first commit
This commit is contained in:
470
tools/ai-review/agents/chat_agent.py
Normal file
470
tools/ai-review/agents/chat_agent.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""Chat Agent (Bartender)
|
||||
|
||||
Interactive AI chat agent with tool use capabilities.
|
||||
Can search the codebase and web to answer user questions.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from clients.llm_client import ToolCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""A message in the chat conversation."""
|
||||
|
||||
role: str # 'user', 'assistant', or 'tool'
|
||||
content: str
|
||||
tool_call_id: str | None = None
|
||||
name: str | None = None # Tool name for tool responses
|
||||
|
||||
|
||||
class ChatAgent(BaseAgent):
|
||||
"""Interactive chat agent with tool capabilities."""
|
||||
|
||||
# Marker for chat responses
|
||||
CHAT_AI_MARKER = "<!-- AI_CHAT_RESPONSE -->"
|
||||
|
||||
# Tool definitions in OpenAI format
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_codebase",
|
||||
"description": "Search the repository codebase for files, functions, classes, or patterns. Use this to find relevant code.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query - can be a filename, function name, class name, or code pattern",
|
||||
},
|
||||
"file_pattern": {
|
||||
"type": "string",
|
||||
"description": "Optional file pattern to filter results (e.g., '*.py', 'src/*.js')",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a specific file from the repository.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read",
|
||||
},
|
||||
},
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_web",
|
||||
"description": "Search the web for information using SearXNG. Use this for external documentation, tutorials, or general knowledge.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query",
|
||||
},
|
||||
"categories": {
|
||||
"type": "string",
|
||||
"description": "Optional: comma-separated categories (general, images, videos, news, science, it)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# System prompt for the chat agent
|
||||
SYSTEM_PROMPT = """You are Bartender, a helpful AI assistant for code review and development tasks.
|
||||
|
||||
You have access to tools to help answer questions:
|
||||
- search_codebase: Search the repository for code, files, functions, or patterns
|
||||
- read_file: Read specific files from the repository
|
||||
- search_web: Search the web for documentation, tutorials, or external information
|
||||
|
||||
When helping users:
|
||||
1. Use tools to gather information before answering questions about code
|
||||
2. Be concise but thorough in your explanations
|
||||
3. Provide code examples when helpful
|
||||
4. If you're unsure, say so and suggest alternatives
|
||||
|
||||
Repository context: {owner}/{repo}
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._searxng_url = self.config.get("agents", {}).get("chat", {}).get(
|
||||
"searxng_url", os.environ.get("SEARXNG_URL", "")
|
||||
)
|
||||
|
||||
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||
"""Check if this agent handles the given event."""
|
||||
agent_config = self.config.get("agents", {}).get("chat", {})
|
||||
if not agent_config.get("enabled", True):
|
||||
return False
|
||||
|
||||
# Handle issue comment with @ai-bot chat or just @ai-bot
|
||||
if event_type == "issue_comment":
|
||||
comment_body = event_data.get("comment", {}).get("body", "")
|
||||
mention_prefix = self.config.get("interaction", {}).get(
|
||||
"mention_prefix", "@ai-bot"
|
||||
)
|
||||
# Check if this is a chat request (any @ai-bot mention that isn't a specific command)
|
||||
if mention_prefix in comment_body:
|
||||
# Check it's not another specific command
|
||||
specific_commands = ["summarize", "explain", "suggest", "security", "codebase"]
|
||||
body_lower = comment_body.lower()
|
||||
for cmd in specific_commands:
|
||||
if f"{mention_prefix} {cmd}" in body_lower:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Handle direct chat command
|
||||
if event_type == "chat":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def execute(self, context: AgentContext) -> AgentResult:
|
||||
"""Execute the chat agent."""
|
||||
self.logger.info(f"Starting chat for {context.owner}/{context.repo}")
|
||||
|
||||
# Extract user message
|
||||
if context.event_type == "issue_comment":
|
||||
user_message = context.event_data.get("comment", {}).get("body", "")
|
||||
issue_index = context.event_data.get("issue", {}).get("number")
|
||||
# Remove the @ai-bot prefix
|
||||
mention_prefix = self.config.get("interaction", {}).get(
|
||||
"mention_prefix", "@ai-bot"
|
||||
)
|
||||
user_message = user_message.replace(mention_prefix, "").strip()
|
||||
else:
|
||||
user_message = context.event_data.get("message", "")
|
||||
issue_index = context.event_data.get("issue_number")
|
||||
|
||||
if not user_message:
|
||||
return AgentResult(
|
||||
success=False,
|
||||
message="No message provided",
|
||||
)
|
||||
|
||||
# Build conversation
|
||||
system_prompt = self.SYSTEM_PROMPT.format(
|
||||
owner=context.owner,
|
||||
repo=context.repo,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
# Run the chat loop with tool execution
|
||||
response_content, tools_used = self._run_chat_loop(
|
||||
context, messages, max_iterations=5
|
||||
)
|
||||
|
||||
actions_taken = []
|
||||
if tools_used:
|
||||
actions_taken.append(f"Used tools: {', '.join(tools_used)}")
|
||||
|
||||
# Post response if this is an issue comment
|
||||
if issue_index:
|
||||
comment_body = self._format_response(response_content)
|
||||
self.upsert_comment(
|
||||
context.owner,
|
||||
context.repo,
|
||||
issue_index,
|
||||
comment_body,
|
||||
marker=self.CHAT_AI_MARKER,
|
||||
)
|
||||
actions_taken.append("Posted chat response")
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
message="Chat completed",
|
||||
data={"response": response_content, "tools_used": tools_used},
|
||||
actions_taken=actions_taken,
|
||||
)
|
||||
|
||||
def _run_chat_loop(
|
||||
self,
|
||||
context: AgentContext,
|
||||
messages: list[dict],
|
||||
max_iterations: int = 5,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Run the chat loop with tool execution.
|
||||
|
||||
Returns:
|
||||
Tuple of (final response content, list of tools used)
|
||||
"""
|
||||
tools_used = []
|
||||
|
||||
for _ in range(max_iterations):
|
||||
self._rate_limit()
|
||||
response = self.llm.call_with_tools(messages, tools=self.TOOLS)
|
||||
|
||||
# If no tool calls, we're done
|
||||
if not response.tool_calls:
|
||||
return response.content, tools_used
|
||||
|
||||
# Add assistant message with tool calls
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": str(tc.arguments),
|
||||
},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
],
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
for tool_call in response.tool_calls:
|
||||
tool_result = self._execute_tool(context, tool_call)
|
||||
tools_used.append(tool_call.name)
|
||||
|
||||
# Add tool result to messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result,
|
||||
})
|
||||
|
||||
# If we hit max iterations, make one final call without tools
|
||||
self._rate_limit()
|
||||
final_response = self.llm.call_with_tools(
|
||||
messages, tools=None, tool_choice="none"
|
||||
)
|
||||
return final_response.content, tools_used
|
||||
|
||||
def _execute_tool(self, context: AgentContext, tool_call: ToolCall) -> str:
|
||||
"""Execute a tool call and return the result."""
|
||||
self.logger.info(f"Executing tool: {tool_call.name}")
|
||||
|
||||
try:
|
||||
if tool_call.name == "search_codebase":
|
||||
return self._tool_search_codebase(
|
||||
context,
|
||||
tool_call.arguments.get("query", ""),
|
||||
tool_call.arguments.get("file_pattern"),
|
||||
)
|
||||
elif tool_call.name == "read_file":
|
||||
return self._tool_read_file(
|
||||
context,
|
||||
tool_call.arguments.get("filepath", ""),
|
||||
)
|
||||
elif tool_call.name == "search_web":
|
||||
return self._tool_search_web(
|
||||
tool_call.arguments.get("query", ""),
|
||||
tool_call.arguments.get("categories"),
|
||||
)
|
||||
else:
|
||||
return f"Unknown tool: {tool_call.name}"
|
||||
except Exception as e:
|
||||
self.logger.error(f"Tool execution failed: {e}")
|
||||
return f"Error executing tool: {e}"
|
||||
|
||||
def _tool_search_codebase(
|
||||
self,
|
||||
context: AgentContext,
|
||||
query: str,
|
||||
file_pattern: str | None = None,
|
||||
) -> str:
|
||||
"""Search the codebase for files matching a query."""
|
||||
results = []
|
||||
|
||||
# Get repository file list
|
||||
try:
|
||||
files = self._collect_files(context.owner, context.repo, file_pattern)
|
||||
except Exception as e:
|
||||
return f"Error listing files: {e}"
|
||||
|
||||
query_lower = query.lower()
|
||||
|
||||
# Search through files
|
||||
for file_info in files[:50]: # Limit to prevent API exhaustion
|
||||
filepath = file_info.get("path", "")
|
||||
|
||||
# Check filename match
|
||||
if query_lower in filepath.lower():
|
||||
results.append(f"File: {filepath}")
|
||||
continue
|
||||
|
||||
# Check content for code patterns
|
||||
try:
|
||||
content_data = self.gitea.get_file_contents(
|
||||
context.owner, context.repo, filepath
|
||||
)
|
||||
if content_data.get("content"):
|
||||
content = base64.b64decode(content_data["content"]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Search for query in content
|
||||
lines = content.splitlines()
|
||||
matching_lines = []
|
||||
for i, line in enumerate(lines, 1):
|
||||
if query_lower in line.lower():
|
||||
matching_lines.append(f" L{i}: {line.strip()[:100]}")
|
||||
|
||||
if matching_lines:
|
||||
results.append(f"File: {filepath}")
|
||||
results.extend(matching_lines[:5]) # Max 5 matches per file
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not results:
|
||||
return f"No results found for '{query}'"
|
||||
|
||||
return "\n".join(results[:30]) # Limit total results
|
||||
|
||||
def _collect_files(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
file_pattern: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Collect files from the repository."""
|
||||
files = []
|
||||
|
||||
# Code extensions to search
|
||||
code_extensions = {
|
||||
".py", ".js", ".ts", ".go", ".rs", ".java", ".rb",
|
||||
".php", ".c", ".cpp", ".h", ".cs", ".swift", ".kt",
|
||||
".md", ".yml", ".yaml", ".json", ".toml",
|
||||
}
|
||||
|
||||
# Patterns to ignore
|
||||
ignore_patterns = [
|
||||
"node_modules/", "vendor/", ".git/", "__pycache__/",
|
||||
".venv/", "dist/", "build/", ".min.js", ".min.css",
|
||||
]
|
||||
|
||||
def traverse(path: str = ""):
|
||||
try:
|
||||
contents = self.gitea.get_file_contents(owner, repo, path or ".")
|
||||
if isinstance(contents, list):
|
||||
for item in contents:
|
||||
item_path = item.get("path", "")
|
||||
|
||||
if any(p in item_path for p in ignore_patterns):
|
||||
continue
|
||||
|
||||
if item.get("type") == "file":
|
||||
ext = os.path.splitext(item_path)[1]
|
||||
if ext in code_extensions:
|
||||
# Check file pattern if provided
|
||||
if file_pattern:
|
||||
if not self._match_pattern(item_path, file_pattern):
|
||||
continue
|
||||
files.append(item)
|
||||
elif item.get("type") == "dir":
|
||||
traverse(item_path)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to list {path}: {e}")
|
||||
|
||||
traverse()
|
||||
return files[:100] # Limit to prevent API exhaustion
|
||||
|
||||
def _match_pattern(self, filepath: str, pattern: str) -> bool:
|
||||
"""Check if filepath matches a simple glob pattern."""
|
||||
import fnmatch
|
||||
return fnmatch.fnmatch(filepath, pattern)
|
||||
|
||||
def _tool_read_file(self, context: AgentContext, filepath: str) -> str:
|
||||
"""Read a file from the repository."""
|
||||
try:
|
||||
content_data = self.gitea.get_file_contents(
|
||||
context.owner, context.repo, filepath
|
||||
)
|
||||
if content_data.get("content"):
|
||||
content = base64.b64decode(content_data["content"]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# Truncate if too long
|
||||
if len(content) > 8000:
|
||||
content = content[:8000] + "\n... (truncated)"
|
||||
return f"File: {filepath}\n\n```\n{content}\n```"
|
||||
return f"File not found: {filepath}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
def _tool_search_web(
|
||||
self,
|
||||
query: str,
|
||||
categories: str | None = None,
|
||||
) -> str:
|
||||
"""Search the web using SearXNG."""
|
||||
if not self._searxng_url:
|
||||
return "Web search is not configured. Set SEARXNG_URL environment variable."
|
||||
|
||||
try:
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
}
|
||||
if categories:
|
||||
params["categories"] = categories
|
||||
|
||||
response = requests.get(
|
||||
f"{self._searxng_url}/search",
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = data.get("results", [])
|
||||
if not results:
|
||||
return f"No web results found for '{query}'"
|
||||
|
||||
# Format results
|
||||
output = []
|
||||
for i, result in enumerate(results[:5], 1): # Top 5 results
|
||||
title = result.get("title", "No title")
|
||||
url = result.get("url", "")
|
||||
content = result.get("content", "")[:200]
|
||||
output.append(f"{i}. **{title}**\n {url}\n {content}")
|
||||
|
||||
return "\n\n".join(output)
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Web search failed: {e}"
|
||||
|
||||
def _format_response(self, content: str) -> str:
|
||||
"""Format the chat response with disclaimer."""
|
||||
lines = [
|
||||
f"{self.AI_DISCLAIMER}",
|
||||
"",
|
||||
"---",
|
||||
"",
|
||||
content,
|
||||
]
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user