Files
openrabbit/tools/ai-review/agents/base_agent.py
2025-12-21 13:42:30 +01:00

258 lines
7.3 KiB
Python

"""Base Agent
Abstract base class for all AI agents. Provides common functionality
for Gitea API interaction, LLM calls, logging, and rate limiting.
"""
import logging
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import yaml
from clients.gitea_client import GiteaClient
from clients.llm_client import LLMClient, LLMResponse
@dataclass
class AgentContext:
"""Context passed to agent during execution."""
owner: str
repo: str
event_type: str
event_data: dict
config: dict = field(default_factory=dict)
@dataclass
class AgentResult:
"""Result from agent execution."""
success: bool
message: str
data: dict = field(default_factory=dict)
actions_taken: list[str] = field(default_factory=list)
error: str | None = None
class BaseAgent(ABC):
"""Abstract base class for AI agents."""
# Marker to identify AI-generated comments
AI_MARKER = "<!-- AI_CODE_REVIEW -->"
# Disclaimer text
AI_DISCLAIMER = (
"**Note:** This review was generated by an AI assistant. "
"While it aims to be accurate and helpful, it may contain mistakes "
"or miss important issues. Please verify all findings before taking action."
)
def __init__(
self,
config: dict | None = None,
gitea_client: GiteaClient | None = None,
llm_client: LLMClient | None = None,
):
"""Initialize the base agent.
Args:
config: Agent configuration dictionary.
gitea_client: Optional pre-configured Gitea client.
llm_client: Optional pre-configured LLM client.
"""
self.config = config or self._load_config()
self.gitea = gitea_client or GiteaClient()
self.llm = llm_client or LLMClient.from_config(self.config)
self.logger = logging.getLogger(self.__class__.__name__)
# Rate limiting
self._last_request_time = 0.0
self._min_request_interval = 1.0 # seconds
@staticmethod
def _load_config() -> dict:
"""Load configuration from config.yml."""
config_path = os.path.join(os.path.dirname(__file__), "..", "config.yml")
if os.path.exists(config_path):
with open(config_path) as f:
return yaml.safe_load(f)
return {}
def _rate_limit(self):
"""Apply rate limiting between requests."""
elapsed = time.time() - self._last_request_time
if elapsed < self._min_request_interval:
time.sleep(self._min_request_interval - elapsed)
self._last_request_time = time.time()
def load_prompt(self, prompt_name: str) -> str:
"""Load a prompt template from the prompts directory.
Args:
prompt_name: Name of the prompt file (without extension).
Returns:
Prompt template content.
"""
prompt_path = os.path.join(
os.path.dirname(__file__), "..", "prompts", f"{prompt_name}.md"
)
if not os.path.exists(prompt_path):
raise FileNotFoundError(f"Prompt not found: {prompt_path}")
with open(prompt_path) as f:
return f.read()
def call_llm(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a rate-limited call to the LLM.
Args:
prompt: The prompt to send.
**kwargs: Additional LLM options.
Returns:
LLM response.
"""
self._rate_limit()
return self.llm.call(prompt, **kwargs)
def call_llm_json(self, prompt: str, **kwargs) -> dict:
"""Make a rate-limited call and parse JSON response.
Args:
prompt: The prompt to send.
**kwargs: Additional LLM options.
Returns:
Parsed JSON response.
"""
self._rate_limit()
return self.llm.call_json(prompt, **kwargs)
def find_ai_comment(
self,
owner: str,
repo: str,
issue_index: int,
marker: str | None = None,
) -> int | None:
"""Find an existing AI comment by marker.
Args:
owner: Repository owner.
repo: Repository name.
issue_index: Issue or PR number.
marker: Custom marker to search for. Defaults to AI_MARKER.
Returns:
Comment ID if found, None otherwise.
"""
marker = marker or self.AI_MARKER
comments = self.gitea.list_issue_comments(owner, repo, issue_index)
for comment in comments:
if marker in comment.get("body", ""):
return comment["id"]
return None
def upsert_comment(
self,
owner: str,
repo: str,
issue_index: int,
body: str,
marker: str | None = None,
) -> dict:
"""Create or update an AI comment.
Args:
owner: Repository owner.
repo: Repository name.
issue_index: Issue or PR number.
body: Comment body (marker will be prepended if not present).
marker: Custom marker. Defaults to AI_MARKER.
Returns:
Created or updated comment.
"""
marker = marker or self.AI_MARKER
# Ensure marker is in the body
if marker not in body:
body = f"{marker}\n{body}"
# Check for existing comment
existing_id = self.find_ai_comment(owner, repo, issue_index, marker)
if existing_id:
return self.gitea.update_issue_comment(owner, repo, existing_id, body)
else:
return self.gitea.create_issue_comment(owner, repo, issue_index, body)
def format_with_disclaimer(self, content: str) -> str:
"""Add AI disclaimer to content.
Args:
content: The main content.
Returns:
Content with disclaimer prepended.
"""
return f"{self.AI_DISCLAIMER}\n\n{self.AI_MARKER}\n{content}"
@abstractmethod
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the agent's main task.
Args:
context: Execution context with event data.
Returns:
Result of the agent execution.
"""
pass
@abstractmethod
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent can handle the given event.
Args:
event_type: Type of event (issue, pull_request, etc).
event_data: Event payload data.
Returns:
True if this agent can handle the event.
"""
pass
def run(self, context: AgentContext) -> AgentResult:
"""Run the agent with error handling.
Args:
context: Execution context.
Returns:
Agent result, including any errors.
"""
try:
self.logger.info(
f"Running {self.__class__.__name__} for {context.owner}/{context.repo}"
)
result = self.execute(context)
self.logger.info(
f"Completed with success={result.success}: {result.message}"
)
return result
except Exception as e:
self.logger.exception(f"Agent execution failed: {e}")
return AgentResult(
success=False,
message="Agent execution failed",
error=str(e),
)