just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s

This commit is contained in:
2026-01-07 21:19:46 +01:00
parent a1fe47cdf4
commit e8d28225e0
24 changed files with 6431 additions and 250 deletions

View File

@@ -0,0 +1,599 @@
"""Google Gemini Provider
Integration with Google's Gemini API for GCP customers.
Supports Gemini Pro, Gemini Ultra, and other models.
"""
import json
import os
import sys
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class GeminiProvider(BaseLLMProvider):
"""Google Gemini API provider.
Provides integration with Google's Gemini models.
Supports:
- Gemini 1.5 Pro (gemini-1.5-pro)
- Gemini 1.5 Flash (gemini-1.5-flash)
- Gemini 1.0 Pro (gemini-pro)
Environment Variables:
- GOOGLE_API_KEY: Google AI API key
- GEMINI_MODEL: Default model (optional)
"""
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
def __init__(
self,
api_key: str | None = None,
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Gemini provider.
Args:
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
"""
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
"""Build the API URL for a given model.
Args:
model: Model name. Uses default if not specified.
stream: Whether to use streaming endpoint.
Returns:
Full API URL.
"""
m = model or self.model
action = "streamGenerateContent" if stream else "generateContent"
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Gemini API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (model, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If API key is not set.
requests.HTTPError: If the API request fails.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Gemini API with tool support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
# Convert OpenAI-style messages to Gemini format
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
# Handle assistant messages with tool calls
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
# Tool response in Gemini format
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
# User message
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert OpenAI-style tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)
class VertexAIGeminiProvider(BaseLLMProvider):
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
Uses Vertex AI endpoints instead of the public Gemini API.
Supports regional deployments and IAM authentication.
Environment Variables:
- GOOGLE_CLOUD_PROJECT: GCP project ID
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
- VERTEX_AI_MODEL: Default model (optional)
"""
def __init__(
self,
project: str | None = None,
region: str = "us-central1",
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
credentials=None,
):
"""Initialize the Vertex AI Gemini provider.
Args:
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
region: GCP region. Defaults to us-central1.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
credentials: Google credentials object. If not provided,
uses Application Default Credentials.
"""
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
self._credentials = credentials
self._token = None
self._token_expires_at = 0
def _get_token(self) -> str:
"""Get a Google Cloud access token.
Returns:
Access token string.
Raises:
ImportError: If google-auth is not installed.
"""
import time
# Return cached token if still valid (with 5 min buffer)
if self._token and self._token_expires_at > time.time() + 300:
return self._token
try:
import google.auth
from google.auth.transport.requests import Request
except ImportError:
raise ImportError(
"google-auth package is required for Vertex AI authentication. "
"Install with: pip install google-auth"
)
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
if not self._credentials.valid:
self._credentials.refresh(Request())
self._token = self._credentials.token
# Tokens typically expire in 1 hour
self._token_expires_at = time.time() + 3500
return self._token
def _get_api_url(self, model: str | None = None) -> str:
"""Build the Vertex AI API URL.
Args:
model: Model name. Uses default if not specified.
Returns:
Full API URL.
"""
m = model or self.model
return (
f"https://{self.region}-aiplatform.googleapis.com/v1/"
f"projects/{self.project}/locations/{self.region}/"
f"publishers/google/models/{m}:generateContent"
)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to Vertex AI Gemini.
Args:
prompt: The prompt to send.
**kwargs: Additional options.
Returns:
LLMResponse with the generated content.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to Vertex AI Gemini with tool support.
Args:
messages: List of message dicts.
tools: List of tool definitions.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
# Convert messages to Gemini format (same as GeminiProvider)
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}",
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)