"""Test Coverage Agent AI agent for analyzing code changes and suggesting test cases. Helps improve test coverage by identifying untested code paths. """ import base64 import os import re from dataclasses import dataclass, field from agents.base_agent import AgentContext, AgentResult, BaseAgent @dataclass class TestSuggestion: """A suggested test case.""" function_name: str file_path: str test_type: str # unit, integration, edge_case description: str example_code: str | None = None priority: str = "MEDIUM" # HIGH, MEDIUM, LOW @dataclass class CoverageReport: """Report of test coverage analysis.""" functions_analyzed: int functions_with_tests: int functions_without_tests: int suggestions: list[TestSuggestion] existing_tests: list[str] coverage_estimate: float class TestCoverageAgent(BaseAgent): """Agent for analyzing test coverage and suggesting tests.""" # Marker for test coverage comments TEST_AI_MARKER = "" # Test file patterns by language TEST_PATTERNS = { "python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"], "javascript": [ r".*\.test\.[jt]sx?$", r".*\.spec\.[jt]sx?$", r"__tests__/.*\.[jt]sx?$", ], "go": [r".*_test\.go$"], "rust": [r"tests?/.*\.rs$"], "java": [r".*Test\.java$", r".*Tests\.java$"], "ruby": [r".*_spec\.rb$", r"test_.*\.rb$"], } # Function/method patterns by language FUNCTION_PATTERNS = { "python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(", "javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))", "go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(", "rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)", "java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{", "ruby": r"^\s*def\s+(\w+)", } def can_handle(self, event_type: str, event_data: dict) -> bool: """Check if this agent handles the given event.""" agent_config = self.config.get("agents", {}).get("test_coverage", {}) if not agent_config.get("enabled", True): return False # Handle @codebot suggest-tests command if event_type == "issue_comment": comment_body = event_data.get("comment", {}).get("body", "") mention_prefix = self.config.get("interaction", {}).get( "mention_prefix", "@codebot" ) if f"{mention_prefix} suggest-tests" in comment_body.lower(): return True return False def execute(self, context: AgentContext) -> AgentResult: """Execute the test coverage agent.""" self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}") actions_taken = [] # Get issue/PR number and author issue = context.event_data.get("issue", {}) issue_number = issue.get("number") comment_author = ( context.event_data.get("comment", {}).get("user", {}).get("login", "user") ) # Check if this is a PR is_pr = issue.get("pull_request") is not None if is_pr: # Analyze PR diff for changed functions diff = self._get_pr_diff(context.owner, context.repo, issue_number) changed_functions = self._extract_changed_functions(diff) actions_taken.append(f"Analyzed {len(changed_functions)} changed functions") else: # Analyze entire repository changed_functions = self._analyze_repository(context.owner, context.repo) actions_taken.append( f"Analyzed {len(changed_functions)} functions in repository" ) # Find existing tests existing_tests = self._find_existing_tests(context.owner, context.repo) actions_taken.append(f"Found {len(existing_tests)} existing test files") # Generate test suggestions using LLM report = self._generate_suggestions( context.owner, context.repo, changed_functions, existing_tests ) # Post report if issue_number: comment = self._format_coverage_report(report, comment_author, is_pr) self.upsert_comment( context.owner, context.repo, issue_number, comment, marker=self.TEST_AI_MARKER, ) actions_taken.append("Posted test coverage report") return AgentResult( success=True, message=f"Generated {len(report.suggestions)} test suggestions", data={ "functions_analyzed": report.functions_analyzed, "suggestions_count": len(report.suggestions), "coverage_estimate": report.coverage_estimate, }, actions_taken=actions_taken, ) def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str: """Get the PR diff.""" try: return self.gitea.get_pull_request_diff(owner, repo, pr_number) except Exception as e: self.logger.error(f"Failed to get PR diff: {e}") return "" def _extract_changed_functions(self, diff: str) -> list[dict]: """Extract changed functions from diff.""" functions = [] current_file = None current_language = None for line in diff.splitlines(): # Track current file if line.startswith("diff --git"): match = re.search(r"b/(.+)$", line) if match: current_file = match.group(1) current_language = self._detect_language(current_file) # Look for function definitions in added lines if line.startswith("+") and not line.startswith("+++"): if current_language and current_language in self.FUNCTION_PATTERNS: pattern = self.FUNCTION_PATTERNS[current_language] match = re.search(pattern, line[1:]) # Skip the + prefix if match: func_name = next(g for g in match.groups() if g) functions.append( { "name": func_name, "file": current_file, "language": current_language, "line": line[1:].strip(), } ) return functions def _analyze_repository(self, owner: str, repo: str) -> list[dict]: """Analyze repository for functions without tests.""" functions = [] code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"} # Get repository contents (limited to avoid API exhaustion) try: contents = self.gitea.get_file_contents(owner, repo, "") if isinstance(contents, list): for item in contents[:50]: # Limit files if item.get("type") == "file": filepath = item.get("path", "") ext = os.path.splitext(filepath)[1] if ext in code_extensions: file_functions = self._extract_functions_from_file( owner, repo, filepath ) functions.extend(file_functions) except Exception as e: self.logger.warning(f"Failed to analyze repository: {e}") return functions[:100] # Limit total functions def _extract_functions_from_file( self, owner: str, repo: str, filepath: str ) -> list[dict]: """Extract function definitions from a file.""" functions = [] language = self._detect_language(filepath) if not language or language not in self.FUNCTION_PATTERNS: return functions try: content_data = self.gitea.get_file_contents(owner, repo, filepath) if content_data.get("content"): content = base64.b64decode(content_data["content"]).decode( "utf-8", errors="ignore" ) pattern = self.FUNCTION_PATTERNS[language] for i, line in enumerate(content.splitlines(), 1): match = re.search(pattern, line) if match: func_name = next((g for g in match.groups() if g), None) if func_name and not func_name.startswith("_"): functions.append( { "name": func_name, "file": filepath, "language": language, "line_number": i, } ) except Exception: pass return functions def _detect_language(self, filepath: str) -> str | None: """Detect programming language from file path.""" ext_map = { ".py": "python", ".js": "javascript", ".jsx": "javascript", ".ts": "javascript", ".tsx": "javascript", ".go": "go", ".rs": "rust", ".java": "java", ".rb": "ruby", } ext = os.path.splitext(filepath)[1] return ext_map.get(ext) def _find_existing_tests(self, owner: str, repo: str) -> list[str]: """Find existing test files in the repository.""" test_files = [] # Common test directories test_dirs = ["tests", "test", "__tests__", "spec"] for test_dir in test_dirs: try: contents = self.gitea.get_file_contents(owner, repo, test_dir) if isinstance(contents, list): for item in contents: if item.get("type") == "file": test_files.append(item.get("path", "")) except Exception: pass # Also check root for test files try: contents = self.gitea.get_file_contents(owner, repo, "") if isinstance(contents, list): for item in contents: if item.get("type") == "file": filepath = item.get("path", "") if self._is_test_file(filepath): test_files.append(filepath) except Exception: pass return test_files def _is_test_file(self, filepath: str) -> bool: """Check if a file is a test file.""" for lang, patterns in self.TEST_PATTERNS.items(): for pattern in patterns: if re.search(pattern, filepath): return True return False def _generate_suggestions( self, owner: str, repo: str, functions: list[dict], existing_tests: list[str], ) -> CoverageReport: """Generate test suggestions using LLM.""" suggestions = [] # Build prompt for LLM if functions: functions_text = "\n".join( [ f"- {f['name']} in {f['file']} ({f['language']})" for f in functions[:20] # Limit for prompt size ] ) prompt = f"""Analyze these functions and suggest test cases: Functions to test: {functions_text} Existing test files: {", ".join(existing_tests[:10]) if existing_tests else "None found"} For each function, suggest: 1. What to test (happy path, edge cases, error handling) 2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities) 3. Brief example test code if possible Respond in JSON format: {{ "suggestions": [ {{ "function_name": "function_name", "file_path": "path/to/file", "test_type": "unit|integration|edge_case", "description": "What to test", "example_code": "brief example or null", "priority": "HIGH|MEDIUM|LOW" }} ], "coverage_estimate": 0.0 to 1.0 }} """ try: result = self.call_llm_json(prompt) for s in result.get("suggestions", []): suggestions.append( TestSuggestion( function_name=s.get("function_name", ""), file_path=s.get("file_path", ""), test_type=s.get("test_type", "unit"), description=s.get("description", ""), example_code=s.get("example_code"), priority=s.get("priority", "MEDIUM"), ) ) coverage_estimate = result.get("coverage_estimate", 0.5) except Exception as e: self.logger.warning(f"LLM suggestion failed: {e}") # Generate basic suggestions without LLM for f in functions[:10]: suggestions.append( TestSuggestion( function_name=f["name"], file_path=f["file"], test_type="unit", description=f"Add unit tests for {f['name']}", priority="MEDIUM", ) ) coverage_estimate = 0.5 else: coverage_estimate = 1.0 if existing_tests else 0.0 # Estimate functions with tests functions_with_tests = int(len(functions) * coverage_estimate) return CoverageReport( functions_analyzed=len(functions), functions_with_tests=functions_with_tests, functions_without_tests=len(functions) - functions_with_tests, suggestions=suggestions, existing_tests=existing_tests, coverage_estimate=coverage_estimate, ) def _format_coverage_report( self, report: CoverageReport, user: str | None, is_pr: bool ) -> str: """Format the coverage report as a comment.""" lines = [] if user: lines.append(f"@{user}") lines.append("") lines.extend( [ f"{self.AI_DISCLAIMER}", "", "## ๐Ÿงช Test Coverage Analysis", "", "### Summary", "", f"| Metric | Value |", f"|--------|-------|", f"| Functions Analyzed | {report.functions_analyzed} |", f"| Estimated Coverage | {report.coverage_estimate:.0%} |", f"| Test Files Found | {len(report.existing_tests)} |", f"| Suggestions | {len(report.suggestions)} |", "", ] ) # Suggestions by priority if report.suggestions: lines.append("### ๐Ÿ’ก Test Suggestions") lines.append("") # Group by priority by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []} for s in report.suggestions: if s.priority in by_priority: by_priority[s.priority].append(s) priority_emoji = {"HIGH": "๐Ÿ”ด", "MEDIUM": "๐ŸŸก", "LOW": "๐Ÿ”ต"} for priority in ["HIGH", "MEDIUM", "LOW"]: suggestions = by_priority[priority] if suggestions: lines.append(f"#### {priority_emoji[priority]} {priority} Priority") lines.append("") for s in suggestions[:5]: # Limit display lines.append(f"**`{s.function_name}`** in `{s.file_path}`") lines.append(f"- Type: {s.test_type}") lines.append(f"- {s.description}") if s.example_code: lines.append(f"```") lines.append(s.example_code[:200]) lines.append(f"```") lines.append("") if len(suggestions) > 5: lines.append(f"*... and {len(suggestions) - 5} more*") lines.append("") # Existing test files if report.existing_tests: lines.append("### ๐Ÿ“ Existing Test Files") lines.append("") for f in report.existing_tests[:10]: lines.append(f"- `{f}`") if len(report.existing_tests) > 10: lines.append(f"- *... and {len(report.existing_tests) - 10} more*") lines.append("") # Coverage bar lines.append("### ๐Ÿ“Š Coverage Estimate") lines.append("") filled = int(report.coverage_estimate * 10) bar = "โ–ˆ" * filled + "โ–‘" * (10 - filled) lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}") lines.append("") # Recommendations if report.coverage_estimate < 0.8: lines.append("---") lines.append("โš ๏ธ **Coverage below 80%** - Consider adding more tests") elif report.coverage_estimate >= 0.8: lines.append("---") lines.append("โœ… **Good test coverage!**") return "\n".join(lines)