"""LLM Client A unified client for interacting with multiple LLM providers. Supports OpenAI, OpenRouter, Ollama, and extensible for more providers. """ import json import os from abc import ABC, abstractmethod from dataclasses import dataclass import requests @dataclass class ToolCall: """Represents a tool call from the LLM.""" id: str name: str arguments: dict @dataclass class LLMResponse: """Response from an LLM call.""" content: str model: str provider: str tokens_used: int | None = None finish_reason: str | None = None tool_calls: list[ToolCall] | None = None class BaseLLMProvider(ABC): """Abstract base class for LLM providers.""" @abstractmethod def call(self, prompt: str, **kwargs) -> LLMResponse: """Make a call to the LLM. Args: prompt: The prompt to send. **kwargs: Provider-specific options. Returns: LLMResponse with the generated content. """ pass def call_with_tools( self, messages: list[dict], tools: list[dict] | None = None, **kwargs, ) -> LLMResponse: """Make a call to the LLM with tool/function calling support. Args: messages: List of message dicts with 'role' and 'content'. tools: List of tool definitions in OpenAI format. **kwargs: Provider-specific options. Returns: LLMResponse with content and/or tool_calls. """ raise NotImplementedError("Tool calling not supported by this provider") class OpenAIProvider(BaseLLMProvider): """OpenAI API provider.""" def __init__( self, api_key: str | None = None, model: str = "gpt-4o-mini", temperature: float = 0, max_tokens: int = 4096, ): self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") self.model = model self.temperature = temperature self.max_tokens = max_tokens self.api_url = "https://api.openai.com/v1/chat/completions" def call(self, prompt: str, **kwargs) -> LLMResponse: """Call OpenAI API.""" if not self.api_key: raise ValueError("OpenAI API key is required") response = requests.post( self.api_url, headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={ "model": kwargs.get("model", self.model), "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), "messages": [{"role": "user", "content": prompt}], }, timeout=120, ) response.raise_for_status() data = response.json() choice = data["choices"][0] usage = data.get("usage", {}) return LLMResponse( content=choice["message"]["content"], model=data["model"], provider="openai", tokens_used=usage.get("total_tokens"), finish_reason=choice.get("finish_reason"), ) def call_with_tools( self, messages: list[dict], tools: list[dict] | None = None, **kwargs, ) -> LLMResponse: """Call OpenAI API with tool support.""" if not self.api_key: raise ValueError("OpenAI API key is required") request_body = { "model": kwargs.get("model", self.model), "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), "messages": messages, } if tools: request_body["tools"] = tools request_body["tool_choice"] = kwargs.get("tool_choice", "auto") response = requests.post( self.api_url, headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json=request_body, timeout=120, ) response.raise_for_status() data = response.json() choice = data["choices"][0] usage = data.get("usage", {}) message = choice["message"] # Parse tool calls if present tool_calls = None if message.get("tool_calls"): tool_calls = [] for tc in message["tool_calls"]: tool_calls.append( ToolCall( id=tc["id"], name=tc["function"]["name"], arguments=json.loads(tc["function"]["arguments"]), ) ) return LLMResponse( content=message.get("content") or "", model=data["model"], provider="openai", tokens_used=usage.get("total_tokens"), finish_reason=choice.get("finish_reason"), tool_calls=tool_calls, ) class OpenRouterProvider(BaseLLMProvider): """OpenRouter API provider.""" def __init__( self, api_key: str | None = None, model: str = "anthropic/claude-3.5-sonnet", temperature: float = 0, max_tokens: int = 4096, ): self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "") self.model = model self.temperature = temperature self.max_tokens = max_tokens self.api_url = "https://openrouter.ai/api/v1/chat/completions" def call(self, prompt: str, **kwargs) -> LLMResponse: """Call OpenRouter API.""" if not self.api_key: raise ValueError("OpenRouter API key is required") response = requests.post( self.api_url, headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={ "model": kwargs.get("model", self.model), "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), "messages": [{"role": "user", "content": prompt}], }, timeout=120, ) response.raise_for_status() data = response.json() choice = data["choices"][0] usage = data.get("usage", {}) return LLMResponse( content=choice["message"]["content"], model=data.get("model", self.model), provider="openrouter", tokens_used=usage.get("total_tokens"), finish_reason=choice.get("finish_reason"), ) def call_with_tools( self, messages: list[dict], tools: list[dict] | None = None, **kwargs, ) -> LLMResponse: """Call OpenRouter API with tool support.""" if not self.api_key: raise ValueError("OpenRouter API key is required") request_body = { "model": kwargs.get("model", self.model), "temperature": kwargs.get("temperature", self.temperature), "max_tokens": kwargs.get("max_tokens", self.max_tokens), "messages": messages, } if tools: request_body["tools"] = tools request_body["tool_choice"] = kwargs.get("tool_choice", "auto") response = requests.post( self.api_url, headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json=request_body, timeout=120, ) response.raise_for_status() data = response.json() choice = data["choices"][0] usage = data.get("usage", {}) message = choice["message"] # Parse tool calls if present tool_calls = None if message.get("tool_calls"): tool_calls = [] for tc in message["tool_calls"]: tool_calls.append( ToolCall( id=tc["id"], name=tc["function"]["name"], arguments=json.loads(tc["function"]["arguments"]), ) ) return LLMResponse( content=message.get("content") or "", model=data.get("model", self.model), provider="openrouter", tokens_used=usage.get("total_tokens"), finish_reason=choice.get("finish_reason"), tool_calls=tool_calls, ) class OllamaProvider(BaseLLMProvider): """Ollama (self-hosted) provider.""" def __init__( self, host: str | None = None, model: str = "codellama:13b", temperature: float = 0, ): self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434") self.model = model self.temperature = temperature def call(self, prompt: str, **kwargs) -> LLMResponse: """Call Ollama API.""" response = requests.post( f"{self.host}/api/generate", json={ "model": kwargs.get("model", self.model), "prompt": prompt, "stream": False, "options": { "temperature": kwargs.get("temperature", self.temperature), }, }, timeout=300, # Longer timeout for local models ) response.raise_for_status() data = response.json() return LLMResponse( content=data["response"], model=data.get("model", self.model), provider="ollama", tokens_used=data.get("eval_count"), finish_reason="stop" if data.get("done") else None, ) class LLMClient: """Unified LLM client supporting multiple providers.""" PROVIDERS = { "openai": OpenAIProvider, "openrouter": OpenRouterProvider, "ollama": OllamaProvider, } def __init__( self, provider: str = "openai", config: dict | None = None, ): """Initialize the LLM client. Args: provider: Provider name (openai, openrouter, ollama). config: Provider-specific configuration. """ if provider not in self.PROVIDERS: raise ValueError( f"Unknown provider: {provider}. Available: {list(self.PROVIDERS.keys())}" ) self.provider_name = provider self.config = config or {} self._provider = self.PROVIDERS[provider](**self.config) def call(self, prompt: str, **kwargs) -> LLMResponse: """Make a call to the configured LLM provider. Args: prompt: The prompt to send. **kwargs: Provider-specific options. Returns: LLMResponse with the generated content. """ return self._provider.call(prompt, **kwargs) def call_with_tools( self, messages: list[dict], tools: list[dict] | None = None, **kwargs, ) -> LLMResponse: """Make a call with tool/function calling support. Args: messages: List of message dicts with 'role' and 'content'. tools: List of tool definitions in OpenAI format. **kwargs: Provider-specific options. Returns: LLMResponse with content and/or tool_calls. """ return self._provider.call_with_tools(messages, tools, **kwargs) def call_json(self, prompt: str, **kwargs) -> dict: """Make a call and parse the response as JSON. Args: prompt: The prompt to send (should request JSON output). **kwargs: Provider-specific options. Returns: Parsed JSON response. Raises: json.JSONDecodeError: If response is not valid JSON. """ response = self.call(prompt, **kwargs) content = response.content.strip() return self._extract_json(content) def _extract_json(self, content: str) -> dict: """Extract and parse JSON from content string. Handles markdown code blocks and preamble text. """ import re content = content.strip() # Attempt 1: direct parse try: return json.loads(content) except json.JSONDecodeError: pass # Attempt 2: Extract from markdown code blocks (improved regex) if "```" in content: # Try multiple code block patterns patterns = [ r"```json\s*\n([\s\S]*?)\n```", # ```json with newlines r"```json\s*([\s\S]*?)```", # ```json without newlines r"```\s*\n([\s\S]*?)\n```", # ``` with newlines r"```\s*([\s\S]*?)```", # ``` without newlines ] for pattern in patterns: match = re.search(pattern, content) if match: try: json_str = match.group(1).strip() return json.loads(json_str) except json.JSONDecodeError: continue # Attempt 3: Find first { and last } try: start = content.find("{") end = content.rfind("}") if start != -1 and end != -1: json_str = content[start : end + 1] return json.loads(json_str) except json.JSONDecodeError: pass # Attempt 4: Fix common JSON errors (comments, trailing commas) try: # Remove comments json_str = re.sub(r"//.*", "", content) json_str = re.sub(r"/\*[\s\S]*?\*/", "", json_str) # Try to extract JSON after cleaning start = json_str.find("{") end = json_str.rfind("}") if start != -1 and end != -1: json_str = json_str[start : end + 1] return json.loads(json_str) except json.JSONDecodeError: pass # If all attempts fail, raise an error with the content for debugging snippet = content[:500] + "..." if len(content) > 500 else content raise ValueError(f"Failed to parse JSON response: {snippet!r}") @classmethod def from_config(cls, config: dict) -> "LLMClient": """Create an LLM client from a configuration dictionary. Args: config: Configuration with 'provider' key and provider-specific settings. Returns: Configured LLMClient instance. """ provider = config.get("provider", "openai") provider_config = {} # Map config keys to provider-specific settings if provider == "openai": provider_config = { "model": config.get("model", {}).get("openai", "gpt-4o-mini"), "temperature": config.get("temperature", 0), "max_tokens": config.get("max_tokens", 16000), } elif provider == "openrouter": provider_config = { "model": config.get("model", {}).get( "openrouter", "anthropic/claude-3.5-sonnet" ), "temperature": config.get("temperature", 0), "max_tokens": config.get("max_tokens", 16000), } elif provider == "ollama": provider_config = { "model": config.get("model", {}).get("ollama", "codellama:13b"), "temperature": config.get("temperature", 0), } return cls(provider=provider, config=provider_config)