All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
516 lines
16 KiB
Python
516 lines
16 KiB
Python
"""LLM Client
|
|
|
|
A unified client for interacting with multiple LLM providers.
|
|
Supports OpenAI, OpenRouter, Ollama, and extensible for more providers.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
|
|
import requests
|
|
|
|
|
|
@dataclass
|
|
class ToolCall:
|
|
"""Represents a tool call from the LLM."""
|
|
|
|
id: str
|
|
name: str
|
|
arguments: dict
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""Response from an LLM call."""
|
|
|
|
content: str
|
|
model: str
|
|
provider: str
|
|
tokens_used: int | None = None
|
|
finish_reason: str | None = None
|
|
tool_calls: list[ToolCall] | None = None
|
|
|
|
|
|
class BaseLLMProvider(ABC):
|
|
"""Abstract base class for LLM providers."""
|
|
|
|
@abstractmethod
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Make a call to the LLM.
|
|
|
|
Args:
|
|
prompt: The prompt to send.
|
|
**kwargs: Provider-specific options.
|
|
|
|
Returns:
|
|
LLMResponse with the generated content.
|
|
"""
|
|
pass
|
|
|
|
def call_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""Make a call to the LLM with tool/function calling support.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
tools: List of tool definitions in OpenAI format.
|
|
**kwargs: Provider-specific options.
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool_calls.
|
|
"""
|
|
raise NotImplementedError("Tool calling not supported by this provider")
|
|
|
|
|
|
class OpenAIProvider(BaseLLMProvider):
|
|
"""OpenAI API provider."""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str = "gpt-4o-mini",
|
|
temperature: float = 0,
|
|
max_tokens: int = 4096,
|
|
timeout: int = 120,
|
|
):
|
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
|
self.model = model
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.timeout = timeout
|
|
self.api_url = "https://api.openai.com/v1/chat/completions"
|
|
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Call OpenAI API."""
|
|
if not self.api_key:
|
|
raise ValueError("OpenAI API key is required")
|
|
|
|
response = requests.post(
|
|
self.api_url,
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={
|
|
"model": kwargs.get("model", self.model),
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
},
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
choice = data["choices"][0]
|
|
usage = data.get("usage", {})
|
|
|
|
return LLMResponse(
|
|
content=choice["message"]["content"],
|
|
model=data["model"],
|
|
provider="openai",
|
|
tokens_used=usage.get("total_tokens"),
|
|
finish_reason=choice.get("finish_reason"),
|
|
)
|
|
|
|
def call_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""Call OpenAI API with tool support."""
|
|
if not self.api_key:
|
|
raise ValueError("OpenAI API key is required")
|
|
|
|
request_body = {
|
|
"model": kwargs.get("model", self.model),
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
"messages": messages,
|
|
}
|
|
|
|
if tools:
|
|
request_body["tools"] = tools
|
|
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
|
|
|
response = requests.post(
|
|
self.api_url,
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json=request_body,
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
choice = data["choices"][0]
|
|
usage = data.get("usage", {})
|
|
message = choice["message"]
|
|
|
|
# Parse tool calls if present
|
|
tool_calls = None
|
|
if message.get("tool_calls"):
|
|
tool_calls = []
|
|
for tc in message["tool_calls"]:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=tc["id"],
|
|
name=tc["function"]["name"],
|
|
arguments=json.loads(tc["function"]["arguments"]),
|
|
)
|
|
)
|
|
|
|
return LLMResponse(
|
|
content=message.get("content") or "",
|
|
model=data["model"],
|
|
provider="openai",
|
|
tokens_used=usage.get("total_tokens"),
|
|
finish_reason=choice.get("finish_reason"),
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
|
|
class OpenRouterProvider(BaseLLMProvider):
|
|
"""OpenRouter API provider."""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
model: str = "anthropic/claude-3.5-sonnet",
|
|
temperature: float = 0,
|
|
max_tokens: int = 4096,
|
|
timeout: int = 120,
|
|
):
|
|
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
|
self.model = model
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.timeout = timeout
|
|
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
|
|
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Call OpenRouter API."""
|
|
if not self.api_key:
|
|
raise ValueError("OpenRouter API key is required")
|
|
|
|
response = requests.post(
|
|
self.api_url,
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={
|
|
"model": kwargs.get("model", self.model),
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
},
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
choice = data["choices"][0]
|
|
usage = data.get("usage", {})
|
|
|
|
return LLMResponse(
|
|
content=choice["message"]["content"],
|
|
model=data.get("model", self.model),
|
|
provider="openrouter",
|
|
tokens_used=usage.get("total_tokens"),
|
|
finish_reason=choice.get("finish_reason"),
|
|
)
|
|
|
|
def call_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""Call OpenRouter API with tool support."""
|
|
if not self.api_key:
|
|
raise ValueError("OpenRouter API key is required")
|
|
|
|
request_body = {
|
|
"model": kwargs.get("model", self.model),
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
|
"messages": messages,
|
|
}
|
|
|
|
if tools:
|
|
request_body["tools"] = tools
|
|
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
|
|
|
response = requests.post(
|
|
self.api_url,
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
json=request_body,
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
choice = data["choices"][0]
|
|
usage = data.get("usage", {})
|
|
message = choice["message"]
|
|
|
|
# Parse tool calls if present
|
|
tool_calls = None
|
|
if message.get("tool_calls"):
|
|
tool_calls = []
|
|
for tc in message["tool_calls"]:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=tc["id"],
|
|
name=tc["function"]["name"],
|
|
arguments=json.loads(tc["function"]["arguments"]),
|
|
)
|
|
)
|
|
|
|
return LLMResponse(
|
|
content=message.get("content") or "",
|
|
model=data.get("model", self.model),
|
|
provider="openrouter",
|
|
tokens_used=usage.get("total_tokens"),
|
|
finish_reason=choice.get("finish_reason"),
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
|
|
class OllamaProvider(BaseLLMProvider):
|
|
"""Ollama (self-hosted) provider."""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str | None = None,
|
|
model: str = "codellama:13b",
|
|
temperature: float = 0,
|
|
timeout: int = 300,
|
|
):
|
|
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
|
self.model = model
|
|
self.temperature = temperature
|
|
self.timeout = timeout
|
|
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Call Ollama API."""
|
|
response = requests.post(
|
|
f"{self.host}/api/generate",
|
|
json={
|
|
"model": kwargs.get("model", self.model),
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
},
|
|
},
|
|
timeout=self.timeout,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
return LLMResponse(
|
|
content=data["response"],
|
|
model=data.get("model", self.model),
|
|
provider="ollama",
|
|
tokens_used=data.get("eval_count"),
|
|
finish_reason="stop" if data.get("done") else None,
|
|
)
|
|
|
|
|
|
class LLMClient:
|
|
"""Unified LLM client supporting multiple providers."""
|
|
|
|
PROVIDERS = {
|
|
"openai": OpenAIProvider,
|
|
"openrouter": OpenRouterProvider,
|
|
"ollama": OllamaProvider,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
provider: str = "openai",
|
|
config: dict | None = None,
|
|
):
|
|
"""Initialize the LLM client.
|
|
|
|
Args:
|
|
provider: Provider name (openai, openrouter, ollama).
|
|
config: Provider-specific configuration.
|
|
"""
|
|
if provider not in self.PROVIDERS:
|
|
raise ValueError(
|
|
f"Unknown provider: {provider}. Available: {list(self.PROVIDERS.keys())}"
|
|
)
|
|
|
|
self.provider_name = provider
|
|
self.config = config or {}
|
|
self._provider = self.PROVIDERS[provider](**self.config)
|
|
|
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Make a call to the configured LLM provider.
|
|
|
|
Args:
|
|
prompt: The prompt to send.
|
|
**kwargs: Provider-specific options.
|
|
|
|
Returns:
|
|
LLMResponse with the generated content.
|
|
"""
|
|
return self._provider.call(prompt, **kwargs)
|
|
|
|
def call_with_tools(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""Make a call with tool/function calling support.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
tools: List of tool definitions in OpenAI format.
|
|
**kwargs: Provider-specific options.
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool_calls.
|
|
"""
|
|
return self._provider.call_with_tools(messages, tools, **kwargs)
|
|
|
|
def call_json(self, prompt: str, **kwargs) -> dict:
|
|
"""Make a call and parse the response as JSON.
|
|
|
|
Args:
|
|
prompt: The prompt to send (should request JSON output).
|
|
**kwargs: Provider-specific options.
|
|
|
|
Returns:
|
|
Parsed JSON response.
|
|
|
|
Raises:
|
|
json.JSONDecodeError: If response is not valid JSON.
|
|
"""
|
|
response = self.call(prompt, **kwargs)
|
|
content = response.content.strip()
|
|
|
|
return self._extract_json(content)
|
|
|
|
def _extract_json(self, content: str) -> dict:
|
|
"""Extract and parse JSON from content string.
|
|
|
|
Handles markdown code blocks and preamble text.
|
|
"""
|
|
import re
|
|
|
|
content = content.strip()
|
|
|
|
# Attempt 1: direct parse
|
|
try:
|
|
return json.loads(content)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Attempt 2: Extract from markdown code blocks (improved regex)
|
|
if "```" in content:
|
|
# Try multiple code block patterns
|
|
patterns = [
|
|
r"```json\s*\n([\s\S]*?)\n```", # ```json with newlines
|
|
r"```json\s*([\s\S]*?)```", # ```json without newlines
|
|
r"```\s*\n([\s\S]*?)\n```", # ``` with newlines
|
|
r"```\s*([\s\S]*?)```", # ``` without newlines
|
|
]
|
|
|
|
for pattern in patterns:
|
|
match = re.search(pattern, content)
|
|
if match:
|
|
try:
|
|
json_str = match.group(1).strip()
|
|
return json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Attempt 3: Find first { and last }
|
|
try:
|
|
start = content.find("{")
|
|
end = content.rfind("}")
|
|
if start != -1 and end != -1:
|
|
json_str = content[start : end + 1]
|
|
return json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Attempt 4: Fix common JSON errors (comments, trailing commas)
|
|
try:
|
|
# Remove comments
|
|
json_str = re.sub(r"//.*", "", content)
|
|
json_str = re.sub(r"/\*[\s\S]*?\*/", "", json_str)
|
|
# Try to extract JSON after cleaning
|
|
start = json_str.find("{")
|
|
end = json_str.rfind("}")
|
|
if start != -1 and end != -1:
|
|
json_str = json_str[start : end + 1]
|
|
return json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# If all attempts fail, raise an error with the content for debugging
|
|
snippet = content[:500] + "..." if len(content) > 500 else content
|
|
raise ValueError(f"Failed to parse JSON response: {snippet!r}")
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict) -> "LLMClient":
|
|
"""Create an LLM client from a configuration dictionary.
|
|
|
|
Args:
|
|
config: Configuration with 'provider' key and provider-specific settings.
|
|
|
|
Returns:
|
|
Configured LLMClient instance.
|
|
"""
|
|
provider = config.get("provider", "openai")
|
|
provider_config = {}
|
|
|
|
# Get timeout configuration
|
|
timeouts = config.get("timeouts", {})
|
|
llm_timeout = timeouts.get("llm", 120)
|
|
ollama_timeout = timeouts.get("ollama", 300)
|
|
|
|
# Map config keys to provider-specific settings
|
|
if provider == "openai":
|
|
provider_config = {
|
|
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
|
|
"temperature": config.get("temperature", 0),
|
|
"max_tokens": config.get("max_tokens", 16000),
|
|
"timeout": llm_timeout,
|
|
}
|
|
elif provider == "openrouter":
|
|
provider_config = {
|
|
"model": config.get("model", {}).get(
|
|
"openrouter", "anthropic/claude-3.5-sonnet"
|
|
),
|
|
"temperature": config.get("temperature", 0),
|
|
"max_tokens": config.get("max_tokens", 16000),
|
|
"timeout": llm_timeout,
|
|
}
|
|
elif provider == "ollama":
|
|
provider_config = {
|
|
"model": config.get("model", {}).get("ollama", "codellama:13b"),
|
|
"temperature": config.get("temperature", 0),
|
|
"timeout": ollama_timeout,
|
|
}
|
|
|
|
return cls(provider=provider, config=provider_config)
|