just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
@@ -77,11 +77,13 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
model: str = "gpt-4o-mini",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.timeout = timeout
|
||||
self.api_url = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
@@ -101,7 +103,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -145,7 +147,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -186,11 +188,13 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
model: str = "anthropic/claude-3.5-sonnet",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.timeout = timeout
|
||||
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
@@ -210,7 +214,7 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -254,7 +258,7 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -294,10 +298,12 @@ class OllamaProvider(BaseLLMProvider):
|
||||
host: str | None = None,
|
||||
model: str = "codellama:13b",
|
||||
temperature: float = 0,
|
||||
timeout: int = 300,
|
||||
):
|
||||
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.timeout = timeout
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Call Ollama API."""
|
||||
@@ -311,7 +317,7 @@ class OllamaProvider(BaseLLMProvider):
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
},
|
||||
timeout=300, # Longer timeout for local models
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -477,12 +483,18 @@ class LLMClient:
|
||||
provider = config.get("provider", "openai")
|
||||
provider_config = {}
|
||||
|
||||
# Get timeout configuration
|
||||
timeouts = config.get("timeouts", {})
|
||||
llm_timeout = timeouts.get("llm", 120)
|
||||
ollama_timeout = timeouts.get("ollama", 300)
|
||||
|
||||
# Map config keys to provider-specific settings
|
||||
if provider == "openai":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"max_tokens": config.get("max_tokens", 16000),
|
||||
"timeout": llm_timeout,
|
||||
}
|
||||
elif provider == "openrouter":
|
||||
provider_config = {
|
||||
@@ -491,11 +503,13 @@ class LLMClient:
|
||||
),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"max_tokens": config.get("max_tokens", 16000),
|
||||
"timeout": llm_timeout,
|
||||
}
|
||||
elif provider == "ollama":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("ollama", "codellama:13b"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"timeout": ollama_timeout,
|
||||
}
|
||||
|
||||
return cls(provider=provider, config=provider_config)
|
||||
|
||||
27
tools/ai-review/clients/providers/__init__.py
Normal file
27
tools/ai-review/clients/providers/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""LLM Providers Package
|
||||
|
||||
This package contains additional LLM provider implementations
|
||||
beyond the core providers in llm_client.py.
|
||||
|
||||
Providers:
|
||||
- AnthropicProvider: Direct Anthropic Claude API
|
||||
- AzureOpenAIProvider: Azure OpenAI Service with API key auth
|
||||
- AzureOpenAIWithAADProvider: Azure OpenAI with Azure AD auth
|
||||
- GeminiProvider: Google Gemini API (public)
|
||||
- VertexAIGeminiProvider: Google Vertex AI Gemini (enterprise GCP)
|
||||
"""
|
||||
|
||||
from clients.providers.anthropic_provider import AnthropicProvider
|
||||
from clients.providers.azure_provider import (
|
||||
AzureOpenAIProvider,
|
||||
AzureOpenAIWithAADProvider,
|
||||
)
|
||||
from clients.providers.gemini_provider import GeminiProvider, VertexAIGeminiProvider
|
||||
|
||||
__all__ = [
|
||||
"AnthropicProvider",
|
||||
"AzureOpenAIProvider",
|
||||
"AzureOpenAIWithAADProvider",
|
||||
"GeminiProvider",
|
||||
"VertexAIGeminiProvider",
|
||||
]
|
||||
249
tools/ai-review/clients/providers/anthropic_provider.py
Normal file
249
tools/ai-review/clients/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Anthropic Claude Provider
|
||||
|
||||
Direct integration with Anthropic's Claude API.
|
||||
Supports Claude 3.5 Sonnet, Claude 3 Opus, and other models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
# Import base classes from parent module
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic Claude API provider.
|
||||
|
||||
Provides direct integration with Anthropic's Claude models
|
||||
without going through OpenRouter.
|
||||
|
||||
Supports:
|
||||
- Claude 3.5 Sonnet (claude-3-5-sonnet-20241022)
|
||||
- Claude 3 Opus (claude-3-opus-20240229)
|
||||
- Claude 3 Sonnet (claude-3-sonnet-20240229)
|
||||
- Claude 3 Haiku (claude-3-haiku-20240307)
|
||||
"""
|
||||
|
||||
API_URL = "https://api.anthropic.com/v1/messages"
|
||||
API_VERSION = "2023-06-01"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
|
||||
model: Model to use. Defaults to Claude 3.5 Sonnet.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Anthropic API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (model, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not set.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
|
||||
response = requests.post(
|
||||
self.API_URL,
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": self.API_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": kwargs.get("model", self.model),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
content += block.get("text", "")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=data.get("model", self.model),
|
||||
provider="anthropic",
|
||||
tokens_used=data.get("usage", {}).get("input_tokens", 0)
|
||||
+ data.get("usage", {}).get("output_tokens", 0),
|
||||
finish_reason=data.get("stop_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Anthropic API with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
|
||||
# Convert OpenAI-style messages to Anthropic format
|
||||
anthropic_messages = []
|
||||
system_content = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
|
||||
if role == "system":
|
||||
system_content = msg.get("content", "")
|
||||
elif role == "assistant":
|
||||
# Handle assistant messages with tool calls
|
||||
if msg.get("tool_calls"):
|
||||
content = []
|
||||
if msg.get("content"):
|
||||
content.append({"type": "text", "text": msg["content"]})
|
||||
for tc in msg["tool_calls"]:
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["function"]["name"],
|
||||
"input": json.loads(tc["function"]["arguments"])
|
||||
if isinstance(tc["function"]["arguments"], str)
|
||||
else tc["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
)
|
||||
elif role == "tool":
|
||||
# Tool response
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
)
|
||||
|
||||
# Convert OpenAI-style tools to Anthropic format
|
||||
anthropic_tools = None
|
||||
if tools:
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
anthropic_tools.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
|
||||
request_body = {
|
||||
"model": kwargs.get("model", self.model),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"messages": anthropic_messages,
|
||||
}
|
||||
|
||||
if system_content:
|
||||
request_body["system"] = system_content
|
||||
|
||||
if anthropic_tools:
|
||||
request_body["tools"] = anthropic_tools
|
||||
|
||||
response = requests.post(
|
||||
self.API_URL,
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": self.API_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=block.get("id", ""),
|
||||
name=block.get("name", ""),
|
||||
arguments=block.get("input", {}),
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=data.get("model", self.model),
|
||||
provider="anthropic",
|
||||
tokens_used=data.get("usage", {}).get("input_tokens", 0)
|
||||
+ data.get("usage", {}).get("output_tokens", 0),
|
||||
finish_reason=data.get("stop_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""Azure OpenAI Provider
|
||||
|
||||
Integration with Azure OpenAI Service for enterprise deployments.
|
||||
Supports custom deployments, regional endpoints, and Azure AD auth.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class AzureOpenAIProvider(BaseLLMProvider):
|
||||
"""Azure OpenAI Service provider.
|
||||
|
||||
Provides integration with Azure-hosted OpenAI models for
|
||||
enterprise customers with Azure deployments.
|
||||
|
||||
Supports:
|
||||
- GPT-4, GPT-4 Turbo, GPT-4o
|
||||
- GPT-3.5 Turbo
|
||||
- Custom fine-tuned models
|
||||
|
||||
Environment Variables:
|
||||
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||
- AZURE_OPENAI_API_KEY: API key for authentication
|
||||
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||
- AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview)
|
||||
"""
|
||||
|
||||
DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Azure OpenAI provider.
|
||||
|
||||
Args:
|
||||
endpoint: Azure OpenAI endpoint URL.
|
||||
Defaults to AZURE_OPENAI_ENDPOINT env var.
|
||||
api_key: API key for authentication.
|
||||
Defaults to AZURE_OPENAI_API_KEY env var.
|
||||
deployment: Deployment name to use.
|
||||
Defaults to AZURE_OPENAI_DEPLOYMENT env var.
|
||||
api_version: API version string.
|
||||
Defaults to AZURE_OPENAI_API_VERSION env var or latest.
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.endpoint = (
|
||||
endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
|
||||
).rstrip("/")
|
||||
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
|
||||
self.api_version = api_version or os.environ.get(
|
||||
"AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION
|
||||
)
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _get_api_url(self, deployment: str | None = None) -> str:
|
||||
"""Build the API URL for a given deployment.
|
||||
|
||||
Args:
|
||||
deployment: Deployment name. Uses default if not specified.
|
||||
|
||||
Returns:
|
||||
Full API URL for chat completions.
|
||||
"""
|
||||
deploy = deployment or self.deployment
|
||||
return (
|
||||
f"{self.endpoint}/openai/deployments/{deploy}"
|
||||
f"/chat/completions?api-version={self.api_version}"
|
||||
)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (deployment, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration is missing.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure OpenAI API key is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API with tool support.
|
||||
|
||||
Azure OpenAI uses the same format as OpenAI for tools.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure OpenAI API key is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
|
||||
request_body = {
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", "") or "",
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIWithAADProvider(AzureOpenAIProvider):
|
||||
"""Azure OpenAI provider with Azure Active Directory authentication.
|
||||
|
||||
Uses Azure AD tokens instead of API keys for authentication.
|
||||
Requires azure-identity package for token acquisition.
|
||||
|
||||
Environment Variables:
|
||||
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||
- AZURE_TENANT_ID: Azure AD tenant ID (optional)
|
||||
- AZURE_CLIENT_ID: Azure AD client ID (optional)
|
||||
- AZURE_CLIENT_SECRET: Azure AD client secret (optional)
|
||||
"""
|
||||
|
||||
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str | None = None,
|
||||
deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
credential=None,
|
||||
):
|
||||
"""Initialize the Azure OpenAI AAD provider.
|
||||
|
||||
Args:
|
||||
endpoint: Azure OpenAI endpoint URL.
|
||||
deployment: Deployment name to use.
|
||||
api_version: API version string.
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens in response.
|
||||
credential: Azure credential object. If not provided,
|
||||
uses DefaultAzureCredential.
|
||||
"""
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
api_key="", # Not used with AAD
|
||||
deployment=deployment,
|
||||
api_version=api_version,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self._credential = credential
|
||||
self._token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Get an Azure AD token for authentication.
|
||||
|
||||
Returns:
|
||||
Bearer token string.
|
||||
|
||||
Raises:
|
||||
ImportError: If azure-identity is not installed.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if self._token and self._token_expires_at > time.time() + 300:
|
||||
return self._token
|
||||
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-identity package is required for AAD authentication. "
|
||||
"Install with: pip install azure-identity"
|
||||
)
|
||||
|
||||
if self._credential is None:
|
||||
self._credential = DefaultAzureCredential()
|
||||
|
||||
token = self._credential.get_token(self.SCOPE)
|
||||
self._token = token.token
|
||||
self._token_expires_at = token.expires_on
|
||||
|
||||
return self._token
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API using AAD auth.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
token = self._get_token()
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API with tool support using AAD auth.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts.
|
||||
tools: List of tool definitions.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
token = self._get_token()
|
||||
|
||||
request_body = {
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", "") or "",
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Google Gemini Provider
|
||||
|
||||
Integration with Google's Gemini API for GCP customers.
|
||||
Supports Gemini Pro, Gemini Ultra, and other models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class GeminiProvider(BaseLLMProvider):
|
||||
"""Google Gemini API provider.
|
||||
|
||||
Provides integration with Google's Gemini models.
|
||||
|
||||
Supports:
|
||||
- Gemini 1.5 Pro (gemini-1.5-pro)
|
||||
- Gemini 1.5 Flash (gemini-1.5-flash)
|
||||
- Gemini 1.0 Pro (gemini-pro)
|
||||
|
||||
Environment Variables:
|
||||
- GOOGLE_API_KEY: Google AI API key
|
||||
- GEMINI_MODEL: Default model (optional)
|
||||
"""
|
||||
|
||||
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "gemini-1.5-pro",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Gemini provider.
|
||||
|
||||
Args:
|
||||
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
|
||||
model: Model to use. Defaults to gemini-1.5-pro.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
|
||||
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
|
||||
"""Build the API URL for a given model.
|
||||
|
||||
Args:
|
||||
model: Model name. Uses default if not specified.
|
||||
stream: Whether to use streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Full API URL.
|
||||
"""
|
||||
m = model or self.model
|
||||
action = "streamGenerateContent" if stream else "generateContent"
|
||||
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Gemini API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (model, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not set.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="gemini",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Gemini API with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
|
||||
# Convert OpenAI-style messages to Gemini format
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_instruction = content
|
||||
elif role == "assistant":
|
||||
# Handle assistant messages with tool calls
|
||||
parts = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
gemini_contents.append({"role": "model", "parts": parts})
|
||||
elif role == "tool":
|
||||
# Tool response in Gemini format
|
||||
gemini_contents.append(
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.get("name", ""),
|
||||
"response": {"result": content},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
# User message
|
||||
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
# Convert OpenAI-style tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if function_declarations:
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
request_body = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
request_body["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
if gemini_tools:
|
||||
request_body["tools"] = gemini_tools
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="gemini",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIGeminiProvider(BaseLLMProvider):
|
||||
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
|
||||
|
||||
Uses Vertex AI endpoints instead of the public Gemini API.
|
||||
Supports regional deployments and IAM authentication.
|
||||
|
||||
Environment Variables:
|
||||
- GOOGLE_CLOUD_PROJECT: GCP project ID
|
||||
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
|
||||
- VERTEX_AI_MODEL: Default model (optional)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str | None = None,
|
||||
region: str = "us-central1",
|
||||
model: str = "gemini-1.5-pro",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
credentials=None,
|
||||
):
|
||||
"""Initialize the Vertex AI Gemini provider.
|
||||
|
||||
Args:
|
||||
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
|
||||
region: GCP region. Defaults to us-central1.
|
||||
model: Model to use. Defaults to gemini-1.5-pro.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
credentials: Google credentials object. If not provided,
|
||||
uses Application Default Credentials.
|
||||
"""
|
||||
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
|
||||
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self._credentials = credentials
|
||||
self._token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Get a Google Cloud access token.
|
||||
|
||||
Returns:
|
||||
Access token string.
|
||||
|
||||
Raises:
|
||||
ImportError: If google-auth is not installed.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if self._token and self._token_expires_at > time.time() + 300:
|
||||
return self._token
|
||||
|
||||
try:
|
||||
import google.auth
|
||||
from google.auth.transport.requests import Request
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-auth package is required for Vertex AI authentication. "
|
||||
"Install with: pip install google-auth"
|
||||
)
|
||||
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
|
||||
if not self._credentials.valid:
|
||||
self._credentials.refresh(Request())
|
||||
|
||||
self._token = self._credentials.token
|
||||
# Tokens typically expire in 1 hour
|
||||
self._token_expires_at = time.time() + 3500
|
||||
|
||||
return self._token
|
||||
|
||||
def _get_api_url(self, model: str | None = None) -> str:
|
||||
"""Build the Vertex AI API URL.
|
||||
|
||||
Args:
|
||||
model: Model name. Uses default if not specified.
|
||||
|
||||
Returns:
|
||||
Full API URL.
|
||||
"""
|
||||
m = model or self.model
|
||||
return (
|
||||
f"https://{self.region}-aiplatform.googleapis.com/v1/"
|
||||
f"projects/{self.project}/locations/{self.region}/"
|
||||
f"publishers/google/models/{m}:generateContent"
|
||||
)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to Vertex AI Gemini.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
if not self.project:
|
||||
raise ValueError("GCP project ID is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
token = self._get_token()
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="vertex-ai",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to Vertex AI Gemini with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts.
|
||||
tools: List of tool definitions.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.project:
|
||||
raise ValueError("GCP project ID is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
token = self._get_token()
|
||||
|
||||
# Convert messages to Gemini format (same as GeminiProvider)
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_instruction = content
|
||||
elif role == "assistant":
|
||||
parts = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
gemini_contents.append({"role": "model", "parts": parts})
|
||||
elif role == "tool":
|
||||
gemini_contents.append(
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.get("name", ""),
|
||||
"response": {"result": content},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
# Convert tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if function_declarations:
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
request_body = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
request_body["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
if gemini_tools:
|
||||
request_body["tools"] = gemini_tools
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{len(tool_calls)}",
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="vertex-ai",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
Reference in New Issue
Block a user