Files
openrabbit/tools/ai-review/agents/chat_agent.py
2025-12-21 13:42:30 +01:00

471 lines
16 KiB
Python

"""Chat Agent (Bartender)
Interactive AI chat agent with tool use capabilities.
Can search the codebase and web to answer user questions.
"""
import base64
import logging
import os
import re
from dataclasses import dataclass
import requests
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from clients.llm_client import ToolCall
@dataclass
class ChatMessage:
"""A message in the chat conversation."""
role: str # 'user', 'assistant', or 'tool'
content: str
tool_call_id: str | None = None
name: str | None = None # Tool name for tool responses
class ChatAgent(BaseAgent):
"""Interactive chat agent with tool capabilities."""
# Marker for chat responses
CHAT_AI_MARKER = "<!-- AI_CHAT_RESPONSE -->"
# Tool definitions in OpenAI format
TOOLS = [
{
"type": "function",
"function": {
"name": "search_codebase",
"description": "Search the repository codebase for files, functions, classes, or patterns. Use this to find relevant code.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query - can be a filename, function name, class name, or code pattern",
},
"file_pattern": {
"type": "string",
"description": "Optional file pattern to filter results (e.g., '*.py', 'src/*.js')",
},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read the contents of a specific file from the repository.",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file to read",
},
},
"required": ["filepath"],
},
},
},
{
"type": "function",
"function": {
"name": "search_web",
"description": "Search the web for information using SearXNG. Use this for external documentation, tutorials, or general knowledge.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"categories": {
"type": "string",
"description": "Optional: comma-separated categories (general, images, videos, news, science, it)",
},
},
"required": ["query"],
},
},
},
]
# System prompt for the chat agent
SYSTEM_PROMPT = """You are Bartender, a helpful AI assistant for code review and development tasks.
You have access to tools to help answer questions:
- search_codebase: Search the repository for code, files, functions, or patterns
- read_file: Read specific files from the repository
- search_web: Search the web for documentation, tutorials, or external information
When helping users:
1. Use tools to gather information before answering questions about code
2. Be concise but thorough in your explanations
3. Provide code examples when helpful
4. If you're unsure, say so and suggest alternatives
Repository context: {owner}/{repo}
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._searxng_url = self.config.get("agents", {}).get("chat", {}).get(
"searxng_url", os.environ.get("SEARXNG_URL", "")
)
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("chat", {})
if not agent_config.get("enabled", True):
return False
# Handle issue comment with @ai-bot chat or just @ai-bot
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@ai-bot"
)
# Check if this is a chat request (any @ai-bot mention that isn't a specific command)
if mention_prefix in comment_body:
# Check it's not another specific command
specific_commands = ["summarize", "explain", "suggest", "security", "codebase"]
body_lower = comment_body.lower()
for cmd in specific_commands:
if f"{mention_prefix} {cmd}" in body_lower:
return False
return True
# Handle direct chat command
if event_type == "chat":
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the chat agent."""
self.logger.info(f"Starting chat for {context.owner}/{context.repo}")
# Extract user message
if context.event_type == "issue_comment":
user_message = context.event_data.get("comment", {}).get("body", "")
issue_index = context.event_data.get("issue", {}).get("number")
# Remove the @ai-bot prefix
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@ai-bot"
)
user_message = user_message.replace(mention_prefix, "").strip()
else:
user_message = context.event_data.get("message", "")
issue_index = context.event_data.get("issue_number")
if not user_message:
return AgentResult(
success=False,
message="No message provided",
)
# Build conversation
system_prompt = self.SYSTEM_PROMPT.format(
owner=context.owner,
repo=context.repo,
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
# Run the chat loop with tool execution
response_content, tools_used = self._run_chat_loop(
context, messages, max_iterations=5
)
actions_taken = []
if tools_used:
actions_taken.append(f"Used tools: {', '.join(tools_used)}")
# Post response if this is an issue comment
if issue_index:
comment_body = self._format_response(response_content)
self.upsert_comment(
context.owner,
context.repo,
issue_index,
comment_body,
marker=self.CHAT_AI_MARKER,
)
actions_taken.append("Posted chat response")
return AgentResult(
success=True,
message="Chat completed",
data={"response": response_content, "tools_used": tools_used},
actions_taken=actions_taken,
)
def _run_chat_loop(
self,
context: AgentContext,
messages: list[dict],
max_iterations: int = 5,
) -> tuple[str, list[str]]:
"""Run the chat loop with tool execution.
Returns:
Tuple of (final response content, list of tools used)
"""
tools_used = []
for _ in range(max_iterations):
self._rate_limit()
response = self.llm.call_with_tools(messages, tools=self.TOOLS)
# If no tool calls, we're done
if not response.tool_calls:
return response.content, tools_used
# Add assistant message with tool calls
messages.append({
"role": "assistant",
"content": response.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": str(tc.arguments),
},
}
for tc in response.tool_calls
],
})
# Execute each tool call
for tool_call in response.tool_calls:
tool_result = self._execute_tool(context, tool_call)
tools_used.append(tool_call.name)
# Add tool result to messages
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
})
# If we hit max iterations, make one final call without tools
self._rate_limit()
final_response = self.llm.call_with_tools(
messages, tools=None, tool_choice="none"
)
return final_response.content, tools_used
def _execute_tool(self, context: AgentContext, tool_call: ToolCall) -> str:
"""Execute a tool call and return the result."""
self.logger.info(f"Executing tool: {tool_call.name}")
try:
if tool_call.name == "search_codebase":
return self._tool_search_codebase(
context,
tool_call.arguments.get("query", ""),
tool_call.arguments.get("file_pattern"),
)
elif tool_call.name == "read_file":
return self._tool_read_file(
context,
tool_call.arguments.get("filepath", ""),
)
elif tool_call.name == "search_web":
return self._tool_search_web(
tool_call.arguments.get("query", ""),
tool_call.arguments.get("categories"),
)
else:
return f"Unknown tool: {tool_call.name}"
except Exception as e:
self.logger.error(f"Tool execution failed: {e}")
return f"Error executing tool: {e}"
def _tool_search_codebase(
self,
context: AgentContext,
query: str,
file_pattern: str | None = None,
) -> str:
"""Search the codebase for files matching a query."""
results = []
# Get repository file list
try:
files = self._collect_files(context.owner, context.repo, file_pattern)
except Exception as e:
return f"Error listing files: {e}"
query_lower = query.lower()
# Search through files
for file_info in files[:50]: # Limit to prevent API exhaustion
filepath = file_info.get("path", "")
# Check filename match
if query_lower in filepath.lower():
results.append(f"File: {filepath}")
continue
# Check content for code patterns
try:
content_data = self.gitea.get_file_contents(
context.owner, context.repo, filepath
)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
# Search for query in content
lines = content.splitlines()
matching_lines = []
for i, line in enumerate(lines, 1):
if query_lower in line.lower():
matching_lines.append(f" L{i}: {line.strip()[:100]}")
if matching_lines:
results.append(f"File: {filepath}")
results.extend(matching_lines[:5]) # Max 5 matches per file
except Exception:
pass
if not results:
return f"No results found for '{query}'"
return "\n".join(results[:30]) # Limit total results
def _collect_files(
self,
owner: str,
repo: str,
file_pattern: str | None = None,
) -> list[dict]:
"""Collect files from the repository."""
files = []
# Code extensions to search
code_extensions = {
".py", ".js", ".ts", ".go", ".rs", ".java", ".rb",
".php", ".c", ".cpp", ".h", ".cs", ".swift", ".kt",
".md", ".yml", ".yaml", ".json", ".toml",
}
# Patterns to ignore
ignore_patterns = [
"node_modules/", "vendor/", ".git/", "__pycache__/",
".venv/", "dist/", "build/", ".min.js", ".min.css",
]
def traverse(path: str = ""):
try:
contents = self.gitea.get_file_contents(owner, repo, path or ".")
if isinstance(contents, list):
for item in contents:
item_path = item.get("path", "")
if any(p in item_path for p in ignore_patterns):
continue
if item.get("type") == "file":
ext = os.path.splitext(item_path)[1]
if ext in code_extensions:
# Check file pattern if provided
if file_pattern:
if not self._match_pattern(item_path, file_pattern):
continue
files.append(item)
elif item.get("type") == "dir":
traverse(item_path)
except Exception as e:
self.logger.warning(f"Failed to list {path}: {e}")
traverse()
return files[:100] # Limit to prevent API exhaustion
def _match_pattern(self, filepath: str, pattern: str) -> bool:
"""Check if filepath matches a simple glob pattern."""
import fnmatch
return fnmatch.fnmatch(filepath, pattern)
def _tool_read_file(self, context: AgentContext, filepath: str) -> str:
"""Read a file from the repository."""
try:
content_data = self.gitea.get_file_contents(
context.owner, context.repo, filepath
)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
# Truncate if too long
if len(content) > 8000:
content = content[:8000] + "\n... (truncated)"
return f"File: {filepath}\n\n```\n{content}\n```"
return f"File not found: {filepath}"
except Exception as e:
return f"Error reading file: {e}"
def _tool_search_web(
self,
query: str,
categories: str | None = None,
) -> str:
"""Search the web using SearXNG."""
if not self._searxng_url:
return "Web search is not configured. Set SEARXNG_URL environment variable."
try:
params = {
"q": query,
"format": "json",
}
if categories:
params["categories"] = categories
response = requests.get(
f"{self._searxng_url}/search",
params=params,
timeout=30,
)
response.raise_for_status()
data = response.json()
results = data.get("results", [])
if not results:
return f"No web results found for '{query}'"
# Format results
output = []
for i, result in enumerate(results[:5], 1): # Top 5 results
title = result.get("title", "No title")
url = result.get("url", "")
content = result.get("content", "")[:200]
output.append(f"{i}. **{title}**\n {url}\n {content}")
return "\n\n".join(output)
except requests.exceptions.RequestException as e:
return f"Web search failed: {e}"
def _format_response(self, content: str) -> str:
"""Format the chat response with disclaimer."""
lines = [
f"{self.AI_DISCLAIMER}",
"",
"---",
"",
content,
]
return "\n".join(lines)