258 lines
7.3 KiB
Python
258 lines
7.3 KiB
Python
"""Base Agent
|
|
|
|
Abstract base class for all AI agents. Provides common functionality
|
|
for Gitea API interaction, LLM calls, logging, and rate limiting.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from clients.gitea_client import GiteaClient
|
|
from clients.llm_client import LLMClient, LLMResponse
|
|
|
|
|
|
@dataclass
|
|
class AgentContext:
|
|
"""Context passed to agent during execution."""
|
|
|
|
owner: str
|
|
repo: str
|
|
event_type: str
|
|
event_data: dict
|
|
config: dict = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class AgentResult:
|
|
"""Result from agent execution."""
|
|
|
|
success: bool
|
|
message: str
|
|
data: dict = field(default_factory=dict)
|
|
actions_taken: list[str] = field(default_factory=list)
|
|
error: str | None = None
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
"""Abstract base class for AI agents."""
|
|
|
|
# Marker to identify AI-generated comments
|
|
AI_MARKER = "<!-- AI_CODE_REVIEW -->"
|
|
|
|
# Disclaimer text
|
|
AI_DISCLAIMER = (
|
|
"**Note:** This review was generated by an AI assistant. "
|
|
"While it aims to be accurate and helpful, it may contain mistakes "
|
|
"or miss important issues. Please verify all findings before taking action."
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
config: dict | None = None,
|
|
gitea_client: GiteaClient | None = None,
|
|
llm_client: LLMClient | None = None,
|
|
):
|
|
"""Initialize the base agent.
|
|
|
|
Args:
|
|
config: Agent configuration dictionary.
|
|
gitea_client: Optional pre-configured Gitea client.
|
|
llm_client: Optional pre-configured LLM client.
|
|
"""
|
|
self.config = config or self._load_config()
|
|
self.gitea = gitea_client or GiteaClient()
|
|
self.llm = llm_client or LLMClient.from_config(self.config)
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
|
|
# Rate limiting
|
|
self._last_request_time = 0.0
|
|
self._min_request_interval = 1.0 # seconds
|
|
|
|
@staticmethod
|
|
def _load_config() -> dict:
|
|
"""Load configuration from config.yml."""
|
|
config_path = os.path.join(os.path.dirname(__file__), "..", "config.yml")
|
|
if os.path.exists(config_path):
|
|
with open(config_path) as f:
|
|
return yaml.safe_load(f)
|
|
return {}
|
|
|
|
def _rate_limit(self):
|
|
"""Apply rate limiting between requests."""
|
|
elapsed = time.time() - self._last_request_time
|
|
if elapsed < self._min_request_interval:
|
|
time.sleep(self._min_request_interval - elapsed)
|
|
self._last_request_time = time.time()
|
|
|
|
def load_prompt(self, prompt_name: str) -> str:
|
|
"""Load a prompt template from the prompts directory.
|
|
|
|
Args:
|
|
prompt_name: Name of the prompt file (without extension).
|
|
|
|
Returns:
|
|
Prompt template content.
|
|
"""
|
|
prompt_path = os.path.join(
|
|
os.path.dirname(__file__), "..", "prompts", f"{prompt_name}.md"
|
|
)
|
|
if not os.path.exists(prompt_path):
|
|
raise FileNotFoundError(f"Prompt not found: {prompt_path}")
|
|
with open(prompt_path) as f:
|
|
return f.read()
|
|
|
|
def call_llm(self, prompt: str, **kwargs) -> LLMResponse:
|
|
"""Make a rate-limited call to the LLM.
|
|
|
|
Args:
|
|
prompt: The prompt to send.
|
|
**kwargs: Additional LLM options.
|
|
|
|
Returns:
|
|
LLM response.
|
|
"""
|
|
self._rate_limit()
|
|
return self.llm.call(prompt, **kwargs)
|
|
|
|
def call_llm_json(self, prompt: str, **kwargs) -> dict:
|
|
"""Make a rate-limited call and parse JSON response.
|
|
|
|
Args:
|
|
prompt: The prompt to send.
|
|
**kwargs: Additional LLM options.
|
|
|
|
Returns:
|
|
Parsed JSON response.
|
|
"""
|
|
self._rate_limit()
|
|
return self.llm.call_json(prompt, **kwargs)
|
|
|
|
def find_ai_comment(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
issue_index: int,
|
|
marker: str | None = None,
|
|
) -> int | None:
|
|
"""Find an existing AI comment by marker.
|
|
|
|
Args:
|
|
owner: Repository owner.
|
|
repo: Repository name.
|
|
issue_index: Issue or PR number.
|
|
marker: Custom marker to search for. Defaults to AI_MARKER.
|
|
|
|
Returns:
|
|
Comment ID if found, None otherwise.
|
|
"""
|
|
marker = marker or self.AI_MARKER
|
|
comments = self.gitea.list_issue_comments(owner, repo, issue_index)
|
|
|
|
for comment in comments:
|
|
if marker in comment.get("body", ""):
|
|
return comment["id"]
|
|
|
|
return None
|
|
|
|
def upsert_comment(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
issue_index: int,
|
|
body: str,
|
|
marker: str | None = None,
|
|
) -> dict:
|
|
"""Create or update an AI comment.
|
|
|
|
Args:
|
|
owner: Repository owner.
|
|
repo: Repository name.
|
|
issue_index: Issue or PR number.
|
|
body: Comment body (marker will be prepended if not present).
|
|
marker: Custom marker. Defaults to AI_MARKER.
|
|
|
|
Returns:
|
|
Created or updated comment.
|
|
"""
|
|
marker = marker or self.AI_MARKER
|
|
|
|
# Ensure marker is in the body
|
|
if marker not in body:
|
|
body = f"{marker}\n{body}"
|
|
|
|
# Check for existing comment
|
|
existing_id = self.find_ai_comment(owner, repo, issue_index, marker)
|
|
|
|
if existing_id:
|
|
return self.gitea.update_issue_comment(owner, repo, existing_id, body)
|
|
else:
|
|
return self.gitea.create_issue_comment(owner, repo, issue_index, body)
|
|
|
|
def format_with_disclaimer(self, content: str) -> str:
|
|
"""Add AI disclaimer to content.
|
|
|
|
Args:
|
|
content: The main content.
|
|
|
|
Returns:
|
|
Content with disclaimer prepended.
|
|
"""
|
|
return f"{self.AI_DISCLAIMER}\n\n{self.AI_MARKER}\n{content}"
|
|
|
|
@abstractmethod
|
|
def execute(self, context: AgentContext) -> AgentResult:
|
|
"""Execute the agent's main task.
|
|
|
|
Args:
|
|
context: Execution context with event data.
|
|
|
|
Returns:
|
|
Result of the agent execution.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
|
"""Check if this agent can handle the given event.
|
|
|
|
Args:
|
|
event_type: Type of event (issue, pull_request, etc).
|
|
event_data: Event payload data.
|
|
|
|
Returns:
|
|
True if this agent can handle the event.
|
|
"""
|
|
pass
|
|
|
|
def run(self, context: AgentContext) -> AgentResult:
|
|
"""Run the agent with error handling.
|
|
|
|
Args:
|
|
context: Execution context.
|
|
|
|
Returns:
|
|
Agent result, including any errors.
|
|
"""
|
|
try:
|
|
self.logger.info(
|
|
f"Running {self.__class__.__name__} for {context.owner}/{context.repo}"
|
|
)
|
|
result = self.execute(context)
|
|
self.logger.info(
|
|
f"Completed with success={result.success}: {result.message}"
|
|
)
|
|
return result
|
|
except Exception as e:
|
|
self.logger.exception(f"Agent execution failed: {e}")
|
|
return AgentResult(
|
|
success=False,
|
|
message="Agent execution failed",
|
|
error=str(e),
|
|
)
|