just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Google Gemini Provider
|
||||
|
||||
Integration with Google's Gemini API for GCP customers.
|
||||
Supports Gemini Pro, Gemini Ultra, and other models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class GeminiProvider(BaseLLMProvider):
|
||||
"""Google Gemini API provider.
|
||||
|
||||
Provides integration with Google's Gemini models.
|
||||
|
||||
Supports:
|
||||
- Gemini 1.5 Pro (gemini-1.5-pro)
|
||||
- Gemini 1.5 Flash (gemini-1.5-flash)
|
||||
- Gemini 1.0 Pro (gemini-pro)
|
||||
|
||||
Environment Variables:
|
||||
- GOOGLE_API_KEY: Google AI API key
|
||||
- GEMINI_MODEL: Default model (optional)
|
||||
"""
|
||||
|
||||
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "gemini-1.5-pro",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Gemini provider.
|
||||
|
||||
Args:
|
||||
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
|
||||
model: Model to use. Defaults to gemini-1.5-pro.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
|
||||
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
|
||||
"""Build the API URL for a given model.
|
||||
|
||||
Args:
|
||||
model: Model name. Uses default if not specified.
|
||||
stream: Whether to use streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Full API URL.
|
||||
"""
|
||||
m = model or self.model
|
||||
action = "streamGenerateContent" if stream else "generateContent"
|
||||
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Gemini API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (model, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not set.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="gemini",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Gemini API with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
|
||||
# Convert OpenAI-style messages to Gemini format
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_instruction = content
|
||||
elif role == "assistant":
|
||||
# Handle assistant messages with tool calls
|
||||
parts = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
gemini_contents.append({"role": "model", "parts": parts})
|
||||
elif role == "tool":
|
||||
# Tool response in Gemini format
|
||||
gemini_contents.append(
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.get("name", ""),
|
||||
"response": {"result": content},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
# User message
|
||||
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
# Convert OpenAI-style tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if function_declarations:
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
request_body = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
request_body["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
if gemini_tools:
|
||||
request_body["tools"] = gemini_tools
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="gemini",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIGeminiProvider(BaseLLMProvider):
|
||||
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
|
||||
|
||||
Uses Vertex AI endpoints instead of the public Gemini API.
|
||||
Supports regional deployments and IAM authentication.
|
||||
|
||||
Environment Variables:
|
||||
- GOOGLE_CLOUD_PROJECT: GCP project ID
|
||||
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
|
||||
- VERTEX_AI_MODEL: Default model (optional)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str | None = None,
|
||||
region: str = "us-central1",
|
||||
model: str = "gemini-1.5-pro",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
credentials=None,
|
||||
):
|
||||
"""Initialize the Vertex AI Gemini provider.
|
||||
|
||||
Args:
|
||||
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
|
||||
region: GCP region. Defaults to us-central1.
|
||||
model: Model to use. Defaults to gemini-1.5-pro.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
credentials: Google credentials object. If not provided,
|
||||
uses Application Default Credentials.
|
||||
"""
|
||||
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
|
||||
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self._credentials = credentials
|
||||
self._token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Get a Google Cloud access token.
|
||||
|
||||
Returns:
|
||||
Access token string.
|
||||
|
||||
Raises:
|
||||
ImportError: If google-auth is not installed.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if self._token and self._token_expires_at > time.time() + 300:
|
||||
return self._token
|
||||
|
||||
try:
|
||||
import google.auth
|
||||
from google.auth.transport.requests import Request
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-auth package is required for Vertex AI authentication. "
|
||||
"Install with: pip install google-auth"
|
||||
)
|
||||
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
|
||||
if not self._credentials.valid:
|
||||
self._credentials.refresh(Request())
|
||||
|
||||
self._token = self._credentials.token
|
||||
# Tokens typically expire in 1 hour
|
||||
self._token_expires_at = time.time() + 3500
|
||||
|
||||
return self._token
|
||||
|
||||
def _get_api_url(self, model: str | None = None) -> str:
|
||||
"""Build the Vertex AI API URL.
|
||||
|
||||
Args:
|
||||
model: Model name. Uses default if not specified.
|
||||
|
||||
Returns:
|
||||
Full API URL.
|
||||
"""
|
||||
m = model or self.model
|
||||
return (
|
||||
f"https://{self.region}-aiplatform.googleapis.com/v1/"
|
||||
f"projects/{self.project}/locations/{self.region}/"
|
||||
f"publishers/google/models/{m}:generateContent"
|
||||
)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to Vertex AI Gemini.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
if not self.project:
|
||||
raise ValueError("GCP project ID is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
token = self._get_token()
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="vertex-ai",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to Vertex AI Gemini with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts.
|
||||
tools: List of tool definitions.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.project:
|
||||
raise ValueError("GCP project ID is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
token = self._get_token()
|
||||
|
||||
# Convert messages to Gemini format (same as GeminiProvider)
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_instruction = content
|
||||
elif role == "assistant":
|
||||
parts = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
gemini_contents.append({"role": "model", "parts": parts})
|
||||
elif role == "tool":
|
||||
gemini_contents.append(
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.get("name", ""),
|
||||
"response": {"result": content},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
# Convert tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if function_declarations:
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
request_body = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
request_body["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
if gemini_tools:
|
||||
request_body["tools"] = gemini_tools
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{len(tool_calls)}",
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="vertex-ai",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
Reference in New Issue
Block a user