first commit
This commit is contained in:
10
tools/ai-review/clients/__init__.py
Normal file
10
tools/ai-review/clients/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""API Clients Package
|
||||
|
||||
This package contains client wrappers for external services
|
||||
like Gitea API and LLM providers.
|
||||
"""
|
||||
|
||||
from clients.gitea_client import GiteaClient
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
__all__ = ["GiteaClient", "LLMClient"]
|
||||
447
tools/ai-review/clients/gitea_client.py
Normal file
447
tools/ai-review/clients/gitea_client.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Gitea API Client
|
||||
|
||||
A unified client for interacting with the Gitea REST API.
|
||||
Provides methods for issues, pull requests, comments, and repository operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class GiteaClient:
|
||||
"""Client for Gitea API operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str | None = None,
|
||||
token: str | None = None,
|
||||
timeout: int = 30,
|
||||
):
|
||||
"""Initialize the Gitea client.
|
||||
|
||||
Args:
|
||||
api_url: Gitea API base URL. Defaults to AI_REVIEW_API_URL env var.
|
||||
token: API token. Defaults to AI_REVIEW_TOKEN env var.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
self.api_url = api_url or os.environ.get("AI_REVIEW_API_URL", "")
|
||||
self.token = token or os.environ.get("AI_REVIEW_TOKEN", "")
|
||||
self.timeout = timeout
|
||||
|
||||
if not self.api_url:
|
||||
raise ValueError("Gitea API URL is required")
|
||||
if not self.token:
|
||||
raise ValueError("Gitea API token is required")
|
||||
|
||||
self.headers = {
|
||||
"Authorization": f"token {self.token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> dict | list:
|
||||
"""Make an API request.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PATCH, DELETE).
|
||||
endpoint: API endpoint (without base URL).
|
||||
json: Request body for POST/PATCH.
|
||||
params: Query parameters.
|
||||
|
||||
Returns:
|
||||
Response JSON data.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the request fails.
|
||||
"""
|
||||
url = f"{self.api_url}{endpoint}"
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=self.headers,
|
||||
json=json,
|
||||
params=params,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 204:
|
||||
return {}
|
||||
return response.json()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Issue Operations
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def create_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
labels: list[int] | None = None,
|
||||
) -> dict:
|
||||
"""Create a new issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
title: Issue title.
|
||||
body: Issue body.
|
||||
labels: Optional list of label IDs.
|
||||
|
||||
Returns:
|
||||
Created issue object.
|
||||
"""
|
||||
payload = {
|
||||
"title": title,
|
||||
"body": body,
|
||||
}
|
||||
if labels:
|
||||
payload["labels"] = labels
|
||||
|
||||
return self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def update_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: str | None = None,
|
||||
) -> dict:
|
||||
"""Update an existing issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: Issue number.
|
||||
title: New title.
|
||||
body: New body.
|
||||
state: New state (open, closed).
|
||||
|
||||
Returns:
|
||||
Updated issue object.
|
||||
"""
|
||||
payload = {}
|
||||
if title:
|
||||
payload["title"] = title
|
||||
if body:
|
||||
payload["body"] = body
|
||||
if state:
|
||||
payload["state"] = state
|
||||
|
||||
return self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/issues/{index}",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def list_issues(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: str = "open",
|
||||
labels: list[str] | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 30,
|
||||
) -> list[dict]:
|
||||
"""List issues in a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
state: Issue state (open, closed, all).
|
||||
labels: Filter by labels.
|
||||
page: Page number.
|
||||
limit: Items per page.
|
||||
|
||||
Returns:
|
||||
List of issue objects.
|
||||
"""
|
||||
params = {
|
||||
"state": state,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}
|
||||
if labels:
|
||||
params["labels"] = ",".join(labels)
|
||||
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/issues", params=params)
|
||||
|
||||
def get_issue(self, owner: str, repo: str, index: int) -> dict:
|
||||
"""Get a single issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: Issue number.
|
||||
|
||||
Returns:
|
||||
Issue object.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/issues/{index}")
|
||||
|
||||
def create_issue_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
body: str,
|
||||
) -> dict:
|
||||
"""Create a comment on an issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: Issue number.
|
||||
body: Comment body.
|
||||
|
||||
Returns:
|
||||
Created comment object.
|
||||
"""
|
||||
return self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{index}/comments",
|
||||
json={"body": body},
|
||||
)
|
||||
|
||||
def update_issue_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
comment_id: int,
|
||||
body: str,
|
||||
) -> dict:
|
||||
"""Update an existing comment.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
comment_id: Comment ID.
|
||||
body: Updated comment body.
|
||||
|
||||
Returns:
|
||||
Updated comment object.
|
||||
"""
|
||||
return self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/issues/comments/{comment_id}",
|
||||
json={"body": body},
|
||||
)
|
||||
|
||||
def list_issue_comments(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
) -> list[dict]:
|
||||
"""List comments on an issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: Issue number.
|
||||
|
||||
Returns:
|
||||
List of comment objects.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/issues/{index}/comments")
|
||||
|
||||
def add_issue_labels(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
labels: list[int],
|
||||
) -> list[dict]:
|
||||
"""Add labels to an issue.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: Issue number.
|
||||
labels: List of label IDs to add.
|
||||
|
||||
Returns:
|
||||
List of label objects.
|
||||
"""
|
||||
return self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{index}/labels",
|
||||
json={"labels": labels},
|
||||
)
|
||||
|
||||
def get_repo_labels(self, owner: str, repo: str) -> list[dict]:
|
||||
"""Get all labels for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
|
||||
Returns:
|
||||
List of label objects.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/labels")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Pull Request Operations
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_pull_request(self, owner: str, repo: str, index: int) -> dict:
|
||||
"""Get a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: PR number.
|
||||
|
||||
Returns:
|
||||
Pull request object.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/pulls/{index}")
|
||||
|
||||
def get_pull_request_diff(self, owner: str, repo: str, index: int) -> str:
|
||||
"""Get the diff for a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: PR number.
|
||||
|
||||
Returns:
|
||||
Diff text.
|
||||
"""
|
||||
url = f"{self.api_url}/repos/{owner}/{repo}/pulls/{index}.diff"
|
||||
response = requests.get(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"token {self.token}",
|
||||
"Accept": "text/plain",
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
def list_pull_request_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
) -> list[dict]:
|
||||
"""List files changed in a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: PR number.
|
||||
|
||||
Returns:
|
||||
List of changed file objects.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/pulls/{index}/files")
|
||||
|
||||
def create_pull_request_review(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
body: str,
|
||||
event: str = "COMMENT",
|
||||
comments: list[dict] | None = None,
|
||||
) -> dict:
|
||||
"""Create a review on a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
index: PR number.
|
||||
body: Review body.
|
||||
event: Review event (APPROVE, REQUEST_CHANGES, COMMENT).
|
||||
comments: List of inline comments.
|
||||
|
||||
Returns:
|
||||
Created review object.
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"body": body,
|
||||
"event": event,
|
||||
}
|
||||
if comments:
|
||||
payload["comments"] = comments
|
||||
|
||||
return self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{index}/reviews",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Repository Operations
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_repository(self, owner: str, repo: str) -> dict:
|
||||
"""Get repository information.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
|
||||
Returns:
|
||||
Repository object.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}")
|
||||
|
||||
def get_file_contents(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
filepath: str,
|
||||
ref: str | None = None,
|
||||
) -> dict:
|
||||
"""Get file contents from a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
filepath: Path to file.
|
||||
ref: Git ref (branch, tag, commit).
|
||||
|
||||
Returns:
|
||||
File content object with base64-encoded content.
|
||||
"""
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
return self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/contents/{filepath}",
|
||||
params=params,
|
||||
)
|
||||
|
||||
def get_branch(self, owner: str, repo: str, branch: str) -> dict:
|
||||
"""Get branch information.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
branch: Branch name.
|
||||
|
||||
Returns:
|
||||
Branch object.
|
||||
"""
|
||||
return self._request("GET", f"/repos/{owner}/{repo}/branches/{branch}")
|
||||
482
tools/ai-review/clients/llm_client.py
Normal file
482
tools/ai-review/clients/llm_client.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""LLM Client
|
||||
|
||||
A unified client for interacting with multiple LLM providers.
|
||||
Supports OpenAI, OpenRouter, Ollama, and extensible for more providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Represents a tool call from the LLM."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM call."""
|
||||
|
||||
content: str
|
||||
model: str
|
||||
provider: str
|
||||
tokens_used: int | None = None
|
||||
finish_reason: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class BaseLLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
@abstractmethod
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Provider-specific options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
pass
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the LLM with tool/function calling support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Provider-specific options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
raise NotImplementedError("Tool calling not supported by this provider")
|
||||
|
||||
|
||||
class OpenAIProvider(BaseLLMProvider):
|
||||
"""OpenAI API provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "gpt-4o-mini",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.api_url = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Call OpenAI API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": kwargs.get("model", self.model),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data["choices"][0]
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=choice["message"]["content"],
|
||||
model=data["model"],
|
||||
provider="openai",
|
||||
tokens_used=usage.get("total_tokens"),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Call OpenAI API with tool support."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
request_body = {
|
||||
"model": kwargs.get("model", self.model),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data["choices"][0]
|
||||
usage = data.get("usage", {})
|
||||
message = choice["message"]
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content") or "",
|
||||
model=data["model"],
|
||||
provider="openai",
|
||||
tokens_used=usage.get("total_tokens"),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterProvider(BaseLLMProvider):
|
||||
"""OpenRouter API provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "anthropic/claude-3.5-sonnet",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Call OpenRouter API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenRouter API key is required")
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": kwargs.get("model", self.model),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data["choices"][0]
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=choice["message"]["content"],
|
||||
model=data.get("model", self.model),
|
||||
provider="openrouter",
|
||||
tokens_used=usage.get("total_tokens"),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Call OpenRouter API with tool support."""
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenRouter API key is required")
|
||||
|
||||
request_body = {
|
||||
"model": kwargs.get("model", self.model),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self.api_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data["choices"][0]
|
||||
usage = data.get("usage", {})
|
||||
message = choice["message"]
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=json.loads(tc["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content") or "",
|
||||
model=data.get("model", self.model),
|
||||
provider="openrouter",
|
||||
tokens_used=usage.get("total_tokens"),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class OllamaProvider(BaseLLMProvider):
|
||||
"""Ollama (self-hosted) provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str | None = None,
|
||||
model: str = "codellama:13b",
|
||||
temperature: float = 0,
|
||||
):
|
||||
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Call Ollama API."""
|
||||
response = requests.post(
|
||||
f"{self.host}/api/generate",
|
||||
json={
|
||||
"model": kwargs.get("model", self.model),
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
},
|
||||
timeout=300, # Longer timeout for local models
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return LLMResponse(
|
||||
content=data["response"],
|
||||
model=data.get("model", self.model),
|
||||
provider="ollama",
|
||||
tokens_used=data.get("eval_count"),
|
||||
finish_reason="stop" if data.get("done") else None,
|
||||
)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Unified LLM client supporting multiple providers."""
|
||||
|
||||
PROVIDERS = {
|
||||
"openai": OpenAIProvider,
|
||||
"openrouter": OpenRouterProvider,
|
||||
"ollama": OllamaProvider,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
config: dict | None = None,
|
||||
):
|
||||
"""Initialize the LLM client.
|
||||
|
||||
Args:
|
||||
provider: Provider name (openai, openrouter, ollama).
|
||||
config: Provider-specific configuration.
|
||||
"""
|
||||
if provider not in self.PROVIDERS:
|
||||
raise ValueError(f"Unknown provider: {provider}. Available: {list(self.PROVIDERS.keys())}")
|
||||
|
||||
self.provider_name = provider
|
||||
self.config = config or {}
|
||||
self._provider = self.PROVIDERS[provider](**self.config)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the configured LLM provider.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Provider-specific options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
return self._provider.call(prompt, **kwargs)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call with tool/function calling support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Provider-specific options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
return self._provider.call_with_tools(messages, tools, **kwargs)
|
||||
|
||||
def call_json(self, prompt: str, **kwargs) -> dict:
|
||||
"""Make a call and parse the response as JSON.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send (should request JSON output).
|
||||
**kwargs: Provider-specific options.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response.
|
||||
|
||||
Raises:
|
||||
json.JSONDecodeError: If response is not valid JSON.
|
||||
"""
|
||||
response = self.call(prompt, **kwargs)
|
||||
content = response.content.strip()
|
||||
|
||||
return self._extract_json(content)
|
||||
|
||||
def _extract_json(self, content: str) -> dict:
|
||||
"""Extract and parse JSON from content string.
|
||||
|
||||
Handles markdown code blocks and preamble text.
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
# Attempt 1: direct parse
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Attempt 2: Extract from markdown code blocks
|
||||
if "```" in content:
|
||||
# Find the JSON block
|
||||
import re
|
||||
match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Attempt 3: Find first { and last }
|
||||
try:
|
||||
start = content.find("{")
|
||||
end = content.rfind("}")
|
||||
if start != -1 and end != -1:
|
||||
json_str = content[start : end + 1]
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Attempt 4: Fix common JSON errors (comments, trailing commas)
|
||||
# This is risky but helpful for LLM output
|
||||
try:
|
||||
# Remove comments
|
||||
import re
|
||||
json_str = re.sub(r"//.*", "", content)
|
||||
json_str = re.sub(r"/\*[\s\S]*?\*/", "", json_str)
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
# If all attempts fail, raise an error with the content for debugging
|
||||
snippet = content[:500] + "..." if len(content) > 500 else content
|
||||
raise ValueError(f"Failed to parse JSON response: {e}. Raw content snippet: {snippet!r}")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "LLMClient":
|
||||
"""Create an LLM client from a configuration dictionary.
|
||||
|
||||
Args:
|
||||
config: Configuration with 'provider' key and provider-specific settings.
|
||||
|
||||
Returns:
|
||||
Configured LLMClient instance.
|
||||
"""
|
||||
provider = config.get("provider", "openai")
|
||||
provider_config = {}
|
||||
|
||||
# Map config keys to provider-specific settings
|
||||
if provider == "openai":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"max_tokens": config.get("max_tokens", 16000),
|
||||
}
|
||||
elif provider == "openrouter":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("openrouter", "anthropic/claude-3.5-sonnet"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"max_tokens": config.get("max_tokens", 16000),
|
||||
}
|
||||
elif provider == "ollama":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("ollama", "codellama:13b"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
}
|
||||
|
||||
return cls(provider=provider, config=provider_config)
|
||||
Reference in New Issue
Block a user