"""Google Gemini Provider Integration with Google's Gemini API for GCP customers. Supports Gemini Pro, Gemini Ultra, and other models. """ import json import os import sys import requests sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall class GeminiProvider(BaseLLMProvider): """Google Gemini API provider. Provides integration with Google's Gemini models. Supports: - Gemini 1.5 Pro (gemini-1.5-pro) - Gemini 1.5 Flash (gemini-1.5-flash) - Gemini 1.0 Pro (gemini-pro) Environment Variables: - GOOGLE_API_KEY: Google AI API key - GEMINI_MODEL: Default model (optional) """ API_URL = "https://generativelanguage.googleapis.com/v1beta/models" def __init__( self, api_key: str | None = None, model: str = "gemini-1.5-pro", temperature: float = 0, max_tokens: int = 4096, ): """Initialize the Gemini provider. Args: api_key: Google API key. Defaults to GOOGLE_API_KEY env var. model: Model to use. Defaults to gemini-1.5-pro. temperature: Sampling temperature (0-1). max_tokens: Maximum tokens in response. """ self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "") self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro") self.temperature = temperature self.max_tokens = max_tokens def _get_api_url(self, model: str | None = None, stream: bool = False) -> str: """Build the API URL for a given model. Args: model: Model name. Uses default if not specified. stream: Whether to use streaming endpoint. Returns: Full API URL. """ m = model or self.model action = "streamGenerateContent" if stream else "generateContent" return f"{self.API_URL}/{m}:{action}?key={self.api_key}" def call(self, prompt: str, **kwargs) -> LLMResponse: """Make a call to the Gemini API. Args: prompt: The prompt to send. **kwargs: Additional options (model, temperature, max_tokens). Returns: LLMResponse with the generated content. Raises: ValueError: If API key is not set. requests.HTTPError: If the API request fails. """ if not self.api_key: raise ValueError("Google API key is required") model = kwargs.get("model", self.model) response = requests.post( self._get_api_url(model), headers={"Content-Type": "application/json"}, json={ "contents": [{"parts": [{"text": prompt}]}], "generationConfig": { "temperature": kwargs.get("temperature", self.temperature), "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), }, }, timeout=120, ) response.raise_for_status() data = response.json() # Extract content from response content = "" candidates = data.get("candidates", []) if candidates: parts = candidates[0].get("content", {}).get("parts", []) for part in parts: if "text" in part: content += part["text"] # Get token counts usage = data.get("usageMetadata", {}) tokens_used = usage.get("promptTokenCount", 0) + usage.get( "candidatesTokenCount", 0 ) finish_reason = None if candidates: finish_reason = candidates[0].get("finishReason") return LLMResponse( content=content, model=model, provider="gemini", tokens_used=tokens_used, finish_reason=finish_reason, ) def call_with_tools( self, messages: list[dict], tools: list[dict] | None = None, **kwargs, ) -> LLMResponse: """Make a call to the Gemini API with tool support. Args: messages: List of message dicts with 'role' and 'content'. tools: List of tool definitions in OpenAI format. **kwargs: Additional options. Returns: LLMResponse with content and/or tool_calls. """ if not self.api_key: raise ValueError("Google API key is required") model = kwargs.get("model", self.model) # Convert OpenAI-style messages to Gemini format gemini_contents = [] system_instruction = None for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": system_instruction = content elif role == "assistant": # Handle assistant messages with tool calls parts = [] if content: parts.append({"text": content}) if msg.get("tool_calls"): for tc in msg["tool_calls"]: func = tc.get("function", {}) args = func.get("arguments", {}) if isinstance(args, str): try: args = json.loads(args) except json.JSONDecodeError: args = {} parts.append( { "functionCall": { "name": func.get("name", ""), "args": args, } } ) gemini_contents.append({"role": "model", "parts": parts}) elif role == "tool": # Tool response in Gemini format gemini_contents.append( { "role": "function", "parts": [ { "functionResponse": { "name": msg.get("name", ""), "response": {"result": content}, } } ], } ) else: # User message gemini_contents.append({"role": "user", "parts": [{"text": content}]}) # Convert OpenAI-style tools to Gemini format gemini_tools = None if tools: function_declarations = [] for tool in tools: if tool.get("type") == "function": func = tool["function"] function_declarations.append( { "name": func["name"], "description": func.get("description", ""), "parameters": func.get("parameters", {}), } ) if function_declarations: gemini_tools = [{"functionDeclarations": function_declarations}] request_body = { "contents": gemini_contents, "generationConfig": { "temperature": kwargs.get("temperature", self.temperature), "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), }, } if system_instruction: request_body["systemInstruction"] = { "parts": [{"text": system_instruction}] } if gemini_tools: request_body["tools"] = gemini_tools response = requests.post( self._get_api_url(model), headers={"Content-Type": "application/json"}, json=request_body, timeout=120, ) response.raise_for_status() data = response.json() # Parse response content = "" tool_calls = None candidates = data.get("candidates", []) if candidates: parts = candidates[0].get("content", {}).get("parts", []) for part in parts: if "text" in part: content += part["text"] elif "functionCall" in part: if tool_calls is None: tool_calls = [] fc = part["functionCall"] tool_calls.append( ToolCall( id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs name=fc.get("name", ""), arguments=fc.get("args", {}), ) ) # Get token counts usage = data.get("usageMetadata", {}) tokens_used = usage.get("promptTokenCount", 0) + usage.get( "candidatesTokenCount", 0 ) finish_reason = None if candidates: finish_reason = candidates[0].get("finishReason") return LLMResponse( content=content, model=model, provider="gemini", tokens_used=tokens_used, finish_reason=finish_reason, tool_calls=tool_calls, ) class VertexAIGeminiProvider(BaseLLMProvider): """Google Vertex AI Gemini provider for enterprise GCP deployments. Uses Vertex AI endpoints instead of the public Gemini API. Supports regional deployments and IAM authentication. Environment Variables: - GOOGLE_CLOUD_PROJECT: GCP project ID - GOOGLE_CLOUD_REGION: GCP region (default: us-central1) - VERTEX_AI_MODEL: Default model (optional) """ def __init__( self, project: str | None = None, region: str = "us-central1", model: str = "gemini-1.5-pro", temperature: float = 0, max_tokens: int = 4096, credentials=None, ): """Initialize the Vertex AI Gemini provider. Args: project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var. region: GCP region. Defaults to us-central1. model: Model to use. Defaults to gemini-1.5-pro. temperature: Sampling temperature (0-1). max_tokens: Maximum tokens in response. credentials: Google credentials object. If not provided, uses Application Default Credentials. """ self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "") self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1") self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro") self.temperature = temperature self.max_tokens = max_tokens self._credentials = credentials self._token = None self._token_expires_at = 0 def _get_token(self) -> str: """Get a Google Cloud access token. Returns: Access token string. Raises: ImportError: If google-auth is not installed. """ import time # Return cached token if still valid (with 5 min buffer) if self._token and self._token_expires_at > time.time() + 300: return self._token try: import google.auth from google.auth.transport.requests import Request except ImportError: raise ImportError( "google-auth package is required for Vertex AI authentication. " "Install with: pip install google-auth" ) if self._credentials is None: self._credentials, _ = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) if not self._credentials.valid: self._credentials.refresh(Request()) self._token = self._credentials.token # Tokens typically expire in 1 hour self._token_expires_at = time.time() + 3500 return self._token def _get_api_url(self, model: str | None = None) -> str: """Build the Vertex AI API URL. Args: model: Model name. Uses default if not specified. Returns: Full API URL. """ m = model or self.model return ( f"https://{self.region}-aiplatform.googleapis.com/v1/" f"projects/{self.project}/locations/{self.region}/" f"publishers/google/models/{m}:generateContent" ) def call(self, prompt: str, **kwargs) -> LLMResponse: """Make a call to Vertex AI Gemini. Args: prompt: The prompt to send. **kwargs: Additional options. Returns: LLMResponse with the generated content. """ if not self.project: raise ValueError("GCP project ID is required") model = kwargs.get("model", self.model) token = self._get_token() response = requests.post( self._get_api_url(model), headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, json={ "contents": [{"parts": [{"text": prompt}]}], "generationConfig": { "temperature": kwargs.get("temperature", self.temperature), "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), }, }, timeout=120, ) response.raise_for_status() data = response.json() # Extract content from response content = "" candidates = data.get("candidates", []) if candidates: parts = candidates[0].get("content", {}).get("parts", []) for part in parts: if "text" in part: content += part["text"] # Get token counts usage = data.get("usageMetadata", {}) tokens_used = usage.get("promptTokenCount", 0) + usage.get( "candidatesTokenCount", 0 ) finish_reason = None if candidates: finish_reason = candidates[0].get("finishReason") return LLMResponse( content=content, model=model, provider="vertex-ai", tokens_used=tokens_used, finish_reason=finish_reason, ) def call_with_tools( self, messages: list[dict], tools: list[dict] | None = None, **kwargs, ) -> LLMResponse: """Make a call to Vertex AI Gemini with tool support. Args: messages: List of message dicts. tools: List of tool definitions. **kwargs: Additional options. Returns: LLMResponse with content and/or tool_calls. """ if not self.project: raise ValueError("GCP project ID is required") model = kwargs.get("model", self.model) token = self._get_token() # Convert messages to Gemini format (same as GeminiProvider) gemini_contents = [] system_instruction = None for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": system_instruction = content elif role == "assistant": parts = [] if content: parts.append({"text": content}) if msg.get("tool_calls"): for tc in msg["tool_calls"]: func = tc.get("function", {}) args = func.get("arguments", {}) if isinstance(args, str): try: args = json.loads(args) except json.JSONDecodeError: args = {} parts.append( { "functionCall": { "name": func.get("name", ""), "args": args, } } ) gemini_contents.append({"role": "model", "parts": parts}) elif role == "tool": gemini_contents.append( { "role": "function", "parts": [ { "functionResponse": { "name": msg.get("name", ""), "response": {"result": content}, } } ], } ) else: gemini_contents.append({"role": "user", "parts": [{"text": content}]}) # Convert tools to Gemini format gemini_tools = None if tools: function_declarations = [] for tool in tools: if tool.get("type") == "function": func = tool["function"] function_declarations.append( { "name": func["name"], "description": func.get("description", ""), "parameters": func.get("parameters", {}), } ) if function_declarations: gemini_tools = [{"functionDeclarations": function_declarations}] request_body = { "contents": gemini_contents, "generationConfig": { "temperature": kwargs.get("temperature", self.temperature), "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), }, } if system_instruction: request_body["systemInstruction"] = { "parts": [{"text": system_instruction}] } if gemini_tools: request_body["tools"] = gemini_tools response = requests.post( self._get_api_url(model), headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json", }, json=request_body, timeout=120, ) response.raise_for_status() data = response.json() # Parse response content = "" tool_calls = None candidates = data.get("candidates", []) if candidates: parts = candidates[0].get("content", {}).get("parts", []) for part in parts: if "text" in part: content += part["text"] elif "functionCall" in part: if tool_calls is None: tool_calls = [] fc = part["functionCall"] tool_calls.append( ToolCall( id=f"call_{len(tool_calls)}", name=fc.get("name", ""), arguments=fc.get("args", {}), ) ) usage = data.get("usageMetadata", {}) tokens_used = usage.get("promptTokenCount", 0) + usage.get( "candidatesTokenCount", 0 ) finish_reason = None if candidates: finish_reason = candidates[0].get("finishReason") return LLMResponse( content=content, model=model, provider="vertex-ai", tokens_used=tokens_used, finish_reason=finish_reason, tool_calls=tool_calls, )