All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
481 lines
17 KiB
Python
481 lines
17 KiB
Python
"""Test Coverage Agent
|
|
|
|
AI agent for analyzing code changes and suggesting test cases.
|
|
Helps improve test coverage by identifying untested code paths.
|
|
"""
|
|
|
|
import base64
|
|
import os
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
|
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
|
|
|
|
|
@dataclass
|
|
class TestSuggestion:
|
|
"""A suggested test case."""
|
|
|
|
function_name: str
|
|
file_path: str
|
|
test_type: str # unit, integration, edge_case
|
|
description: str
|
|
example_code: str | None = None
|
|
priority: str = "MEDIUM" # HIGH, MEDIUM, LOW
|
|
|
|
|
|
@dataclass
|
|
class CoverageReport:
|
|
"""Report of test coverage analysis."""
|
|
|
|
functions_analyzed: int
|
|
functions_with_tests: int
|
|
functions_without_tests: int
|
|
suggestions: list[TestSuggestion]
|
|
existing_tests: list[str]
|
|
coverage_estimate: float
|
|
|
|
|
|
class TestCoverageAgent(BaseAgent):
|
|
"""Agent for analyzing test coverage and suggesting tests."""
|
|
|
|
# Marker for test coverage comments
|
|
TEST_AI_MARKER = "<!-- AI_TEST_COVERAGE -->"
|
|
|
|
# Test file patterns by language
|
|
TEST_PATTERNS = {
|
|
"python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"],
|
|
"javascript": [
|
|
r".*\.test\.[jt]sx?$",
|
|
r".*\.spec\.[jt]sx?$",
|
|
r"__tests__/.*\.[jt]sx?$",
|
|
],
|
|
"go": [r".*_test\.go$"],
|
|
"rust": [r"tests?/.*\.rs$"],
|
|
"java": [r".*Test\.java$", r".*Tests\.java$"],
|
|
"ruby": [r".*_spec\.rb$", r"test_.*\.rb$"],
|
|
}
|
|
|
|
# Function/method patterns by language
|
|
FUNCTION_PATTERNS = {
|
|
"python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(",
|
|
"javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))",
|
|
"go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(",
|
|
"rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)",
|
|
"java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{",
|
|
"ruby": r"^\s*def\s+(\w+)",
|
|
}
|
|
|
|
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
|
"""Check if this agent handles the given event."""
|
|
agent_config = self.config.get("agents", {}).get("test_coverage", {})
|
|
if not agent_config.get("enabled", True):
|
|
return False
|
|
|
|
# Handle @codebot suggest-tests command
|
|
if event_type == "issue_comment":
|
|
comment_body = event_data.get("comment", {}).get("body", "")
|
|
mention_prefix = self.config.get("interaction", {}).get(
|
|
"mention_prefix", "@codebot"
|
|
)
|
|
if f"{mention_prefix} suggest-tests" in comment_body.lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
def execute(self, context: AgentContext) -> AgentResult:
|
|
"""Execute the test coverage agent."""
|
|
self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}")
|
|
|
|
actions_taken = []
|
|
|
|
# Get issue/PR number and author
|
|
issue = context.event_data.get("issue", {})
|
|
issue_number = issue.get("number")
|
|
comment_author = (
|
|
context.event_data.get("comment", {}).get("user", {}).get("login", "user")
|
|
)
|
|
|
|
# Check if this is a PR
|
|
is_pr = issue.get("pull_request") is not None
|
|
|
|
if is_pr:
|
|
# Analyze PR diff for changed functions
|
|
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
|
|
changed_functions = self._extract_changed_functions(diff)
|
|
actions_taken.append(f"Analyzed {len(changed_functions)} changed functions")
|
|
else:
|
|
# Analyze entire repository
|
|
changed_functions = self._analyze_repository(context.owner, context.repo)
|
|
actions_taken.append(
|
|
f"Analyzed {len(changed_functions)} functions in repository"
|
|
)
|
|
|
|
# Find existing tests
|
|
existing_tests = self._find_existing_tests(context.owner, context.repo)
|
|
actions_taken.append(f"Found {len(existing_tests)} existing test files")
|
|
|
|
# Generate test suggestions using LLM
|
|
report = self._generate_suggestions(
|
|
context.owner, context.repo, changed_functions, existing_tests
|
|
)
|
|
|
|
# Post report
|
|
if issue_number:
|
|
comment = self._format_coverage_report(report, comment_author, is_pr)
|
|
self.upsert_comment(
|
|
context.owner,
|
|
context.repo,
|
|
issue_number,
|
|
comment,
|
|
marker=self.TEST_AI_MARKER,
|
|
)
|
|
actions_taken.append("Posted test coverage report")
|
|
|
|
return AgentResult(
|
|
success=True,
|
|
message=f"Generated {len(report.suggestions)} test suggestions",
|
|
data={
|
|
"functions_analyzed": report.functions_analyzed,
|
|
"suggestions_count": len(report.suggestions),
|
|
"coverage_estimate": report.coverage_estimate,
|
|
},
|
|
actions_taken=actions_taken,
|
|
)
|
|
|
|
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
|
|
"""Get the PR diff."""
|
|
try:
|
|
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get PR diff: {e}")
|
|
return ""
|
|
|
|
def _extract_changed_functions(self, diff: str) -> list[dict]:
|
|
"""Extract changed functions from diff."""
|
|
functions = []
|
|
current_file = None
|
|
current_language = None
|
|
|
|
for line in diff.splitlines():
|
|
# Track current file
|
|
if line.startswith("diff --git"):
|
|
match = re.search(r"b/(.+)$", line)
|
|
if match:
|
|
current_file = match.group(1)
|
|
current_language = self._detect_language(current_file)
|
|
|
|
# Look for function definitions in added lines
|
|
if line.startswith("+") and not line.startswith("+++"):
|
|
if current_language and current_language in self.FUNCTION_PATTERNS:
|
|
pattern = self.FUNCTION_PATTERNS[current_language]
|
|
match = re.search(pattern, line[1:]) # Skip the + prefix
|
|
if match:
|
|
func_name = next(g for g in match.groups() if g)
|
|
functions.append(
|
|
{
|
|
"name": func_name,
|
|
"file": current_file,
|
|
"language": current_language,
|
|
"line": line[1:].strip(),
|
|
}
|
|
)
|
|
|
|
return functions
|
|
|
|
def _analyze_repository(self, owner: str, repo: str) -> list[dict]:
|
|
"""Analyze repository for functions without tests."""
|
|
functions = []
|
|
code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"}
|
|
|
|
# Get repository contents (limited to avoid API exhaustion)
|
|
try:
|
|
contents = self.gitea.get_file_contents(owner, repo, "")
|
|
if isinstance(contents, list):
|
|
for item in contents[:50]: # Limit files
|
|
if item.get("type") == "file":
|
|
filepath = item.get("path", "")
|
|
ext = os.path.splitext(filepath)[1]
|
|
if ext in code_extensions:
|
|
file_functions = self._extract_functions_from_file(
|
|
owner, repo, filepath
|
|
)
|
|
functions.extend(file_functions)
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to analyze repository: {e}")
|
|
|
|
return functions[:100] # Limit total functions
|
|
|
|
def _extract_functions_from_file(
|
|
self, owner: str, repo: str, filepath: str
|
|
) -> list[dict]:
|
|
"""Extract function definitions from a file."""
|
|
functions = []
|
|
language = self._detect_language(filepath)
|
|
|
|
if not language or language not in self.FUNCTION_PATTERNS:
|
|
return functions
|
|
|
|
try:
|
|
content_data = self.gitea.get_file_contents(owner, repo, filepath)
|
|
if content_data.get("content"):
|
|
content = base64.b64decode(content_data["content"]).decode(
|
|
"utf-8", errors="ignore"
|
|
)
|
|
|
|
pattern = self.FUNCTION_PATTERNS[language]
|
|
for i, line in enumerate(content.splitlines(), 1):
|
|
match = re.search(pattern, line)
|
|
if match:
|
|
func_name = next((g for g in match.groups() if g), None)
|
|
if func_name and not func_name.startswith("_"):
|
|
functions.append(
|
|
{
|
|
"name": func_name,
|
|
"file": filepath,
|
|
"language": language,
|
|
"line_number": i,
|
|
}
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
return functions
|
|
|
|
def _detect_language(self, filepath: str) -> str | None:
|
|
"""Detect programming language from file path."""
|
|
ext_map = {
|
|
".py": "python",
|
|
".js": "javascript",
|
|
".jsx": "javascript",
|
|
".ts": "javascript",
|
|
".tsx": "javascript",
|
|
".go": "go",
|
|
".rs": "rust",
|
|
".java": "java",
|
|
".rb": "ruby",
|
|
}
|
|
ext = os.path.splitext(filepath)[1]
|
|
return ext_map.get(ext)
|
|
|
|
def _find_existing_tests(self, owner: str, repo: str) -> list[str]:
|
|
"""Find existing test files in the repository."""
|
|
test_files = []
|
|
|
|
# Common test directories
|
|
test_dirs = ["tests", "test", "__tests__", "spec"]
|
|
|
|
for test_dir in test_dirs:
|
|
try:
|
|
contents = self.gitea.get_file_contents(owner, repo, test_dir)
|
|
if isinstance(contents, list):
|
|
for item in contents:
|
|
if item.get("type") == "file":
|
|
test_files.append(item.get("path", ""))
|
|
except Exception:
|
|
pass
|
|
|
|
# Also check root for test files
|
|
try:
|
|
contents = self.gitea.get_file_contents(owner, repo, "")
|
|
if isinstance(contents, list):
|
|
for item in contents:
|
|
if item.get("type") == "file":
|
|
filepath = item.get("path", "")
|
|
if self._is_test_file(filepath):
|
|
test_files.append(filepath)
|
|
except Exception:
|
|
pass
|
|
|
|
return test_files
|
|
|
|
def _is_test_file(self, filepath: str) -> bool:
|
|
"""Check if a file is a test file."""
|
|
for lang, patterns in self.TEST_PATTERNS.items():
|
|
for pattern in patterns:
|
|
if re.search(pattern, filepath):
|
|
return True
|
|
return False
|
|
|
|
def _generate_suggestions(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
functions: list[dict],
|
|
existing_tests: list[str],
|
|
) -> CoverageReport:
|
|
"""Generate test suggestions using LLM."""
|
|
suggestions = []
|
|
|
|
# Build prompt for LLM
|
|
if functions:
|
|
functions_text = "\n".join(
|
|
[
|
|
f"- {f['name']} in {f['file']} ({f['language']})"
|
|
for f in functions[:20] # Limit for prompt size
|
|
]
|
|
)
|
|
|
|
prompt = f"""Analyze these functions and suggest test cases:
|
|
|
|
Functions to test:
|
|
{functions_text}
|
|
|
|
Existing test files:
|
|
{", ".join(existing_tests[:10]) if existing_tests else "None found"}
|
|
|
|
For each function, suggest:
|
|
1. What to test (happy path, edge cases, error handling)
|
|
2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities)
|
|
3. Brief example test code if possible
|
|
|
|
Respond in JSON format:
|
|
{{
|
|
"suggestions": [
|
|
{{
|
|
"function_name": "function_name",
|
|
"file_path": "path/to/file",
|
|
"test_type": "unit|integration|edge_case",
|
|
"description": "What to test",
|
|
"example_code": "brief example or null",
|
|
"priority": "HIGH|MEDIUM|LOW"
|
|
}}
|
|
],
|
|
"coverage_estimate": 0.0 to 1.0
|
|
}}
|
|
"""
|
|
|
|
try:
|
|
result = self.call_llm_json(prompt)
|
|
|
|
for s in result.get("suggestions", []):
|
|
suggestions.append(
|
|
TestSuggestion(
|
|
function_name=s.get("function_name", ""),
|
|
file_path=s.get("file_path", ""),
|
|
test_type=s.get("test_type", "unit"),
|
|
description=s.get("description", ""),
|
|
example_code=s.get("example_code"),
|
|
priority=s.get("priority", "MEDIUM"),
|
|
)
|
|
)
|
|
|
|
coverage_estimate = result.get("coverage_estimate", 0.5)
|
|
|
|
except Exception as e:
|
|
self.logger.warning(f"LLM suggestion failed: {e}")
|
|
# Generate basic suggestions without LLM
|
|
for f in functions[:10]:
|
|
suggestions.append(
|
|
TestSuggestion(
|
|
function_name=f["name"],
|
|
file_path=f["file"],
|
|
test_type="unit",
|
|
description=f"Add unit tests for {f['name']}",
|
|
priority="MEDIUM",
|
|
)
|
|
)
|
|
coverage_estimate = 0.5
|
|
|
|
else:
|
|
coverage_estimate = 1.0 if existing_tests else 0.0
|
|
|
|
# Estimate functions with tests
|
|
functions_with_tests = int(len(functions) * coverage_estimate)
|
|
|
|
return CoverageReport(
|
|
functions_analyzed=len(functions),
|
|
functions_with_tests=functions_with_tests,
|
|
functions_without_tests=len(functions) - functions_with_tests,
|
|
suggestions=suggestions,
|
|
existing_tests=existing_tests,
|
|
coverage_estimate=coverage_estimate,
|
|
)
|
|
|
|
def _format_coverage_report(
|
|
self, report: CoverageReport, user: str | None, is_pr: bool
|
|
) -> str:
|
|
"""Format the coverage report as a comment."""
|
|
lines = []
|
|
|
|
if user:
|
|
lines.append(f"@{user}")
|
|
lines.append("")
|
|
|
|
lines.extend(
|
|
[
|
|
f"{self.AI_DISCLAIMER}",
|
|
"",
|
|
"## 🧪 Test Coverage Analysis",
|
|
"",
|
|
"### Summary",
|
|
"",
|
|
f"| Metric | Value |",
|
|
f"|--------|-------|",
|
|
f"| Functions Analyzed | {report.functions_analyzed} |",
|
|
f"| Estimated Coverage | {report.coverage_estimate:.0%} |",
|
|
f"| Test Files Found | {len(report.existing_tests)} |",
|
|
f"| Suggestions | {len(report.suggestions)} |",
|
|
"",
|
|
]
|
|
)
|
|
|
|
# Suggestions by priority
|
|
if report.suggestions:
|
|
lines.append("### 💡 Test Suggestions")
|
|
lines.append("")
|
|
|
|
# Group by priority
|
|
by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []}
|
|
for s in report.suggestions:
|
|
if s.priority in by_priority:
|
|
by_priority[s.priority].append(s)
|
|
|
|
priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
|
|
|
|
for priority in ["HIGH", "MEDIUM", "LOW"]:
|
|
suggestions = by_priority[priority]
|
|
if suggestions:
|
|
lines.append(f"#### {priority_emoji[priority]} {priority} Priority")
|
|
lines.append("")
|
|
for s in suggestions[:5]: # Limit display
|
|
lines.append(f"**`{s.function_name}`** in `{s.file_path}`")
|
|
lines.append(f"- Type: {s.test_type}")
|
|
lines.append(f"- {s.description}")
|
|
if s.example_code:
|
|
lines.append(f"```")
|
|
lines.append(s.example_code[:200])
|
|
lines.append(f"```")
|
|
lines.append("")
|
|
if len(suggestions) > 5:
|
|
lines.append(f"*... and {len(suggestions) - 5} more*")
|
|
lines.append("")
|
|
|
|
# Existing test files
|
|
if report.existing_tests:
|
|
lines.append("### 📁 Existing Test Files")
|
|
lines.append("")
|
|
for f in report.existing_tests[:10]:
|
|
lines.append(f"- `{f}`")
|
|
if len(report.existing_tests) > 10:
|
|
lines.append(f"- *... and {len(report.existing_tests) - 10} more*")
|
|
lines.append("")
|
|
|
|
# Coverage bar
|
|
lines.append("### 📊 Coverage Estimate")
|
|
lines.append("")
|
|
filled = int(report.coverage_estimate * 10)
|
|
bar = "█" * filled + "░" * (10 - filled)
|
|
lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}")
|
|
lines.append("")
|
|
|
|
# Recommendations
|
|
if report.coverage_estimate < 0.8:
|
|
lines.append("---")
|
|
lines.append("⚠️ **Coverage below 80%** - Consider adding more tests")
|
|
elif report.coverage_estimate >= 0.8:
|
|
lines.append("---")
|
|
lines.append("✅ **Good test coverage!**")
|
|
|
|
return "\n".join(lines)
|