All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
600 lines
20 KiB
Python
600 lines
20 KiB
Python
"""Google Gemini Provider
|
|
|
|
Integration with Google's Gemini API for GCP customers.
|
|
Supports Gemini Pro, Gemini Ultra, and other models.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
import requests
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
|
|
|
|
|
class GeminiProvider(BaseLLMProvider):
|
|
"""Google Gemini API provider.
|
|
|
|
Provides integration with Google's Gemini models.
|
|
|
|
Supports:
|
|
- Gemini 1.5 Pro (gemini-1.5-pro)
|
|
- Gemini 1.5 Flash (gemini-1.5-flash)
|
|
- Gemini 1.0 Pro (gemini-pro)
|
|
|
|
Environment Variables:
|
|
- GOOGLE_API_KEY: Google AI API key
|
|
- GEMINI_MODEL: Default model (optional)
|
|
"""
|
|
|
|
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str = "gemini-1.5-pro",
|
|
temperature: float = 0,
|
|
max_tokens: int = 4096,
|
|
):
|
|
"""Initialize the Gemini provider.
|
|
|
|
Args:
|
|
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
|
|
model: Model to use. Defaults to gemini-1.5-pro.
|
|
temperature: Sampling temperature (0-1).
|
|
max_tokens: Maximum tokens in response.
|
|
"""
|
|
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
|
|
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
|
|
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
|
|
"""Build the API URL for a given model.
|
|
|
|
Args:
|
|
model: Model name. Uses default if not specified.
|
|
stream: Whether to use streaming endpoint.
|
|
|
|
Returns:
|
|
Full API URL.
|
|
"""
|
|
m = model or self.model
|
|
action = "streamGenerateContent" if stream else "generateContent"
|
|
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
|
|
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Make a call to the Gemini API.
|
|
|
|
Args:
|
|
prompt: The prompt to send.
|
|
**kwargs: Additional options (model, temperature, max_tokens).
|
|
|
|
Returns:
|
|
LLMResponse with the generated content.
|
|
|
|
Raises:
|
|
ValueError: If API key is not set.
|
|
requests.HTTPError: If the API request fails.
|
|
"""
|
|
if not self.api_key:
|
|
raise ValueError("Google API key is required")
|
|
|
|
model = kwargs.get("model", self.model)
|
|
|
|
response = requests.post(
|
|
self._get_api_url(model),
|
|
headers={"Content-Type": "application/json"},
|
|
json={
|
|
"contents": [{"parts": [{"text": prompt}]}],
|
|
"generationConfig": {
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
|
},
|
|
},
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Extract content from response
|
|
content = ""
|
|
candidates = data.get("candidates", [])
|
|
if candidates:
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
for part in parts:
|
|
if "text" in part:
|
|
content += part["text"]
|
|
|
|
# Get token counts
|
|
usage = data.get("usageMetadata", {})
|
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
|
"candidatesTokenCount", 0
|
|
)
|
|
|
|
finish_reason = None
|
|
if candidates:
|
|
finish_reason = candidates[0].get("finishReason")
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
provider="gemini",
|
|
tokens_used=tokens_used,
|
|
finish_reason=finish_reason,
|
|
)
|
|
|
|
def call_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""Make a call to the Gemini API with tool support.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
tools: List of tool definitions in OpenAI format.
|
|
**kwargs: Additional options.
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool_calls.
|
|
"""
|
|
if not self.api_key:
|
|
raise ValueError("Google API key is required")
|
|
|
|
model = kwargs.get("model", self.model)
|
|
|
|
# Convert OpenAI-style messages to Gemini format
|
|
gemini_contents = []
|
|
system_instruction = None
|
|
|
|
for msg in messages:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
|
|
if role == "system":
|
|
system_instruction = content
|
|
elif role == "assistant":
|
|
# Handle assistant messages with tool calls
|
|
parts = []
|
|
if content:
|
|
parts.append({"text": content})
|
|
|
|
if msg.get("tool_calls"):
|
|
for tc in msg["tool_calls"]:
|
|
func = tc.get("function", {})
|
|
args = func.get("arguments", {})
|
|
if isinstance(args, str):
|
|
try:
|
|
args = json.loads(args)
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
parts.append(
|
|
{
|
|
"functionCall": {
|
|
"name": func.get("name", ""),
|
|
"args": args,
|
|
}
|
|
}
|
|
)
|
|
|
|
gemini_contents.append({"role": "model", "parts": parts})
|
|
elif role == "tool":
|
|
# Tool response in Gemini format
|
|
gemini_contents.append(
|
|
{
|
|
"role": "function",
|
|
"parts": [
|
|
{
|
|
"functionResponse": {
|
|
"name": msg.get("name", ""),
|
|
"response": {"result": content},
|
|
}
|
|
}
|
|
],
|
|
}
|
|
)
|
|
else:
|
|
# User message
|
|
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
|
|
|
# Convert OpenAI-style tools to Gemini format
|
|
gemini_tools = None
|
|
if tools:
|
|
function_declarations = []
|
|
for tool in tools:
|
|
if tool.get("type") == "function":
|
|
func = tool["function"]
|
|
function_declarations.append(
|
|
{
|
|
"name": func["name"],
|
|
"description": func.get("description", ""),
|
|
"parameters": func.get("parameters", {}),
|
|
}
|
|
)
|
|
if function_declarations:
|
|
gemini_tools = [{"functionDeclarations": function_declarations}]
|
|
|
|
request_body = {
|
|
"contents": gemini_contents,
|
|
"generationConfig": {
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
|
},
|
|
}
|
|
|
|
if system_instruction:
|
|
request_body["systemInstruction"] = {
|
|
"parts": [{"text": system_instruction}]
|
|
}
|
|
|
|
if gemini_tools:
|
|
request_body["tools"] = gemini_tools
|
|
|
|
response = requests.post(
|
|
self._get_api_url(model),
|
|
headers={"Content-Type": "application/json"},
|
|
json=request_body,
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Parse response
|
|
content = ""
|
|
tool_calls = None
|
|
|
|
candidates = data.get("candidates", [])
|
|
if candidates:
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
for part in parts:
|
|
if "text" in part:
|
|
content += part["text"]
|
|
elif "functionCall" in part:
|
|
if tool_calls is None:
|
|
tool_calls = []
|
|
fc = part["functionCall"]
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
|
|
name=fc.get("name", ""),
|
|
arguments=fc.get("args", {}),
|
|
)
|
|
)
|
|
|
|
# Get token counts
|
|
usage = data.get("usageMetadata", {})
|
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
|
"candidatesTokenCount", 0
|
|
)
|
|
|
|
finish_reason = None
|
|
if candidates:
|
|
finish_reason = candidates[0].get("finishReason")
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
provider="gemini",
|
|
tokens_used=tokens_used,
|
|
finish_reason=finish_reason,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
|
|
class VertexAIGeminiProvider(BaseLLMProvider):
|
|
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
|
|
|
|
Uses Vertex AI endpoints instead of the public Gemini API.
|
|
Supports regional deployments and IAM authentication.
|
|
|
|
Environment Variables:
|
|
- GOOGLE_CLOUD_PROJECT: GCP project ID
|
|
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
|
|
- VERTEX_AI_MODEL: Default model (optional)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
project: str | None = None,
|
|
region: str = "us-central1",
|
|
model: str = "gemini-1.5-pro",
|
|
temperature: float = 0,
|
|
max_tokens: int = 4096,
|
|
credentials=None,
|
|
):
|
|
"""Initialize the Vertex AI Gemini provider.
|
|
|
|
Args:
|
|
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
|
|
region: GCP region. Defaults to us-central1.
|
|
model: Model to use. Defaults to gemini-1.5-pro.
|
|
temperature: Sampling temperature (0-1).
|
|
max_tokens: Maximum tokens in response.
|
|
credentials: Google credentials object. If not provided,
|
|
uses Application Default Credentials.
|
|
"""
|
|
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
|
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
|
|
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self._credentials = credentials
|
|
self._token = None
|
|
self._token_expires_at = 0
|
|
|
|
def _get_token(self) -> str:
|
|
"""Get a Google Cloud access token.
|
|
|
|
Returns:
|
|
Access token string.
|
|
|
|
Raises:
|
|
ImportError: If google-auth is not installed.
|
|
"""
|
|
import time
|
|
|
|
# Return cached token if still valid (with 5 min buffer)
|
|
if self._token and self._token_expires_at > time.time() + 300:
|
|
return self._token
|
|
|
|
try:
|
|
import google.auth
|
|
from google.auth.transport.requests import Request
|
|
except ImportError:
|
|
raise ImportError(
|
|
"google-auth package is required for Vertex AI authentication. "
|
|
"Install with: pip install google-auth"
|
|
)
|
|
|
|
if self._credentials is None:
|
|
self._credentials, _ = google.auth.default(
|
|
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
|
)
|
|
|
|
if not self._credentials.valid:
|
|
self._credentials.refresh(Request())
|
|
|
|
self._token = self._credentials.token
|
|
# Tokens typically expire in 1 hour
|
|
self._token_expires_at = time.time() + 3500
|
|
|
|
return self._token
|
|
|
|
def _get_api_url(self, model: str | None = None) -> str:
|
|
"""Build the Vertex AI API URL.
|
|
|
|
Args:
|
|
model: Model name. Uses default if not specified.
|
|
|
|
Returns:
|
|
Full API URL.
|
|
"""
|
|
m = model or self.model
|
|
return (
|
|
f"https://{self.region}-aiplatform.googleapis.com/v1/"
|
|
f"projects/{self.project}/locations/{self.region}/"
|
|
f"publishers/google/models/{m}:generateContent"
|
|
)
|
|
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Make a call to Vertex AI Gemini.
|
|
|
|
Args:
|
|
prompt: The prompt to send.
|
|
**kwargs: Additional options.
|
|
|
|
Returns:
|
|
LLMResponse with the generated content.
|
|
"""
|
|
if not self.project:
|
|
raise ValueError("GCP project ID is required")
|
|
|
|
model = kwargs.get("model", self.model)
|
|
token = self._get_token()
|
|
|
|
response = requests.post(
|
|
self._get_api_url(model),
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={
|
|
"contents": [{"parts": [{"text": prompt}]}],
|
|
"generationConfig": {
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
|
},
|
|
},
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Extract content from response
|
|
content = ""
|
|
candidates = data.get("candidates", [])
|
|
if candidates:
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
for part in parts:
|
|
if "text" in part:
|
|
content += part["text"]
|
|
|
|
# Get token counts
|
|
usage = data.get("usageMetadata", {})
|
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
|
"candidatesTokenCount", 0
|
|
)
|
|
|
|
finish_reason = None
|
|
if candidates:
|
|
finish_reason = candidates[0].get("finishReason")
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
provider="vertex-ai",
|
|
tokens_used=tokens_used,
|
|
finish_reason=finish_reason,
|
|
)
|
|
|
|
def call_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""Make a call to Vertex AI Gemini with tool support.
|
|
|
|
Args:
|
|
messages: List of message dicts.
|
|
tools: List of tool definitions.
|
|
**kwargs: Additional options.
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool_calls.
|
|
"""
|
|
if not self.project:
|
|
raise ValueError("GCP project ID is required")
|
|
|
|
model = kwargs.get("model", self.model)
|
|
token = self._get_token()
|
|
|
|
# Convert messages to Gemini format (same as GeminiProvider)
|
|
gemini_contents = []
|
|
system_instruction = None
|
|
|
|
for msg in messages:
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")
|
|
|
|
if role == "system":
|
|
system_instruction = content
|
|
elif role == "assistant":
|
|
parts = []
|
|
if content:
|
|
parts.append({"text": content})
|
|
if msg.get("tool_calls"):
|
|
for tc in msg["tool_calls"]:
|
|
func = tc.get("function", {})
|
|
args = func.get("arguments", {})
|
|
if isinstance(args, str):
|
|
try:
|
|
args = json.loads(args)
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
parts.append(
|
|
{
|
|
"functionCall": {
|
|
"name": func.get("name", ""),
|
|
"args": args,
|
|
}
|
|
}
|
|
)
|
|
gemini_contents.append({"role": "model", "parts": parts})
|
|
elif role == "tool":
|
|
gemini_contents.append(
|
|
{
|
|
"role": "function",
|
|
"parts": [
|
|
{
|
|
"functionResponse": {
|
|
"name": msg.get("name", ""),
|
|
"response": {"result": content},
|
|
}
|
|
}
|
|
],
|
|
}
|
|
)
|
|
else:
|
|
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
|
|
|
# Convert tools to Gemini format
|
|
gemini_tools = None
|
|
if tools:
|
|
function_declarations = []
|
|
for tool in tools:
|
|
if tool.get("type") == "function":
|
|
func = tool["function"]
|
|
function_declarations.append(
|
|
{
|
|
"name": func["name"],
|
|
"description": func.get("description", ""),
|
|
"parameters": func.get("parameters", {}),
|
|
}
|
|
)
|
|
if function_declarations:
|
|
gemini_tools = [{"functionDeclarations": function_declarations}]
|
|
|
|
request_body = {
|
|
"contents": gemini_contents,
|
|
"generationConfig": {
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
|
},
|
|
}
|
|
|
|
if system_instruction:
|
|
request_body["systemInstruction"] = {
|
|
"parts": [{"text": system_instruction}]
|
|
}
|
|
|
|
if gemini_tools:
|
|
request_body["tools"] = gemini_tools
|
|
|
|
response = requests.post(
|
|
self._get_api_url(model),
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json=request_body,
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Parse response
|
|
content = ""
|
|
tool_calls = None
|
|
|
|
candidates = data.get("candidates", [])
|
|
if candidates:
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
for part in parts:
|
|
if "text" in part:
|
|
content += part["text"]
|
|
elif "functionCall" in part:
|
|
if tool_calls is None:
|
|
tool_calls = []
|
|
fc = part["functionCall"]
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=f"call_{len(tool_calls)}",
|
|
name=fc.get("name", ""),
|
|
arguments=fc.get("args", {}),
|
|
)
|
|
)
|
|
|
|
usage = data.get("usageMetadata", {})
|
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
|
"candidatesTokenCount", 0
|
|
)
|
|
|
|
finish_reason = None
|
|
if candidates:
|
|
finish_reason = candidates[0].get("finishReason")
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
provider="vertex-ai",
|
|
tokens_used=tokens_used,
|
|
finish_reason=finish_reason,
|
|
tool_calls=tool_calls,
|
|
)
|