471 lines
16 KiB
Python
471 lines
16 KiB
Python
"""Chat Agent (Bartender)
|
|
|
|
Interactive AI chat agent with tool use capabilities.
|
|
Can search the codebase and web to answer user questions.
|
|
"""
|
|
|
|
import base64
|
|
import logging
|
|
import os
|
|
import re
|
|
from dataclasses import dataclass
|
|
|
|
import requests
|
|
|
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
|
from clients.llm_client import ToolCall
|
|
|
|
|
|
@dataclass
|
|
class ChatMessage:
|
|
"""A message in the chat conversation."""
|
|
|
|
role: str # 'user', 'assistant', or 'tool'
|
|
content: str
|
|
tool_call_id: str | None = None
|
|
name: str | None = None # Tool name for tool responses
|
|
|
|
|
|
class ChatAgent(BaseAgent):
|
|
"""Interactive chat agent with tool capabilities."""
|
|
|
|
# Marker for chat responses
|
|
CHAT_AI_MARKER = "<!-- AI_CHAT_RESPONSE -->"
|
|
|
|
# Tool definitions in OpenAI format
|
|
TOOLS = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_codebase",
|
|
"description": "Search the repository codebase for files, functions, classes, or patterns. Use this to find relevant code.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search query - can be a filename, function name, class name, or code pattern",
|
|
},
|
|
"file_pattern": {
|
|
"type": "string",
|
|
"description": "Optional file pattern to filter results (e.g., '*.py', 'src/*.js')",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"description": "Read the contents of a specific file from the repository.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"filepath": {
|
|
"type": "string",
|
|
"description": "Path to the file to read",
|
|
},
|
|
},
|
|
"required": ["filepath"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search_web",
|
|
"description": "Search the web for information using SearXNG. Use this for external documentation, tutorials, or general knowledge.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search query",
|
|
},
|
|
"categories": {
|
|
"type": "string",
|
|
"description": "Optional: comma-separated categories (general, images, videos, news, science, it)",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
# System prompt for the chat agent
|
|
SYSTEM_PROMPT = """You are Bartender, a helpful AI assistant for code review and development tasks.
|
|
|
|
You have access to tools to help answer questions:
|
|
- search_codebase: Search the repository for code, files, functions, or patterns
|
|
- read_file: Read specific files from the repository
|
|
- search_web: Search the web for documentation, tutorials, or external information
|
|
|
|
When helping users:
|
|
1. Use tools to gather information before answering questions about code
|
|
2. Be concise but thorough in your explanations
|
|
3. Provide code examples when helpful
|
|
4. If you're unsure, say so and suggest alternatives
|
|
|
|
Repository context: {owner}/{repo}
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._searxng_url = self.config.get("agents", {}).get("chat", {}).get(
|
|
"searxng_url", os.environ.get("SEARXNG_URL", "")
|
|
)
|
|
|
|
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
|
"""Check if this agent handles the given event."""
|
|
agent_config = self.config.get("agents", {}).get("chat", {})
|
|
if not agent_config.get("enabled", True):
|
|
return False
|
|
|
|
# Handle issue comment with @ai-bot chat or just @ai-bot
|
|
if event_type == "issue_comment":
|
|
comment_body = event_data.get("comment", {}).get("body", "")
|
|
mention_prefix = self.config.get("interaction", {}).get(
|
|
"mention_prefix", "@ai-bot"
|
|
)
|
|
# Check if this is a chat request (any @ai-bot mention that isn't a specific command)
|
|
if mention_prefix in comment_body:
|
|
# Check it's not another specific command
|
|
specific_commands = ["summarize", "explain", "suggest", "security", "codebase"]
|
|
body_lower = comment_body.lower()
|
|
for cmd in specific_commands:
|
|
if f"{mention_prefix} {cmd}" in body_lower:
|
|
return False
|
|
return True
|
|
|
|
# Handle direct chat command
|
|
if event_type == "chat":
|
|
return True
|
|
|
|
return False
|
|
|
|
def execute(self, context: AgentContext) -> AgentResult:
|
|
"""Execute the chat agent."""
|
|
self.logger.info(f"Starting chat for {context.owner}/{context.repo}")
|
|
|
|
# Extract user message
|
|
if context.event_type == "issue_comment":
|
|
user_message = context.event_data.get("comment", {}).get("body", "")
|
|
issue_index = context.event_data.get("issue", {}).get("number")
|
|
# Remove the @ai-bot prefix
|
|
mention_prefix = self.config.get("interaction", {}).get(
|
|
"mention_prefix", "@ai-bot"
|
|
)
|
|
user_message = user_message.replace(mention_prefix, "").strip()
|
|
else:
|
|
user_message = context.event_data.get("message", "")
|
|
issue_index = context.event_data.get("issue_number")
|
|
|
|
if not user_message:
|
|
return AgentResult(
|
|
success=False,
|
|
message="No message provided",
|
|
)
|
|
|
|
# Build conversation
|
|
system_prompt = self.SYSTEM_PROMPT.format(
|
|
owner=context.owner,
|
|
repo=context.repo,
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_message},
|
|
]
|
|
|
|
# Run the chat loop with tool execution
|
|
response_content, tools_used = self._run_chat_loop(
|
|
context, messages, max_iterations=5
|
|
)
|
|
|
|
actions_taken = []
|
|
if tools_used:
|
|
actions_taken.append(f"Used tools: {', '.join(tools_used)}")
|
|
|
|
# Post response if this is an issue comment
|
|
if issue_index:
|
|
comment_body = self._format_response(response_content)
|
|
self.upsert_comment(
|
|
context.owner,
|
|
context.repo,
|
|
issue_index,
|
|
comment_body,
|
|
marker=self.CHAT_AI_MARKER,
|
|
)
|
|
actions_taken.append("Posted chat response")
|
|
|
|
return AgentResult(
|
|
success=True,
|
|
message="Chat completed",
|
|
data={"response": response_content, "tools_used": tools_used},
|
|
actions_taken=actions_taken,
|
|
)
|
|
|
|
def _run_chat_loop(
|
|
self,
|
|
context: AgentContext,
|
|
messages: list[dict],
|
|
max_iterations: int = 5,
|
|
) -> tuple[str, list[str]]:
|
|
"""Run the chat loop with tool execution.
|
|
|
|
Returns:
|
|
Tuple of (final response content, list of tools used)
|
|
"""
|
|
tools_used = []
|
|
|
|
for _ in range(max_iterations):
|
|
self._rate_limit()
|
|
response = self.llm.call_with_tools(messages, tools=self.TOOLS)
|
|
|
|
# If no tool calls, we're done
|
|
if not response.tool_calls:
|
|
return response.content, tools_used
|
|
|
|
# Add assistant message with tool calls
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": response.content or "",
|
|
"tool_calls": [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.name,
|
|
"arguments": str(tc.arguments),
|
|
},
|
|
}
|
|
for tc in response.tool_calls
|
|
],
|
|
})
|
|
|
|
# Execute each tool call
|
|
for tool_call in response.tool_calls:
|
|
tool_result = self._execute_tool(context, tool_call)
|
|
tools_used.append(tool_call.name)
|
|
|
|
# Add tool result to messages
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"content": tool_result,
|
|
})
|
|
|
|
# If we hit max iterations, make one final call without tools
|
|
self._rate_limit()
|
|
final_response = self.llm.call_with_tools(
|
|
messages, tools=None, tool_choice="none"
|
|
)
|
|
return final_response.content, tools_used
|
|
|
|
def _execute_tool(self, context: AgentContext, tool_call: ToolCall) -> str:
|
|
"""Execute a tool call and return the result."""
|
|
self.logger.info(f"Executing tool: {tool_call.name}")
|
|
|
|
try:
|
|
if tool_call.name == "search_codebase":
|
|
return self._tool_search_codebase(
|
|
context,
|
|
tool_call.arguments.get("query", ""),
|
|
tool_call.arguments.get("file_pattern"),
|
|
)
|
|
elif tool_call.name == "read_file":
|
|
return self._tool_read_file(
|
|
context,
|
|
tool_call.arguments.get("filepath", ""),
|
|
)
|
|
elif tool_call.name == "search_web":
|
|
return self._tool_search_web(
|
|
tool_call.arguments.get("query", ""),
|
|
tool_call.arguments.get("categories"),
|
|
)
|
|
else:
|
|
return f"Unknown tool: {tool_call.name}"
|
|
except Exception as e:
|
|
self.logger.error(f"Tool execution failed: {e}")
|
|
return f"Error executing tool: {e}"
|
|
|
|
def _tool_search_codebase(
|
|
self,
|
|
context: AgentContext,
|
|
query: str,
|
|
file_pattern: str | None = None,
|
|
) -> str:
|
|
"""Search the codebase for files matching a query."""
|
|
results = []
|
|
|
|
# Get repository file list
|
|
try:
|
|
files = self._collect_files(context.owner, context.repo, file_pattern)
|
|
except Exception as e:
|
|
return f"Error listing files: {e}"
|
|
|
|
query_lower = query.lower()
|
|
|
|
# Search through files
|
|
for file_info in files[:50]: # Limit to prevent API exhaustion
|
|
filepath = file_info.get("path", "")
|
|
|
|
# Check filename match
|
|
if query_lower in filepath.lower():
|
|
results.append(f"File: {filepath}")
|
|
continue
|
|
|
|
# Check content for code patterns
|
|
try:
|
|
content_data = self.gitea.get_file_contents(
|
|
context.owner, context.repo, filepath
|
|
)
|
|
if content_data.get("content"):
|
|
content = base64.b64decode(content_data["content"]).decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
|
|
# Search for query in content
|
|
lines = content.splitlines()
|
|
matching_lines = []
|
|
for i, line in enumerate(lines, 1):
|
|
if query_lower in line.lower():
|
|
matching_lines.append(f" L{i}: {line.strip()[:100]}")
|
|
|
|
if matching_lines:
|
|
results.append(f"File: {filepath}")
|
|
results.extend(matching_lines[:5]) # Max 5 matches per file
|
|
except Exception:
|
|
pass
|
|
|
|
if not results:
|
|
return f"No results found for '{query}'"
|
|
|
|
return "\n".join(results[:30]) # Limit total results
|
|
|
|
def _collect_files(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
file_pattern: str | None = None,
|
|
) -> list[dict]:
|
|
"""Collect files from the repository."""
|
|
files = []
|
|
|
|
# Code extensions to search
|
|
code_extensions = {
|
|
".py", ".js", ".ts", ".go", ".rs", ".java", ".rb",
|
|
".php", ".c", ".cpp", ".h", ".cs", ".swift", ".kt",
|
|
".md", ".yml", ".yaml", ".json", ".toml",
|
|
}
|
|
|
|
# Patterns to ignore
|
|
ignore_patterns = [
|
|
"node_modules/", "vendor/", ".git/", "__pycache__/",
|
|
".venv/", "dist/", "build/", ".min.js", ".min.css",
|
|
]
|
|
|
|
def traverse(path: str = ""):
|
|
try:
|
|
contents = self.gitea.get_file_contents(owner, repo, path or ".")
|
|
if isinstance(contents, list):
|
|
for item in contents:
|
|
item_path = item.get("path", "")
|
|
|
|
if any(p in item_path for p in ignore_patterns):
|
|
continue
|
|
|
|
if item.get("type") == "file":
|
|
ext = os.path.splitext(item_path)[1]
|
|
if ext in code_extensions:
|
|
# Check file pattern if provided
|
|
if file_pattern:
|
|
if not self._match_pattern(item_path, file_pattern):
|
|
continue
|
|
files.append(item)
|
|
elif item.get("type") == "dir":
|
|
traverse(item_path)
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to list {path}: {e}")
|
|
|
|
traverse()
|
|
return files[:100] # Limit to prevent API exhaustion
|
|
|
|
def _match_pattern(self, filepath: str, pattern: str) -> bool:
|
|
"""Check if filepath matches a simple glob pattern."""
|
|
import fnmatch
|
|
return fnmatch.fnmatch(filepath, pattern)
|
|
|
|
def _tool_read_file(self, context: AgentContext, filepath: str) -> str:
|
|
"""Read a file from the repository."""
|
|
try:
|
|
content_data = self.gitea.get_file_contents(
|
|
context.owner, context.repo, filepath
|
|
)
|
|
if content_data.get("content"):
|
|
content = base64.b64decode(content_data["content"]).decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
# Truncate if too long
|
|
if len(content) > 8000:
|
|
content = content[:8000] + "\n... (truncated)"
|
|
return f"File: {filepath}\n\n```\n{content}\n```"
|
|
return f"File not found: {filepath}"
|
|
except Exception as e:
|
|
return f"Error reading file: {e}"
|
|
|
|
def _tool_search_web(
|
|
self,
|
|
query: str,
|
|
categories: str | None = None,
|
|
) -> str:
|
|
"""Search the web using SearXNG."""
|
|
if not self._searxng_url:
|
|
return "Web search is not configured. Set SEARXNG_URL environment variable."
|
|
|
|
try:
|
|
params = {
|
|
"q": query,
|
|
"format": "json",
|
|
}
|
|
if categories:
|
|
params["categories"] = categories
|
|
|
|
response = requests.get(
|
|
f"{self._searxng_url}/search",
|
|
params=params,
|
|
timeout=30,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
results = data.get("results", [])
|
|
if not results:
|
|
return f"No web results found for '{query}'"
|
|
|
|
# Format results
|
|
output = []
|
|
for i, result in enumerate(results[:5], 1): # Top 5 results
|
|
title = result.get("title", "No title")
|
|
url = result.get("url", "")
|
|
content = result.get("content", "")[:200]
|
|
output.append(f"{i}. **{title}**\n {url}\n {content}")
|
|
|
|
return "\n\n".join(output)
|
|
except requests.exceptions.RequestException as e:
|
|
return f"Web search failed: {e}"
|
|
|
|
def _format_response(self, content: str) -> str:
|
|
"""Format the chat response with disclaimer."""
|
|
lines = [
|
|
f"{self.AI_DISCLAIMER}",
|
|
"",
|
|
"---",
|
|
"",
|
|
content,
|
|
]
|
|
return "\n".join(lines)
|