first commit
This commit is contained in:
257
tools/ai-review/agents/base_agent.py
Normal file
257
tools/ai-review/agents/base_agent.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Base Agent
|
||||
|
||||
Abstract base class for all AI agents. Provides common functionality
|
||||
for Gitea API interaction, LLM calls, logging, and rate limiting.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from clients.gitea_client import GiteaClient
|
||||
from clients.llm_client import LLMClient, LLMResponse
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentContext:
|
||||
"""Context passed to agent during execution."""
|
||||
|
||||
owner: str
|
||||
repo: str
|
||||
event_type: str
|
||||
event_data: dict
|
||||
config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result from agent execution."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: dict = field(default_factory=dict)
|
||||
actions_taken: list[str] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Abstract base class for AI agents."""
|
||||
|
||||
# Marker to identify AI-generated comments
|
||||
AI_MARKER = "<!-- AI_CODE_REVIEW -->"
|
||||
|
||||
# Disclaimer text
|
||||
AI_DISCLAIMER = (
|
||||
"**Note:** This review was generated by an AI assistant. "
|
||||
"While it aims to be accurate and helpful, it may contain mistakes "
|
||||
"or miss important issues. Please verify all findings before taking action."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict | None = None,
|
||||
gitea_client: GiteaClient | None = None,
|
||||
llm_client: LLMClient | None = None,
|
||||
):
|
||||
"""Initialize the base agent.
|
||||
|
||||
Args:
|
||||
config: Agent configuration dictionary.
|
||||
gitea_client: Optional pre-configured Gitea client.
|
||||
llm_client: Optional pre-configured LLM client.
|
||||
"""
|
||||
self.config = config or self._load_config()
|
||||
self.gitea = gitea_client or GiteaClient()
|
||||
self.llm = llm_client or LLMClient.from_config(self.config)
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Rate limiting
|
||||
self._last_request_time = 0.0
|
||||
self._min_request_interval = 1.0 # seconds
|
||||
|
||||
@staticmethod
|
||||
def _load_config() -> dict:
|
||||
"""Load configuration from config.yml."""
|
||||
config_path = os.path.join(os.path.dirname(__file__), "..", "config.yml")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
return {}
|
||||
|
||||
def _rate_limit(self):
|
||||
"""Apply rate limiting between requests."""
|
||||
elapsed = time.time() - self._last_request_time
|
||||
if elapsed < self._min_request_interval:
|
||||
time.sleep(self._min_request_interval - elapsed)
|
||||
self._last_request_time = time.time()
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""Load a prompt template from the prompts directory.
|
||||
|
||||
Args:
|
||||
prompt_name: Name of the prompt file (without extension).
|
||||
|
||||
Returns:
|
||||
Prompt template content.
|
||||
"""
|
||||
prompt_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "prompts", f"{prompt_name}.md"
|
||||
)
|
||||
if not os.path.exists(prompt_path):
|
||||
raise FileNotFoundError(f"Prompt not found: {prompt_path}")
|
||||
with open(prompt_path) as f:
|
||||
return f.read()
|
||||
|
||||
def call_llm(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a rate-limited call to the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional LLM options.
|
||||
|
||||
Returns:
|
||||
LLM response.
|
||||
"""
|
||||
self._rate_limit()
|
||||
return self.llm.call(prompt, **kwargs)
|
||||
|
||||
def call_llm_json(self, prompt: str, **kwargs) -> dict:
|
||||
"""Make a rate-limited call and parse JSON response.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional LLM options.
|
||||
|
||||
Returns:
|
||||
Parsed JSON response.
|
||||
"""
|
||||
self._rate_limit()
|
||||
return self.llm.call_json(prompt, **kwargs)
|
||||
|
||||
def find_ai_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_index: int,
|
||||
marker: str | None = None,
|
||||
) -> int | None:
|
||||
"""Find an existing AI comment by marker.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
issue_index: Issue or PR number.
|
||||
marker: Custom marker to search for. Defaults to AI_MARKER.
|
||||
|
||||
Returns:
|
||||
Comment ID if found, None otherwise.
|
||||
"""
|
||||
marker = marker or self.AI_MARKER
|
||||
comments = self.gitea.list_issue_comments(owner, repo, issue_index)
|
||||
|
||||
for comment in comments:
|
||||
if marker in comment.get("body", ""):
|
||||
return comment["id"]
|
||||
|
||||
return None
|
||||
|
||||
def upsert_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_index: int,
|
||||
body: str,
|
||||
marker: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update an AI comment.
|
||||
|
||||
Args:
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
issue_index: Issue or PR number.
|
||||
body: Comment body (marker will be prepended if not present).
|
||||
marker: Custom marker. Defaults to AI_MARKER.
|
||||
|
||||
Returns:
|
||||
Created or updated comment.
|
||||
"""
|
||||
marker = marker or self.AI_MARKER
|
||||
|
||||
# Ensure marker is in the body
|
||||
if marker not in body:
|
||||
body = f"{marker}\n{body}"
|
||||
|
||||
# Check for existing comment
|
||||
existing_id = self.find_ai_comment(owner, repo, issue_index, marker)
|
||||
|
||||
if existing_id:
|
||||
return self.gitea.update_issue_comment(owner, repo, existing_id, body)
|
||||
else:
|
||||
return self.gitea.create_issue_comment(owner, repo, issue_index, body)
|
||||
|
||||
def format_with_disclaimer(self, content: str) -> str:
|
||||
"""Add AI disclaimer to content.
|
||||
|
||||
Args:
|
||||
content: The main content.
|
||||
|
||||
Returns:
|
||||
Content with disclaimer prepended.
|
||||
"""
|
||||
return f"{self.AI_DISCLAIMER}\n\n{self.AI_MARKER}\n{content}"
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, context: AgentContext) -> AgentResult:
|
||||
"""Execute the agent's main task.
|
||||
|
||||
Args:
|
||||
context: Execution context with event data.
|
||||
|
||||
Returns:
|
||||
Result of the agent execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||
"""Check if this agent can handle the given event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (issue, pull_request, etc).
|
||||
event_data: Event payload data.
|
||||
|
||||
Returns:
|
||||
True if this agent can handle the event.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, context: AgentContext) -> AgentResult:
|
||||
"""Run the agent with error handling.
|
||||
|
||||
Args:
|
||||
context: Execution context.
|
||||
|
||||
Returns:
|
||||
Agent result, including any errors.
|
||||
"""
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Running {self.__class__.__name__} for {context.owner}/{context.repo}"
|
||||
)
|
||||
result = self.execute(context)
|
||||
self.logger.info(
|
||||
f"Completed with success={result.success}: {result.message}"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Agent execution failed: {e}")
|
||||
return AgentResult(
|
||||
success=False,
|
||||
message="Agent execution failed",
|
||||
error=str(e),
|
||||
)
|
||||
Reference in New Issue
Block a user