Files
openrabbit/tools/ai-review/agents/test_coverage_agent.py
latte e8d28225e0
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
just why not
2026-01-07 21:19:46 +01:00

481 lines
17 KiB
Python

"""Test Coverage Agent
AI agent for analyzing code changes and suggesting test cases.
Helps improve test coverage by identifying untested code paths.
"""
import base64
import os
import re
from dataclasses import dataclass, field
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class TestSuggestion:
"""A suggested test case."""
function_name: str
file_path: str
test_type: str # unit, integration, edge_case
description: str
example_code: str | None = None
priority: str = "MEDIUM" # HIGH, MEDIUM, LOW
@dataclass
class CoverageReport:
"""Report of test coverage analysis."""
functions_analyzed: int
functions_with_tests: int
functions_without_tests: int
suggestions: list[TestSuggestion]
existing_tests: list[str]
coverage_estimate: float
class TestCoverageAgent(BaseAgent):
"""Agent for analyzing test coverage and suggesting tests."""
# Marker for test coverage comments
TEST_AI_MARKER = "<!-- AI_TEST_COVERAGE -->"
# Test file patterns by language
TEST_PATTERNS = {
"python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"],
"javascript": [
r".*\.test\.[jt]sx?$",
r".*\.spec\.[jt]sx?$",
r"__tests__/.*\.[jt]sx?$",
],
"go": [r".*_test\.go$"],
"rust": [r"tests?/.*\.rs$"],
"java": [r".*Test\.java$", r".*Tests\.java$"],
"ruby": [r".*_spec\.rb$", r"test_.*\.rb$"],
}
# Function/method patterns by language
FUNCTION_PATTERNS = {
"python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(",
"javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))",
"go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(",
"rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)",
"java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{",
"ruby": r"^\s*def\s+(\w+)",
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("test_coverage", {})
if not agent_config.get("enabled", True):
return False
# Handle @codebot suggest-tests command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} suggest-tests" in comment_body.lower():
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the test coverage agent."""
self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}")
actions_taken = []
# Get issue/PR number and author
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {}).get("user", {}).get("login", "user")
)
# Check if this is a PR
is_pr = issue.get("pull_request") is not None
if is_pr:
# Analyze PR diff for changed functions
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
changed_functions = self._extract_changed_functions(diff)
actions_taken.append(f"Analyzed {len(changed_functions)} changed functions")
else:
# Analyze entire repository
changed_functions = self._analyze_repository(context.owner, context.repo)
actions_taken.append(
f"Analyzed {len(changed_functions)} functions in repository"
)
# Find existing tests
existing_tests = self._find_existing_tests(context.owner, context.repo)
actions_taken.append(f"Found {len(existing_tests)} existing test files")
# Generate test suggestions using LLM
report = self._generate_suggestions(
context.owner, context.repo, changed_functions, existing_tests
)
# Post report
if issue_number:
comment = self._format_coverage_report(report, comment_author, is_pr)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.TEST_AI_MARKER,
)
actions_taken.append("Posted test coverage report")
return AgentResult(
success=True,
message=f"Generated {len(report.suggestions)} test suggestions",
data={
"functions_analyzed": report.functions_analyzed,
"suggestions_count": len(report.suggestions),
"coverage_estimate": report.coverage_estimate,
},
actions_taken=actions_taken,
)
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
"""Get the PR diff."""
try:
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
except Exception as e:
self.logger.error(f"Failed to get PR diff: {e}")
return ""
def _extract_changed_functions(self, diff: str) -> list[dict]:
"""Extract changed functions from diff."""
functions = []
current_file = None
current_language = None
for line in diff.splitlines():
# Track current file
if line.startswith("diff --git"):
match = re.search(r"b/(.+)$", line)
if match:
current_file = match.group(1)
current_language = self._detect_language(current_file)
# Look for function definitions in added lines
if line.startswith("+") and not line.startswith("+++"):
if current_language and current_language in self.FUNCTION_PATTERNS:
pattern = self.FUNCTION_PATTERNS[current_language]
match = re.search(pattern, line[1:]) # Skip the + prefix
if match:
func_name = next(g for g in match.groups() if g)
functions.append(
{
"name": func_name,
"file": current_file,
"language": current_language,
"line": line[1:].strip(),
}
)
return functions
def _analyze_repository(self, owner: str, repo: str) -> list[dict]:
"""Analyze repository for functions without tests."""
functions = []
code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"}
# Get repository contents (limited to avoid API exhaustion)
try:
contents = self.gitea.get_file_contents(owner, repo, "")
if isinstance(contents, list):
for item in contents[:50]: # Limit files
if item.get("type") == "file":
filepath = item.get("path", "")
ext = os.path.splitext(filepath)[1]
if ext in code_extensions:
file_functions = self._extract_functions_from_file(
owner, repo, filepath
)
functions.extend(file_functions)
except Exception as e:
self.logger.warning(f"Failed to analyze repository: {e}")
return functions[:100] # Limit total functions
def _extract_functions_from_file(
self, owner: str, repo: str, filepath: str
) -> list[dict]:
"""Extract function definitions from a file."""
functions = []
language = self._detect_language(filepath)
if not language or language not in self.FUNCTION_PATTERNS:
return functions
try:
content_data = self.gitea.get_file_contents(owner, repo, filepath)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
pattern = self.FUNCTION_PATTERNS[language]
for i, line in enumerate(content.splitlines(), 1):
match = re.search(pattern, line)
if match:
func_name = next((g for g in match.groups() if g), None)
if func_name and not func_name.startswith("_"):
functions.append(
{
"name": func_name,
"file": filepath,
"language": language,
"line_number": i,
}
)
except Exception:
pass
return functions
def _detect_language(self, filepath: str) -> str | None:
"""Detect programming language from file path."""
ext_map = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "javascript",
".tsx": "javascript",
".go": "go",
".rs": "rust",
".java": "java",
".rb": "ruby",
}
ext = os.path.splitext(filepath)[1]
return ext_map.get(ext)
def _find_existing_tests(self, owner: str, repo: str) -> list[str]:
"""Find existing test files in the repository."""
test_files = []
# Common test directories
test_dirs = ["tests", "test", "__tests__", "spec"]
for test_dir in test_dirs:
try:
contents = self.gitea.get_file_contents(owner, repo, test_dir)
if isinstance(contents, list):
for item in contents:
if item.get("type") == "file":
test_files.append(item.get("path", ""))
except Exception:
pass
# Also check root for test files
try:
contents = self.gitea.get_file_contents(owner, repo, "")
if isinstance(contents, list):
for item in contents:
if item.get("type") == "file":
filepath = item.get("path", "")
if self._is_test_file(filepath):
test_files.append(filepath)
except Exception:
pass
return test_files
def _is_test_file(self, filepath: str) -> bool:
"""Check if a file is a test file."""
for lang, patterns in self.TEST_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, filepath):
return True
return False
def _generate_suggestions(
self,
owner: str,
repo: str,
functions: list[dict],
existing_tests: list[str],
) -> CoverageReport:
"""Generate test suggestions using LLM."""
suggestions = []
# Build prompt for LLM
if functions:
functions_text = "\n".join(
[
f"- {f['name']} in {f['file']} ({f['language']})"
for f in functions[:20] # Limit for prompt size
]
)
prompt = f"""Analyze these functions and suggest test cases:
Functions to test:
{functions_text}
Existing test files:
{", ".join(existing_tests[:10]) if existing_tests else "None found"}
For each function, suggest:
1. What to test (happy path, edge cases, error handling)
2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities)
3. Brief example test code if possible
Respond in JSON format:
{{
"suggestions": [
{{
"function_name": "function_name",
"file_path": "path/to/file",
"test_type": "unit|integration|edge_case",
"description": "What to test",
"example_code": "brief example or null",
"priority": "HIGH|MEDIUM|LOW"
}}
],
"coverage_estimate": 0.0 to 1.0
}}
"""
try:
result = self.call_llm_json(prompt)
for s in result.get("suggestions", []):
suggestions.append(
TestSuggestion(
function_name=s.get("function_name", ""),
file_path=s.get("file_path", ""),
test_type=s.get("test_type", "unit"),
description=s.get("description", ""),
example_code=s.get("example_code"),
priority=s.get("priority", "MEDIUM"),
)
)
coverage_estimate = result.get("coverage_estimate", 0.5)
except Exception as e:
self.logger.warning(f"LLM suggestion failed: {e}")
# Generate basic suggestions without LLM
for f in functions[:10]:
suggestions.append(
TestSuggestion(
function_name=f["name"],
file_path=f["file"],
test_type="unit",
description=f"Add unit tests for {f['name']}",
priority="MEDIUM",
)
)
coverage_estimate = 0.5
else:
coverage_estimate = 1.0 if existing_tests else 0.0
# Estimate functions with tests
functions_with_tests = int(len(functions) * coverage_estimate)
return CoverageReport(
functions_analyzed=len(functions),
functions_with_tests=functions_with_tests,
functions_without_tests=len(functions) - functions_with_tests,
suggestions=suggestions,
existing_tests=existing_tests,
coverage_estimate=coverage_estimate,
)
def _format_coverage_report(
self, report: CoverageReport, user: str | None, is_pr: bool
) -> str:
"""Format the coverage report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🧪 Test Coverage Analysis",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Functions Analyzed | {report.functions_analyzed} |",
f"| Estimated Coverage | {report.coverage_estimate:.0%} |",
f"| Test Files Found | {len(report.existing_tests)} |",
f"| Suggestions | {len(report.suggestions)} |",
"",
]
)
# Suggestions by priority
if report.suggestions:
lines.append("### 💡 Test Suggestions")
lines.append("")
# Group by priority
by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []}
for s in report.suggestions:
if s.priority in by_priority:
by_priority[s.priority].append(s)
priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
for priority in ["HIGH", "MEDIUM", "LOW"]:
suggestions = by_priority[priority]
if suggestions:
lines.append(f"#### {priority_emoji[priority]} {priority} Priority")
lines.append("")
for s in suggestions[:5]: # Limit display
lines.append(f"**`{s.function_name}`** in `{s.file_path}`")
lines.append(f"- Type: {s.test_type}")
lines.append(f"- {s.description}")
if s.example_code:
lines.append(f"```")
lines.append(s.example_code[:200])
lines.append(f"```")
lines.append("")
if len(suggestions) > 5:
lines.append(f"*... and {len(suggestions) - 5} more*")
lines.append("")
# Existing test files
if report.existing_tests:
lines.append("### 📁 Existing Test Files")
lines.append("")
for f in report.existing_tests[:10]:
lines.append(f"- `{f}`")
if len(report.existing_tests) > 10:
lines.append(f"- *... and {len(report.existing_tests) - 10} more*")
lines.append("")
# Coverage bar
lines.append("### 📊 Coverage Estimate")
lines.append("")
filled = int(report.coverage_estimate * 10)
bar = "" * filled + "" * (10 - filled)
lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}")
lines.append("")
# Recommendations
if report.coverage_estimate < 0.8:
lines.append("---")
lines.append("⚠️ **Coverage below 80%** - Consider adding more tests")
elif report.coverage_estimate >= 0.8:
lines.append("---")
lines.append("✅ **Good test coverage!**")
return "\n".join(lines)