Files
openrabbit/tools/ai-review/clients/providers/gemini_provider.py
latte e8d28225e0
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
just why not
2026-01-07 21:19:46 +01:00

600 lines
20 KiB
Python

"""Google Gemini Provider
Integration with Google's Gemini API for GCP customers.
Supports Gemini Pro, Gemini Ultra, and other models.
"""
import json
import os
import sys
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class GeminiProvider(BaseLLMProvider):
"""Google Gemini API provider.
Provides integration with Google's Gemini models.
Supports:
- Gemini 1.5 Pro (gemini-1.5-pro)
- Gemini 1.5 Flash (gemini-1.5-flash)
- Gemini 1.0 Pro (gemini-pro)
Environment Variables:
- GOOGLE_API_KEY: Google AI API key
- GEMINI_MODEL: Default model (optional)
"""
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
def __init__(
self,
api_key: str | None = None,
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Gemini provider.
Args:
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
"""
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
"""Build the API URL for a given model.
Args:
model: Model name. Uses default if not specified.
stream: Whether to use streaming endpoint.
Returns:
Full API URL.
"""
m = model or self.model
action = "streamGenerateContent" if stream else "generateContent"
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Gemini API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (model, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If API key is not set.
requests.HTTPError: If the API request fails.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Gemini API with tool support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
# Convert OpenAI-style messages to Gemini format
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
# Handle assistant messages with tool calls
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
# Tool response in Gemini format
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
# User message
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert OpenAI-style tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)
class VertexAIGeminiProvider(BaseLLMProvider):
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
Uses Vertex AI endpoints instead of the public Gemini API.
Supports regional deployments and IAM authentication.
Environment Variables:
- GOOGLE_CLOUD_PROJECT: GCP project ID
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
- VERTEX_AI_MODEL: Default model (optional)
"""
def __init__(
self,
project: str | None = None,
region: str = "us-central1",
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
credentials=None,
):
"""Initialize the Vertex AI Gemini provider.
Args:
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
region: GCP region. Defaults to us-central1.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
credentials: Google credentials object. If not provided,
uses Application Default Credentials.
"""
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
self._credentials = credentials
self._token = None
self._token_expires_at = 0
def _get_token(self) -> str:
"""Get a Google Cloud access token.
Returns:
Access token string.
Raises:
ImportError: If google-auth is not installed.
"""
import time
# Return cached token if still valid (with 5 min buffer)
if self._token and self._token_expires_at > time.time() + 300:
return self._token
try:
import google.auth
from google.auth.transport.requests import Request
except ImportError:
raise ImportError(
"google-auth package is required for Vertex AI authentication. "
"Install with: pip install google-auth"
)
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
if not self._credentials.valid:
self._credentials.refresh(Request())
self._token = self._credentials.token
# Tokens typically expire in 1 hour
self._token_expires_at = time.time() + 3500
return self._token
def _get_api_url(self, model: str | None = None) -> str:
"""Build the Vertex AI API URL.
Args:
model: Model name. Uses default if not specified.
Returns:
Full API URL.
"""
m = model or self.model
return (
f"https://{self.region}-aiplatform.googleapis.com/v1/"
f"projects/{self.project}/locations/{self.region}/"
f"publishers/google/models/{m}:generateContent"
)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to Vertex AI Gemini.
Args:
prompt: The prompt to send.
**kwargs: Additional options.
Returns:
LLMResponse with the generated content.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to Vertex AI Gemini with tool support.
Args:
messages: List of message dicts.
tools: List of tool definitions.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
# Convert messages to Gemini format (same as GeminiProvider)
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}",
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)