just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s

This commit is contained in:
2026-01-07 21:19:46 +01:00
parent a1fe47cdf4
commit e8d28225e0
24 changed files with 6431 additions and 250 deletions

View File

@@ -77,11 +77,13 @@ class OpenAIProvider(BaseLLMProvider):
model: str = "gpt-4o-mini",
temperature: float = 0,
max_tokens: int = 4096,
timeout: int = 120,
):
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.timeout = timeout
self.api_url = "https://api.openai.com/v1/chat/completions"
def call(self, prompt: str, **kwargs) -> LLMResponse:
@@ -101,7 +103,7 @@ class OpenAIProvider(BaseLLMProvider):
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": [{"role": "user", "content": prompt}],
},
timeout=120,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
@@ -145,7 +147,7 @@ class OpenAIProvider(BaseLLMProvider):
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
@@ -186,11 +188,13 @@ class OpenRouterProvider(BaseLLMProvider):
model: str = "anthropic/claude-3.5-sonnet",
temperature: float = 0,
max_tokens: int = 4096,
timeout: int = 120,
):
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.timeout = timeout
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
def call(self, prompt: str, **kwargs) -> LLMResponse:
@@ -210,7 +214,7 @@ class OpenRouterProvider(BaseLLMProvider):
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": [{"role": "user", "content": prompt}],
},
timeout=120,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
@@ -254,7 +258,7 @@ class OpenRouterProvider(BaseLLMProvider):
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
@@ -294,10 +298,12 @@ class OllamaProvider(BaseLLMProvider):
host: str | None = None,
model: str = "codellama:13b",
temperature: float = 0,
timeout: int = 300,
):
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
self.model = model
self.temperature = temperature
self.timeout = timeout
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Call Ollama API."""
@@ -311,7 +317,7 @@ class OllamaProvider(BaseLLMProvider):
"temperature": kwargs.get("temperature", self.temperature),
},
},
timeout=300, # Longer timeout for local models
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
@@ -477,12 +483,18 @@ class LLMClient:
provider = config.get("provider", "openai")
provider_config = {}
# Get timeout configuration
timeouts = config.get("timeouts", {})
llm_timeout = timeouts.get("llm", 120)
ollama_timeout = timeouts.get("ollama", 300)
# Map config keys to provider-specific settings
if provider == "openai":
provider_config = {
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
"temperature": config.get("temperature", 0),
"max_tokens": config.get("max_tokens", 16000),
"timeout": llm_timeout,
}
elif provider == "openrouter":
provider_config = {
@@ -491,11 +503,13 @@ class LLMClient:
),
"temperature": config.get("temperature", 0),
"max_tokens": config.get("max_tokens", 16000),
"timeout": llm_timeout,
}
elif provider == "ollama":
provider_config = {
"model": config.get("model", {}).get("ollama", "codellama:13b"),
"temperature": config.get("temperature", 0),
"timeout": ollama_timeout,
}
return cls(provider=provider, config=provider_config)

View File

@@ -0,0 +1,27 @@
"""LLM Providers Package
This package contains additional LLM provider implementations
beyond the core providers in llm_client.py.
Providers:
- AnthropicProvider: Direct Anthropic Claude API
- AzureOpenAIProvider: Azure OpenAI Service with API key auth
- AzureOpenAIWithAADProvider: Azure OpenAI with Azure AD auth
- GeminiProvider: Google Gemini API (public)
- VertexAIGeminiProvider: Google Vertex AI Gemini (enterprise GCP)
"""
from clients.providers.anthropic_provider import AnthropicProvider
from clients.providers.azure_provider import (
AzureOpenAIProvider,
AzureOpenAIWithAADProvider,
)
from clients.providers.gemini_provider import GeminiProvider, VertexAIGeminiProvider
__all__ = [
"AnthropicProvider",
"AzureOpenAIProvider",
"AzureOpenAIWithAADProvider",
"GeminiProvider",
"VertexAIGeminiProvider",
]

View File

@@ -0,0 +1,249 @@
"""Anthropic Claude Provider
Direct integration with Anthropic's Claude API.
Supports Claude 3.5 Sonnet, Claude 3 Opus, and other models.
"""
import json
import os
# Import base classes from parent module
import sys
from dataclasses import dataclass
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class AnthropicProvider(BaseLLMProvider):
"""Anthropic Claude API provider.
Provides direct integration with Anthropic's Claude models
without going through OpenRouter.
Supports:
- Claude 3.5 Sonnet (claude-3-5-sonnet-20241022)
- Claude 3 Opus (claude-3-opus-20240229)
- Claude 3 Sonnet (claude-3-sonnet-20240229)
- Claude 3 Haiku (claude-3-haiku-20240307)
"""
API_URL = "https://api.anthropic.com/v1/messages"
API_VERSION = "2023-06-01"
def __init__(
self,
api_key: str | None = None,
model: str = "claude-3-5-sonnet-20241022",
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Anthropic provider.
Args:
api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
model: Model to use. Defaults to Claude 3.5 Sonnet.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
"""
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", "")
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Anthropic API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (model, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If API key is not set.
requests.HTTPError: If the API request fails.
"""
if not self.api_key:
raise ValueError("Anthropic API key is required")
response = requests.post(
self.API_URL,
headers={
"x-api-key": self.api_key,
"anthropic-version": self.API_VERSION,
"Content-Type": "application/json",
},
json={
"model": kwargs.get("model", self.model),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
"messages": [{"role": "user", "content": prompt}],
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
for block in data.get("content", []):
if block.get("type") == "text":
content += block.get("text", "")
return LLMResponse(
content=content,
model=data.get("model", self.model),
provider="anthropic",
tokens_used=data.get("usage", {}).get("input_tokens", 0)
+ data.get("usage", {}).get("output_tokens", 0),
finish_reason=data.get("stop_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Anthropic API with tool support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.api_key:
raise ValueError("Anthropic API key is required")
# Convert OpenAI-style messages to Anthropic format
anthropic_messages = []
system_content = None
for msg in messages:
role = msg.get("role", "user")
if role == "system":
system_content = msg.get("content", "")
elif role == "assistant":
# Handle assistant messages with tool calls
if msg.get("tool_calls"):
content = []
if msg.get("content"):
content.append({"type": "text", "text": msg["content"]})
for tc in msg["tool_calls"]:
content.append(
{
"type": "tool_use",
"id": tc["id"],
"name": tc["function"]["name"],
"input": json.loads(tc["function"]["arguments"])
if isinstance(tc["function"]["arguments"], str)
else tc["function"]["arguments"],
}
)
anthropic_messages.append({"role": "assistant", "content": content})
else:
anthropic_messages.append(
{
"role": "assistant",
"content": msg.get("content", ""),
}
)
elif role == "tool":
# Tool response
anthropic_messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": msg.get("content", ""),
}
],
}
)
else:
anthropic_messages.append(
{
"role": "user",
"content": msg.get("content", ""),
}
)
# Convert OpenAI-style tools to Anthropic format
anthropic_tools = None
if tools:
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
anthropic_tools.append(
{
"name": func["name"],
"description": func.get("description", ""),
"input_schema": func.get("parameters", {}),
}
)
request_body = {
"model": kwargs.get("model", self.model),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
"messages": anthropic_messages,
}
if system_content:
request_body["system"] = system_content
if anthropic_tools:
request_body["tools"] = anthropic_tools
response = requests.post(
self.API_URL,
headers={
"x-api-key": self.api_key,
"anthropic-version": self.API_VERSION,
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
for block in data.get("content", []):
if block.get("type") == "text":
content += block.get("text", "")
elif block.get("type") == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolCall(
id=block.get("id", ""),
name=block.get("name", ""),
arguments=block.get("input", {}),
)
)
return LLMResponse(
content=content,
model=data.get("model", self.model),
provider="anthropic",
tokens_used=data.get("usage", {}).get("input_tokens", 0)
+ data.get("usage", {}).get("output_tokens", 0),
finish_reason=data.get("stop_reason"),
tool_calls=tool_calls,
)

View File

@@ -0,0 +1,420 @@
"""Azure OpenAI Provider
Integration with Azure OpenAI Service for enterprise deployments.
Supports custom deployments, regional endpoints, and Azure AD auth.
"""
import json
import os
import sys
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class AzureOpenAIProvider(BaseLLMProvider):
"""Azure OpenAI Service provider.
Provides integration with Azure-hosted OpenAI models for
enterprise customers with Azure deployments.
Supports:
- GPT-4, GPT-4 Turbo, GPT-4o
- GPT-3.5 Turbo
- Custom fine-tuned models
Environment Variables:
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
- AZURE_OPENAI_API_KEY: API key for authentication
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
- AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview)
"""
DEFAULT_API_VERSION = "2024-02-15-preview"
def __init__(
self,
endpoint: str | None = None,
api_key: str | None = None,
deployment: str | None = None,
api_version: str | None = None,
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Azure OpenAI provider.
Args:
endpoint: Azure OpenAI endpoint URL.
Defaults to AZURE_OPENAI_ENDPOINT env var.
api_key: API key for authentication.
Defaults to AZURE_OPENAI_API_KEY env var.
deployment: Deployment name to use.
Defaults to AZURE_OPENAI_DEPLOYMENT env var.
api_version: API version string.
Defaults to AZURE_OPENAI_API_VERSION env var or latest.
temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens in response.
"""
self.endpoint = (
endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
).rstrip("/")
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
self.api_version = api_version or os.environ.get(
"AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION
)
self.temperature = temperature
self.max_tokens = max_tokens
def _get_api_url(self, deployment: str | None = None) -> str:
"""Build the API URL for a given deployment.
Args:
deployment: Deployment name. Uses default if not specified.
Returns:
Full API URL for chat completions.
"""
deploy = deployment or self.deployment
return (
f"{self.endpoint}/openai/deployments/{deploy}"
f"/chat/completions?api-version={self.api_version}"
)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Azure OpenAI API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (deployment, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If required configuration is missing.
requests.HTTPError: If the API request fails.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.api_key:
raise ValueError("Azure OpenAI API key is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
response = requests.post(
self._get_api_url(deployment),
headers={
"api-key": self.api_key,
"Content-Type": "application/json",
},
json={
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
},
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Azure OpenAI API with tool support.
Azure OpenAI uses the same format as OpenAI for tools.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.api_key:
raise ValueError("Azure OpenAI API key is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
request_body = {
"messages": messages,
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
}
if tools:
request_body["tools"] = tools
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
response = requests.post(
self._get_api_url(deployment),
headers={
"api-key": self.api_key,
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
# Parse tool calls if present
tool_calls = None
if message.get("tool_calls"):
tool_calls = []
for tc in message["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=args,
)
)
return LLMResponse(
content=message.get("content", "") or "",
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
tool_calls=tool_calls,
)
class AzureOpenAIWithAADProvider(AzureOpenAIProvider):
"""Azure OpenAI provider with Azure Active Directory authentication.
Uses Azure AD tokens instead of API keys for authentication.
Requires azure-identity package for token acquisition.
Environment Variables:
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
- AZURE_TENANT_ID: Azure AD tenant ID (optional)
- AZURE_CLIENT_ID: Azure AD client ID (optional)
- AZURE_CLIENT_SECRET: Azure AD client secret (optional)
"""
SCOPE = "https://cognitiveservices.azure.com/.default"
def __init__(
self,
endpoint: str | None = None,
deployment: str | None = None,
api_version: str | None = None,
temperature: float = 0,
max_tokens: int = 4096,
credential=None,
):
"""Initialize the Azure OpenAI AAD provider.
Args:
endpoint: Azure OpenAI endpoint URL.
deployment: Deployment name to use.
api_version: API version string.
temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens in response.
credential: Azure credential object. If not provided,
uses DefaultAzureCredential.
"""
super().__init__(
endpoint=endpoint,
api_key="", # Not used with AAD
deployment=deployment,
api_version=api_version,
temperature=temperature,
max_tokens=max_tokens,
)
self._credential = credential
self._token = None
self._token_expires_at = 0
def _get_token(self) -> str:
"""Get an Azure AD token for authentication.
Returns:
Bearer token string.
Raises:
ImportError: If azure-identity is not installed.
"""
import time
# Return cached token if still valid (with 5 min buffer)
if self._token and self._token_expires_at > time.time() + 300:
return self._token
try:
from azure.identity import DefaultAzureCredential
except ImportError:
raise ImportError(
"azure-identity package is required for AAD authentication. "
"Install with: pip install azure-identity"
)
if self._credential is None:
self._credential = DefaultAzureCredential()
token = self._credential.get_token(self.SCOPE)
self._token = token.token
self._token_expires_at = token.expires_on
return self._token
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Azure OpenAI API using AAD auth.
Args:
prompt: The prompt to send.
**kwargs: Additional options.
Returns:
LLMResponse with the generated content.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
token = self._get_token()
response = requests.post(
self._get_api_url(deployment),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
},
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Azure OpenAI API with tool support using AAD auth.
Args:
messages: List of message dicts.
tools: List of tool definitions.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
token = self._get_token()
request_body = {
"messages": messages,
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
}
if tools:
request_body["tools"] = tools
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
response = requests.post(
self._get_api_url(deployment),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
# Parse tool calls if present
tool_calls = None
if message.get("tool_calls"):
tool_calls = []
for tc in message["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=args,
)
)
return LLMResponse(
content=message.get("content", "") or "",
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
tool_calls=tool_calls,
)

View File

@@ -0,0 +1,599 @@
"""Google Gemini Provider
Integration with Google's Gemini API for GCP customers.
Supports Gemini Pro, Gemini Ultra, and other models.
"""
import json
import os
import sys
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class GeminiProvider(BaseLLMProvider):
"""Google Gemini API provider.
Provides integration with Google's Gemini models.
Supports:
- Gemini 1.5 Pro (gemini-1.5-pro)
- Gemini 1.5 Flash (gemini-1.5-flash)
- Gemini 1.0 Pro (gemini-pro)
Environment Variables:
- GOOGLE_API_KEY: Google AI API key
- GEMINI_MODEL: Default model (optional)
"""
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
def __init__(
self,
api_key: str | None = None,
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Gemini provider.
Args:
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
"""
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
"""Build the API URL for a given model.
Args:
model: Model name. Uses default if not specified.
stream: Whether to use streaming endpoint.
Returns:
Full API URL.
"""
m = model or self.model
action = "streamGenerateContent" if stream else "generateContent"
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Gemini API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (model, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If API key is not set.
requests.HTTPError: If the API request fails.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Gemini API with tool support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
# Convert OpenAI-style messages to Gemini format
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
# Handle assistant messages with tool calls
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
# Tool response in Gemini format
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
# User message
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert OpenAI-style tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)
class VertexAIGeminiProvider(BaseLLMProvider):
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
Uses Vertex AI endpoints instead of the public Gemini API.
Supports regional deployments and IAM authentication.
Environment Variables:
- GOOGLE_CLOUD_PROJECT: GCP project ID
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
- VERTEX_AI_MODEL: Default model (optional)
"""
def __init__(
self,
project: str | None = None,
region: str = "us-central1",
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
credentials=None,
):
"""Initialize the Vertex AI Gemini provider.
Args:
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
region: GCP region. Defaults to us-central1.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
credentials: Google credentials object. If not provided,
uses Application Default Credentials.
"""
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
self._credentials = credentials
self._token = None
self._token_expires_at = 0
def _get_token(self) -> str:
"""Get a Google Cloud access token.
Returns:
Access token string.
Raises:
ImportError: If google-auth is not installed.
"""
import time
# Return cached token if still valid (with 5 min buffer)
if self._token and self._token_expires_at > time.time() + 300:
return self._token
try:
import google.auth
from google.auth.transport.requests import Request
except ImportError:
raise ImportError(
"google-auth package is required for Vertex AI authentication. "
"Install with: pip install google-auth"
)
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
if not self._credentials.valid:
self._credentials.refresh(Request())
self._token = self._credentials.token
# Tokens typically expire in 1 hour
self._token_expires_at = time.time() + 3500
return self._token
def _get_api_url(self, model: str | None = None) -> str:
"""Build the Vertex AI API URL.
Args:
model: Model name. Uses default if not specified.
Returns:
Full API URL.
"""
m = model or self.model
return (
f"https://{self.region}-aiplatform.googleapis.com/v1/"
f"projects/{self.project}/locations/{self.region}/"
f"publishers/google/models/{m}:generateContent"
)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to Vertex AI Gemini.
Args:
prompt: The prompt to send.
**kwargs: Additional options.
Returns:
LLMResponse with the generated content.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to Vertex AI Gemini with tool support.
Args:
messages: List of message dicts.
tools: List of tool definitions.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
# Convert messages to Gemini format (same as GeminiProvider)
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}",
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)