Files
openrabbit/tools/ai-review/clients/llm_client.py
latte e8d28225e0
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
just why not
2026-01-07 21:19:46 +01:00

516 lines
16 KiB
Python

"""LLM Client
A unified client for interacting with multiple LLM providers.
Supports OpenAI, OpenRouter, Ollama, and extensible for more providers.
"""
import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
import requests
@dataclass
class ToolCall:
"""Represents a tool call from the LLM."""
id: str
name: str
arguments: dict
@dataclass
class LLMResponse:
"""Response from an LLM call."""
content: str
model: str
provider: str
tokens_used: int | None = None
finish_reason: str | None = None
tool_calls: list[ToolCall] | None = None
class BaseLLMProvider(ABC):
"""Abstract base class for LLM providers."""
@abstractmethod
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the LLM.
Args:
prompt: The prompt to send.
**kwargs: Provider-specific options.
Returns:
LLMResponse with the generated content.
"""
pass
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the LLM with tool/function calling support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Provider-specific options.
Returns:
LLMResponse with content and/or tool_calls.
"""
raise NotImplementedError("Tool calling not supported by this provider")
class OpenAIProvider(BaseLLMProvider):
"""OpenAI API provider."""
def __init__(
self,
api_key: str | None = None,
model: str = "gpt-4o-mini",
temperature: float = 0,
max_tokens: int = 4096,
timeout: int = 120,
):
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.timeout = timeout
self.api_url = "https://api.openai.com/v1/chat/completions"
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Call OpenAI API."""
if not self.api_key:
raise ValueError("OpenAI API key is required")
response = requests.post(
self.api_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": kwargs.get("model", self.model),
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": [{"role": "user", "content": prompt}],
},
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
choice = data["choices"][0]
usage = data.get("usage", {})
return LLMResponse(
content=choice["message"]["content"],
model=data["model"],
provider="openai",
tokens_used=usage.get("total_tokens"),
finish_reason=choice.get("finish_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Call OpenAI API with tool support."""
if not self.api_key:
raise ValueError("OpenAI API key is required")
request_body = {
"model": kwargs.get("model", self.model),
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": messages,
}
if tools:
request_body["tools"] = tools
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
response = requests.post(
self.api_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json=request_body,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
choice = data["choices"][0]
usage = data.get("usage", {})
message = choice["message"]
# Parse tool calls if present
tool_calls = None
if message.get("tool_calls"):
tool_calls = []
for tc in message["tool_calls"]:
tool_calls.append(
ToolCall(
id=tc["id"],
name=tc["function"]["name"],
arguments=json.loads(tc["function"]["arguments"]),
)
)
return LLMResponse(
content=message.get("content") or "",
model=data["model"],
provider="openai",
tokens_used=usage.get("total_tokens"),
finish_reason=choice.get("finish_reason"),
tool_calls=tool_calls,
)
class OpenRouterProvider(BaseLLMProvider):
"""OpenRouter API provider."""
def __init__(
self,
api_key: str | None = None,
model: str = "anthropic/claude-3.5-sonnet",
temperature: float = 0,
max_tokens: int = 4096,
timeout: int = 120,
):
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.timeout = timeout
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Call OpenRouter API."""
if not self.api_key:
raise ValueError("OpenRouter API key is required")
response = requests.post(
self.api_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": kwargs.get("model", self.model),
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": [{"role": "user", "content": prompt}],
},
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
choice = data["choices"][0]
usage = data.get("usage", {})
return LLMResponse(
content=choice["message"]["content"],
model=data.get("model", self.model),
provider="openrouter",
tokens_used=usage.get("total_tokens"),
finish_reason=choice.get("finish_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Call OpenRouter API with tool support."""
if not self.api_key:
raise ValueError("OpenRouter API key is required")
request_body = {
"model": kwargs.get("model", self.model),
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": messages,
}
if tools:
request_body["tools"] = tools
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
response = requests.post(
self.api_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json=request_body,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
choice = data["choices"][0]
usage = data.get("usage", {})
message = choice["message"]
# Parse tool calls if present
tool_calls = None
if message.get("tool_calls"):
tool_calls = []
for tc in message["tool_calls"]:
tool_calls.append(
ToolCall(
id=tc["id"],
name=tc["function"]["name"],
arguments=json.loads(tc["function"]["arguments"]),
)
)
return LLMResponse(
content=message.get("content") or "",
model=data.get("model", self.model),
provider="openrouter",
tokens_used=usage.get("total_tokens"),
finish_reason=choice.get("finish_reason"),
tool_calls=tool_calls,
)
class OllamaProvider(BaseLLMProvider):
"""Ollama (self-hosted) provider."""
def __init__(
self,
host: str | None = None,
model: str = "codellama:13b",
temperature: float = 0,
timeout: int = 300,
):
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
self.model = model
self.temperature = temperature
self.timeout = timeout
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Call Ollama API."""
response = requests.post(
f"{self.host}/api/generate",
json={
"model": kwargs.get("model", self.model),
"prompt": prompt,
"stream": False,
"options": {
"temperature": kwargs.get("temperature", self.temperature),
},
},
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
return LLMResponse(
content=data["response"],
model=data.get("model", self.model),
provider="ollama",
tokens_used=data.get("eval_count"),
finish_reason="stop" if data.get("done") else None,
)
class LLMClient:
"""Unified LLM client supporting multiple providers."""
PROVIDERS = {
"openai": OpenAIProvider,
"openrouter": OpenRouterProvider,
"ollama": OllamaProvider,
}
def __init__(
self,
provider: str = "openai",
config: dict | None = None,
):
"""Initialize the LLM client.
Args:
provider: Provider name (openai, openrouter, ollama).
config: Provider-specific configuration.
"""
if provider not in self.PROVIDERS:
raise ValueError(
f"Unknown provider: {provider}. Available: {list(self.PROVIDERS.keys())}"
)
self.provider_name = provider
self.config = config or {}
self._provider = self.PROVIDERS[provider](**self.config)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the configured LLM provider.
Args:
prompt: The prompt to send.
**kwargs: Provider-specific options.
Returns:
LLMResponse with the generated content.
"""
return self._provider.call(prompt, **kwargs)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call with tool/function calling support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Provider-specific options.
Returns:
LLMResponse with content and/or tool_calls.
"""
return self._provider.call_with_tools(messages, tools, **kwargs)
def call_json(self, prompt: str, **kwargs) -> dict:
"""Make a call and parse the response as JSON.
Args:
prompt: The prompt to send (should request JSON output).
**kwargs: Provider-specific options.
Returns:
Parsed JSON response.
Raises:
json.JSONDecodeError: If response is not valid JSON.
"""
response = self.call(prompt, **kwargs)
content = response.content.strip()
return self._extract_json(content)
def _extract_json(self, content: str) -> dict:
"""Extract and parse JSON from content string.
Handles markdown code blocks and preamble text.
"""
import re
content = content.strip()
# Attempt 1: direct parse
try:
return json.loads(content)
except json.JSONDecodeError:
pass
# Attempt 2: Extract from markdown code blocks (improved regex)
if "```" in content:
# Try multiple code block patterns
patterns = [
r"```json\s*\n([\s\S]*?)\n```", # ```json with newlines
r"```json\s*([\s\S]*?)```", # ```json without newlines
r"```\s*\n([\s\S]*?)\n```", # ``` with newlines
r"```\s*([\s\S]*?)```", # ``` without newlines
]
for pattern in patterns:
match = re.search(pattern, content)
if match:
try:
json_str = match.group(1).strip()
return json.loads(json_str)
except json.JSONDecodeError:
continue
# Attempt 3: Find first { and last }
try:
start = content.find("{")
end = content.rfind("}")
if start != -1 and end != -1:
json_str = content[start : end + 1]
return json.loads(json_str)
except json.JSONDecodeError:
pass
# Attempt 4: Fix common JSON errors (comments, trailing commas)
try:
# Remove comments
json_str = re.sub(r"//.*", "", content)
json_str = re.sub(r"/\*[\s\S]*?\*/", "", json_str)
# Try to extract JSON after cleaning
start = json_str.find("{")
end = json_str.rfind("}")
if start != -1 and end != -1:
json_str = json_str[start : end + 1]
return json.loads(json_str)
except json.JSONDecodeError:
pass
# If all attempts fail, raise an error with the content for debugging
snippet = content[:500] + "..." if len(content) > 500 else content
raise ValueError(f"Failed to parse JSON response: {snippet!r}")
@classmethod
def from_config(cls, config: dict) -> "LLMClient":
"""Create an LLM client from a configuration dictionary.
Args:
config: Configuration with 'provider' key and provider-specific settings.
Returns:
Configured LLMClient instance.
"""
provider = config.get("provider", "openai")
provider_config = {}
# Get timeout configuration
timeouts = config.get("timeouts", {})
llm_timeout = timeouts.get("llm", 120)
ollama_timeout = timeouts.get("ollama", 300)
# Map config keys to provider-specific settings
if provider == "openai":
provider_config = {
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
"temperature": config.get("temperature", 0),
"max_tokens": config.get("max_tokens", 16000),
"timeout": llm_timeout,
}
elif provider == "openrouter":
provider_config = {
"model": config.get("model", {}).get(
"openrouter", "anthropic/claude-3.5-sonnet"
),
"temperature": config.get("temperature", 0),
"max_tokens": config.get("max_tokens", 16000),
"timeout": llm_timeout,
}
elif provider == "ollama":
provider_config = {
"model": config.get("model", {}).get("ollama", "codellama:13b"),
"temperature": config.get("temperature", 0),
"timeout": ollama_timeout,
}
return cls(provider=provider, config=provider_config)