just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s

This commit is contained in:
2026-01-07 21:19:46 +01:00
parent a1fe47cdf4
commit e8d28225e0
24 changed files with 6431 additions and 250 deletions

View File

@@ -2,20 +2,40 @@
This package contains the modular agent implementations for the
enterprise AI code review system.
Core Agents:
- PRAgent: Pull request review and analysis
- IssueAgent: Issue triage and response
- CodebaseAgent: Codebase health analysis
- ChatAgent: Interactive chat with tool calling
Specialized Agents:
- DependencyAgent: Dependency vulnerability scanning
- TestCoverageAgent: Test coverage analysis and suggestions
- ArchitectureAgent: Architecture compliance checking
"""
from agents.architecture_agent import ArchitectureAgent
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from agents.chat_agent import ChatAgent
from agents.codebase_agent import CodebaseAgent
from agents.dependency_agent import DependencyAgent
from agents.issue_agent import IssueAgent
from agents.pr_agent import PRAgent
from agents.test_coverage_agent import TestCoverageAgent
__all__ = [
# Base
"BaseAgent",
"AgentContext",
"AgentResult",
# Core Agents
"IssueAgent",
"PRAgent",
"CodebaseAgent",
"ChatAgent",
# Specialized Agents
"DependencyAgent",
"TestCoverageAgent",
"ArchitectureAgent",
]

View File

@@ -0,0 +1,547 @@
"""Architecture Compliance Agent
AI agent for enforcing architectural patterns and layer separation.
Detects cross-layer violations and circular dependencies.
"""
import base64
import os
import re
from dataclasses import dataclass, field
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class ArchitectureViolation:
"""An architecture violation."""
file: str
line: int
violation_type: str # cross_layer, circular, naming, structure
severity: str # HIGH, MEDIUM, LOW
description: str
recommendation: str
source_layer: str | None = None
target_layer: str | None = None
@dataclass
class ArchitectureReport:
"""Report of architecture analysis."""
violations: list[ArchitectureViolation]
layers_detected: dict[str, list[str]]
circular_dependencies: list[tuple[str, str]]
compliance_score: float
recommendations: list[str]
class ArchitectureAgent(BaseAgent):
"""Agent for enforcing architectural compliance."""
# Marker for architecture comments
ARCH_AI_MARKER = "<!-- AI_ARCHITECTURE_CHECK -->"
# Default layer definitions
DEFAULT_LAYERS = {
"api": {
"patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"],
"can_import": ["services", "models", "utils", "config"],
"cannot_import": ["db", "repositories", "infrastructure"],
},
"services": {
"patterns": ["services/", "usecases/", "application/"],
"can_import": ["models", "repositories", "utils", "config"],
"cannot_import": ["api", "controllers", "handlers"],
},
"repositories": {
"patterns": ["repositories/", "repos/", "data/"],
"can_import": ["models", "db", "utils", "config"],
"cannot_import": ["api", "services", "controllers"],
},
"models": {
"patterns": ["models/", "entities/", "domain/", "schemas/"],
"can_import": ["utils", "config"],
"cannot_import": ["api", "services", "repositories", "db"],
},
"db": {
"patterns": ["db/", "database/", "infrastructure/"],
"can_import": ["models", "config"],
"cannot_import": ["api", "services"],
},
"utils": {
"patterns": ["utils/", "helpers/", "common/", "lib/"],
"can_import": ["config"],
"cannot_import": [],
},
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("architecture", {})
if not agent_config.get("enabled", False):
return False
# Handle PR events
if event_type == "pull_request":
action = event_data.get("action", "")
if action in ["opened", "synchronize"]:
return True
# Handle @codebot architecture command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} architecture" in comment_body.lower():
return True
if f"{mention_prefix} arch-check" in comment_body.lower():
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the architecture agent."""
self.logger.info(f"Checking architecture for {context.owner}/{context.repo}")
actions_taken = []
# Get layer configuration
agent_config = self.config.get("agents", {}).get("architecture", {})
layers = agent_config.get("layers", self.DEFAULT_LAYERS)
# Determine issue number
if context.event_type == "issue_comment":
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {})
.get("user", {})
.get("login", "user")
)
is_pr = issue.get("pull_request") is not None
else:
pr = context.event_data.get("pull_request", {})
issue_number = pr.get("number")
comment_author = None
is_pr = True
if is_pr and issue_number:
# Analyze PR diff
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
report = self._analyze_diff(diff, layers)
actions_taken.append(f"Analyzed PR diff for architecture violations")
else:
# Analyze repository structure
report = self._analyze_repository(context.owner, context.repo, layers)
actions_taken.append(f"Analyzed repository architecture")
# Post report
if issue_number:
comment = self._format_architecture_report(report, comment_author)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.ARCH_AI_MARKER,
)
actions_taken.append("Posted architecture report")
return AgentResult(
success=len(report.violations) == 0 or report.compliance_score >= 0.8,
message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance",
data={
"violations_count": len(report.violations),
"compliance_score": report.compliance_score,
"circular_dependencies": len(report.circular_dependencies),
},
actions_taken=actions_taken,
)
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
"""Get the PR diff."""
try:
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
except Exception as e:
self.logger.error(f"Failed to get PR diff: {e}")
return ""
def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport:
"""Analyze PR diff for architecture violations."""
violations = []
imports_by_file = {}
current_file = None
current_language = None
for line in diff.splitlines():
# Track current file
if line.startswith("diff --git"):
match = re.search(r"b/(.+)$", line)
if match:
current_file = match.group(1)
current_language = self._detect_language(current_file)
imports_by_file[current_file] = []
# Look for import statements in added lines
if line.startswith("+") and not line.startswith("+++"):
if current_file and current_language:
imports = self._extract_imports(line[1:], current_language)
imports_by_file.setdefault(current_file, []).extend(imports)
# Check for violations
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
layer_config = layers.get(source_layer, {})
cannot_import = layer_config.get("cannot_import", [])
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer in cannot_import:
violations.append(
ArchitectureViolation(
file=file_path,
line=0, # Line number not tracked in this simple implementation
violation_type="cross_layer",
severity="HIGH",
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
recommendation=f"Move this import to an allowed layer or refactor the dependency",
source_layer=source_layer,
target_layer=target_layer,
)
)
# Detect circular dependencies
circular = self._detect_circular_dependencies(imports_by_file, layers)
# Calculate compliance score
total_imports = sum(len(imps) for imps in imports_by_file.values())
if total_imports > 0:
compliance = 1.0 - (len(violations) / max(total_imports, 1))
else:
compliance = 1.0
return ArchitectureReport(
violations=violations,
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
circular_dependencies=circular,
compliance_score=max(0.0, compliance),
recommendations=self._generate_recommendations(violations),
)
def _analyze_repository(
self, owner: str, repo: str, layers: dict
) -> ArchitectureReport:
"""Analyze repository structure for architecture compliance."""
violations = []
imports_by_file = {}
# Collect files from each layer
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
try:
path = pattern.rstrip("/")
contents = self.gitea.get_file_contents(owner, repo, path)
if isinstance(contents, list):
for item in contents[:20]: # Limit files per layer
if item.get("type") == "file":
filepath = item.get("path", "")
imports = self._get_file_imports(owner, repo, filepath)
imports_by_file[filepath] = imports
except Exception:
pass
# Check for violations
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
layer_config = layers.get(source_layer, {})
cannot_import = layer_config.get("cannot_import", [])
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer in cannot_import:
violations.append(
ArchitectureViolation(
file=file_path,
line=0,
violation_type="cross_layer",
severity="HIGH",
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
recommendation=f"Refactor to remove dependency on '{target_layer}'",
source_layer=source_layer,
target_layer=target_layer,
)
)
# Detect circular dependencies
circular = self._detect_circular_dependencies(imports_by_file, layers)
# Calculate compliance
total_imports = sum(len(imps) for imps in imports_by_file.values())
if total_imports > 0:
compliance = 1.0 - (len(violations) / max(total_imports, 1))
else:
compliance = 1.0
return ArchitectureReport(
violations=violations,
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
circular_dependencies=circular,
compliance_score=max(0.0, compliance),
recommendations=self._generate_recommendations(violations),
)
def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]:
"""Get imports from a file."""
imports = []
language = self._detect_language(filepath)
if not language:
return imports
try:
content_data = self.gitea.get_file_contents(owner, repo, filepath)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
for line in content.splitlines():
imports.extend(self._extract_imports(line, language))
except Exception:
pass
return imports
def _detect_language(self, filepath: str) -> str | None:
"""Detect programming language from file path."""
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".go": "go",
".java": "java",
".rb": "ruby",
}
ext = os.path.splitext(filepath)[1]
return ext_map.get(ext)
def _extract_imports(self, line: str, language: str) -> list[str]:
"""Extract import statements from a line of code."""
imports = []
line = line.strip()
if language == "python":
# from x import y, import x
match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line)
if match:
imp = match.group(1) or match.group(2)
if imp:
imports.append(imp.split(".")[0])
elif language in ("javascript", "typescript"):
# import x from 'y', require('y')
match = re.search(
r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line
)
if match:
imp = match.group(1) or match.group(2)
if imp and not imp.startswith("."):
imports.append(imp.split("/")[0])
elif imp:
imports.append(imp)
elif language == "go":
# import "package"
match = re.search(r'import\s+["\']([^"\']+)["\']', line)
if match:
imports.append(match.group(1).split("/")[-1])
elif language == "java":
# import package.Class
match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line)
if match:
parts = match.group(1).split(".")
if len(parts) > 1:
imports.append(parts[-2]) # Package name
return imports
def _detect_layer(self, filepath: str, layers: dict) -> str | None:
"""Detect which layer a file belongs to."""
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
if pattern.rstrip("/") in filepath:
return layer_name
return None
def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None:
"""Detect which layer an import refers to."""
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
pattern_name = pattern.rstrip("/").split("/")[-1]
if pattern_name in import_path or import_path.startswith(pattern_name):
return layer_name
return None
def _detect_circular_dependencies(
self, imports_by_file: dict, layers: dict
) -> list[tuple[str, str]]:
"""Detect circular dependencies between layers."""
circular = []
# Build layer dependency graph
layer_deps = {}
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
if source_layer not in layer_deps:
layer_deps[source_layer] = set()
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer != source_layer:
layer_deps[source_layer].add(target_layer)
# Check for circular dependencies
for layer_a, deps_a in layer_deps.items():
for layer_b in deps_a:
if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()):
pair = tuple(sorted([layer_a, layer_b]))
if pair not in circular:
circular.append(pair)
return circular
def _group_files_by_layer(
self, files: list[str], layers: dict
) -> dict[str, list[str]]:
"""Group files by their layer."""
grouped = {}
for filepath in files:
layer = self._detect_layer(filepath, layers)
if layer:
if layer not in grouped:
grouped[layer] = []
grouped[layer].append(filepath)
return grouped
def _generate_recommendations(
self, violations: list[ArchitectureViolation]
) -> list[str]:
"""Generate recommendations based on violations."""
recommendations = []
# Count violations by type
cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer")
if cross_layer > 0:
recommendations.append(
f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces"
)
if cross_layer > 5:
recommendations.append(
"Consider using dependency injection to reduce coupling between layers"
)
return recommendations
def _format_architecture_report(
self, report: ArchitectureReport, user: str | None
) -> str:
"""Format the architecture report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🏗️ Architecture Compliance Check",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Compliance Score | {report.compliance_score:.0%} |",
f"| Violations | {len(report.violations)} |",
f"| Circular Dependencies | {len(report.circular_dependencies)} |",
f"| Layers Detected | {len(report.layers_detected)} |",
"",
]
)
# Compliance bar
filled = int(report.compliance_score * 10)
bar = "" * filled + "" * (10 - filled)
lines.append(f"`[{bar}]` {report.compliance_score:.0%}")
lines.append("")
# Violations
if report.violations:
lines.append("### 🚨 Violations")
lines.append("")
for v in report.violations[:10]: # Limit display
severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
lines.append(
f"{severity_emoji.get(v.severity, '')} **{v.violation_type.upper()}** in `{v.file}`"
)
lines.append(f" - {v.description}")
lines.append(f" - 💡 {v.recommendation}")
lines.append("")
if len(report.violations) > 10:
lines.append(f"*... and {len(report.violations) - 10} more violations*")
lines.append("")
# Circular dependencies
if report.circular_dependencies:
lines.append("### 🔄 Circular Dependencies")
lines.append("")
for a, b in report.circular_dependencies:
lines.append(f"- `{a}` ↔ `{b}`")
lines.append("")
# Layers detected
if report.layers_detected:
lines.append("### 📁 Layers Detected")
lines.append("")
for layer, files in report.layers_detected.items():
lines.append(f"- **{layer}**: {len(files)} files")
lines.append("")
# Recommendations
if report.recommendations:
lines.append("### 💡 Recommendations")
lines.append("")
for rec in report.recommendations:
lines.append(f"- {rec}")
lines.append("")
# Overall status
if report.compliance_score >= 0.9:
lines.append("---")
lines.append("✅ **Excellent architecture compliance!**")
elif report.compliance_score >= 0.7:
lines.append("---")
lines.append("⚠️ **Some architectural issues to address**")
else:
lines.append("---")
lines.append("❌ **Significant architectural violations detected**")
return "\n".join(lines)

View File

@@ -65,9 +65,10 @@ class BaseAgent(ABC):
self.llm = llm_client or LLMClient.from_config(self.config)
self.logger = logging.getLogger(self.__class__.__name__)
# Rate limiting
# Rate limiting - now configurable
self._last_request_time = 0.0
self._min_request_interval = 1.0 # seconds
rate_limits = self.config.get("rate_limits", {})
self._min_request_interval = rate_limits.get("min_interval", 1.0) # seconds
@staticmethod
def _load_config() -> dict:

View File

@@ -0,0 +1,548 @@
"""Dependency Security Agent
AI agent for scanning dependency files for known vulnerabilities
and outdated packages. Supports multiple package managers.
"""
import base64
import json
import logging
import os
import re
import subprocess
from dataclasses import dataclass, field
from typing import Any
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class DependencyFinding:
"""A security finding in a dependency."""
package: str
version: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW
vulnerability_id: str # CVE, GHSA, etc.
title: str
description: str
fixed_version: str | None = None
references: list[str] = field(default_factory=list)
@dataclass
class DependencyReport:
"""Report of dependency analysis."""
total_packages: int
vulnerable_packages: int
outdated_packages: int
findings: list[DependencyFinding]
recommendations: list[str]
files_scanned: list[str]
class DependencyAgent(BaseAgent):
"""Agent for scanning dependencies for security vulnerabilities."""
# Marker for dependency comments
DEP_AI_MARKER = "<!-- AI_DEPENDENCY_SCAN -->"
# Supported dependency files
DEPENDENCY_FILES = {
"python": ["requirements.txt", "Pipfile", "pyproject.toml", "setup.py"],
"javascript": ["package.json", "package-lock.json", "yarn.lock"],
"ruby": ["Gemfile", "Gemfile.lock"],
"go": ["go.mod", "go.sum"],
"rust": ["Cargo.toml", "Cargo.lock"],
"java": ["pom.xml", "build.gradle", "build.gradle.kts"],
"php": ["composer.json", "composer.lock"],
"dotnet": ["*.csproj", "packages.config", "*.fsproj"],
}
# Common vulnerable package patterns
KNOWN_VULNERABILITIES = {
"python": {
"requests": {
"< 2.31.0": "CVE-2023-32681 - Proxy-Authorization header leak"
},
"urllib3": {
"< 2.0.7": "CVE-2023-45803 - Request body not stripped on redirects"
},
"cryptography": {"< 41.0.0": "Multiple CVEs - Update recommended"},
"pillow": {"< 10.0.0": "CVE-2023-4863 - WebP vulnerability"},
"django": {"< 4.2.0": "Multiple security fixes"},
"flask": {"< 2.3.0": "Security improvements"},
"pyyaml": {"< 6.0": "CVE-2020-14343 - Arbitrary code execution"},
"jinja2": {"< 3.1.0": "Security fixes"},
},
"javascript": {
"lodash": {"< 4.17.21": "CVE-2021-23337 - Prototype pollution"},
"axios": {"< 1.6.0": "CVE-2023-45857 - CSRF vulnerability"},
"express": {"< 4.18.0": "Security updates"},
"jquery": {"< 3.5.0": "XSS vulnerabilities"},
"minimist": {"< 1.2.6": "Prototype pollution"},
"node-fetch": {"< 3.3.0": "Security fixes"},
},
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("dependency", {})
if not agent_config.get("enabled", True):
return False
# Handle PR events that modify dependency files
if event_type == "pull_request":
action = event_data.get("action", "")
if action in ["opened", "synchronize"]:
# Check if any dependency files are modified
files = event_data.get("files", [])
for f in files:
if self._is_dependency_file(f.get("filename", "")):
return True
# Handle @codebot check-deps command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} check-deps" in comment_body.lower():
return True
return False
def _is_dependency_file(self, filename: str) -> bool:
"""Check if a file is a dependency file."""
basename = os.path.basename(filename)
for lang, files in self.DEPENDENCY_FILES.items():
for pattern in files:
if pattern.startswith("*"):
if basename.endswith(pattern[1:]):
return True
elif basename == pattern:
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the dependency agent."""
self.logger.info(f"Scanning dependencies for {context.owner}/{context.repo}")
actions_taken = []
# Determine if this is a command or PR event
if context.event_type == "issue_comment":
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {})
.get("user", {})
.get("login", "user")
)
else:
pr = context.event_data.get("pull_request", {})
issue_number = pr.get("number")
comment_author = None
# Collect dependency files
dep_files = self._collect_dependency_files(context.owner, context.repo)
if not dep_files:
message = "No dependency files found in repository."
if issue_number:
self.gitea.create_issue_comment(
context.owner,
context.repo,
issue_number,
f"{self.AI_DISCLAIMER}\n\n{message}",
)
return AgentResult(
success=True,
message=message,
)
actions_taken.append(f"Found {len(dep_files)} dependency files")
# Analyze dependencies
report = self._analyze_dependencies(context.owner, context.repo, dep_files)
actions_taken.append(f"Analyzed {report.total_packages} packages")
# Run external scanners if available
external_findings = self._run_external_scanners(context.owner, context.repo)
if external_findings:
report.findings.extend(external_findings)
actions_taken.append(
f"External scanner found {len(external_findings)} issues"
)
# Generate and post report
if issue_number:
comment = self._format_dependency_report(report, comment_author)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.DEP_AI_MARKER,
)
actions_taken.append("Posted dependency report")
return AgentResult(
success=True,
message=f"Dependency scan complete: {report.vulnerable_packages} vulnerable, {report.outdated_packages} outdated",
data={
"total_packages": report.total_packages,
"vulnerable_packages": report.vulnerable_packages,
"outdated_packages": report.outdated_packages,
"findings_count": len(report.findings),
},
actions_taken=actions_taken,
)
def _collect_dependency_files(
self, owner: str, repo: str
) -> dict[str, dict[str, Any]]:
"""Collect all dependency files from the repository."""
dep_files = {}
# Common paths to check
paths_to_check = [
"", # Root
"backend/",
"frontend/",
"api/",
"services/",
]
for base_path in paths_to_check:
for lang, filenames in self.DEPENDENCY_FILES.items():
for filename in filenames:
if filename.startswith("*"):
continue # Skip glob patterns for now
filepath = f"{base_path}{filename}".lstrip("/")
try:
content_data = self.gitea.get_file_contents(
owner, repo, filepath
)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
dep_files[filepath] = {
"language": lang,
"content": content,
}
except Exception:
pass # File doesn't exist
return dep_files
def _analyze_dependencies(
self, owner: str, repo: str, dep_files: dict
) -> DependencyReport:
"""Analyze dependency files for vulnerabilities."""
findings = []
total_packages = 0
vulnerable_count = 0
outdated_count = 0
recommendations = []
files_scanned = list(dep_files.keys())
for filepath, file_info in dep_files.items():
lang = file_info["language"]
content = file_info["content"]
if lang == "python":
packages = self._parse_python_deps(content, filepath)
elif lang == "javascript":
packages = self._parse_javascript_deps(content, filepath)
else:
packages = []
total_packages += len(packages)
# Check for known vulnerabilities
known_vulns = self.KNOWN_VULNERABILITIES.get(lang, {})
for pkg_name, version in packages:
if pkg_name.lower() in known_vulns:
vuln_info = known_vulns[pkg_name.lower()]
for version_constraint, vuln_desc in vuln_info.items():
if self._version_matches_constraint(
version, version_constraint
):
findings.append(
DependencyFinding(
package=pkg_name,
version=version or "unknown",
severity="HIGH",
vulnerability_id=vuln_desc.split(" - ")[0]
if " - " in vuln_desc
else "VULN",
title=vuln_desc,
description=f"Package {pkg_name} version {version} has known vulnerabilities",
fixed_version=version_constraint.replace("< ", ""),
)
)
vulnerable_count += 1
# Add recommendations
if vulnerable_count > 0:
recommendations.append(
f"Update {vulnerable_count} packages with known vulnerabilities"
)
if total_packages > 50:
recommendations.append(
"Consider auditing dependencies to reduce attack surface"
)
return DependencyReport(
total_packages=total_packages,
vulnerable_packages=vulnerable_count,
outdated_packages=outdated_count,
findings=findings,
recommendations=recommendations,
files_scanned=files_scanned,
)
def _parse_python_deps(
self, content: str, filepath: str
) -> list[tuple[str, str | None]]:
"""Parse Python dependency file."""
packages = []
if "requirements" in filepath.lower():
# requirements.txt format
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#") or line.startswith("-"):
continue
# Parse package==version, package>=version, package
match = re.match(r"([a-zA-Z0-9_-]+)([<>=!]+)?(.+)?", line)
if match:
pkg_name = match.group(1)
version = match.group(3) if match.group(3) else None
packages.append((pkg_name, version))
elif filepath.endswith("pyproject.toml"):
# pyproject.toml format
in_deps = False
for line in content.splitlines():
if (
"[project.dependencies]" in line
or "[tool.poetry.dependencies]" in line
):
in_deps = True
continue
if in_deps:
if line.startswith("["):
in_deps = False
continue
match = re.match(r'"?([a-zA-Z0-9_-]+)"?\s*[=<>]', line)
if match:
packages.append((match.group(1), None))
return packages
def _parse_javascript_deps(
self, content: str, filepath: str
) -> list[tuple[str, str | None]]:
"""Parse JavaScript dependency file."""
packages = []
if filepath.endswith("package.json"):
try:
data = json.loads(content)
for dep_type in ["dependencies", "devDependencies"]:
deps = data.get(dep_type, {})
for name, version in deps.items():
# Strip version prefixes like ^, ~, >=
clean_version = re.sub(r"^[\^~>=<]+", "", version)
packages.append((name, clean_version))
except json.JSONDecodeError:
pass
return packages
def _version_matches_constraint(self, version: str | None, constraint: str) -> bool:
"""Check if version matches a vulnerability constraint."""
if not version:
return True # Assume vulnerable if version unknown
# Simple version comparison
if constraint.startswith("< "):
target = constraint[2:]
try:
return self._compare_versions(version, target) < 0
except Exception:
return False
return False
def _compare_versions(self, v1: str, v2: str) -> int:
"""Compare two version strings. Returns -1, 0, or 1."""
def normalize(v):
return [int(x) for x in re.sub(r"[^0-9.]", "", v).split(".") if x]
try:
parts1 = normalize(v1)
parts2 = normalize(v2)
for i in range(max(len(parts1), len(parts2))):
p1 = parts1[i] if i < len(parts1) else 0
p2 = parts2[i] if i < len(parts2) else 0
if p1 < p2:
return -1
if p1 > p2:
return 1
return 0
except Exception:
return 0
def _run_external_scanners(self, owner: str, repo: str) -> list[DependencyFinding]:
"""Run external vulnerability scanners if available."""
findings = []
agent_config = self.config.get("agents", {}).get("dependency", {})
# Try pip-audit for Python
if agent_config.get("pip_audit", False):
try:
result = subprocess.run(
["pip-audit", "--format", "json"],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode == 0:
data = json.loads(result.stdout)
for vuln in data.get("vulnerabilities", []):
findings.append(
DependencyFinding(
package=vuln.get("name", ""),
version=vuln.get("version", ""),
severity=vuln.get("severity", "MEDIUM"),
vulnerability_id=vuln.get("id", ""),
title=vuln.get("description", "")[:100],
description=vuln.get("description", ""),
fixed_version=vuln.get("fix_versions", [None])[0],
)
)
except Exception as e:
self.logger.debug(f"pip-audit not available: {e}")
# Try npm audit for JavaScript
if agent_config.get("npm_audit", False):
try:
result = subprocess.run(
["npm", "audit", "--json"],
capture_output=True,
text=True,
timeout=60,
)
data = json.loads(result.stdout)
for vuln_id, vuln in data.get("vulnerabilities", {}).items():
findings.append(
DependencyFinding(
package=vuln.get("name", vuln_id),
version=vuln.get("range", ""),
severity=vuln.get("severity", "moderate").upper(),
vulnerability_id=vuln_id,
title=vuln.get("title", ""),
description=vuln.get("overview", ""),
fixed_version=vuln.get("fixAvailable", {}).get("version"),
)
)
except Exception as e:
self.logger.debug(f"npm audit not available: {e}")
return findings
def _format_dependency_report(
self, report: DependencyReport, user: str | None = None
) -> str:
"""Format the dependency report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🔍 Dependency Security Scan",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Total Packages | {report.total_packages} |",
f"| Vulnerable | {report.vulnerable_packages} |",
f"| Outdated | {report.outdated_packages} |",
f"| Files Scanned | {len(report.files_scanned)} |",
"",
]
)
# Findings by severity
if report.findings:
lines.append("### 🚨 Security Findings")
lines.append("")
# Group by severity
by_severity = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []}
for finding in report.findings:
sev = finding.severity.upper()
if sev in by_severity:
by_severity[sev].append(finding)
severity_emoji = {
"CRITICAL": "🔴",
"HIGH": "🟠",
"MEDIUM": "🟡",
"LOW": "🔵",
}
for severity in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]:
findings = by_severity[severity]
if findings:
lines.append(f"#### {severity_emoji[severity]} {severity}")
lines.append("")
for f in findings[:10]: # Limit display
lines.append(f"- **{f.package}** `{f.version}`")
lines.append(f" - {f.vulnerability_id}: {f.title}")
if f.fixed_version:
lines.append(f" - ✅ Fix: Upgrade to `{f.fixed_version}`")
if len(findings) > 10:
lines.append(f" - ... and {len(findings) - 10} more")
lines.append("")
# Files scanned
lines.append("### 📁 Files Scanned")
lines.append("")
for f in report.files_scanned:
lines.append(f"- `{f}`")
lines.append("")
# Recommendations
if report.recommendations:
lines.append("### 💡 Recommendations")
lines.append("")
for rec in report.recommendations:
lines.append(f"- {rec}")
lines.append("")
# Overall status
if report.vulnerable_packages == 0:
lines.append("---")
lines.append("✅ **No known vulnerabilities detected**")
else:
lines.append("---")
lines.append(
f"⚠️ **{report.vulnerable_packages} vulnerable packages require attention**"
)
return "\n".join(lines)

View File

@@ -365,9 +365,20 @@ class IssueAgent(BaseAgent):
"commands", ["explain", "suggest", "security", "summarize", "triage"]
)
# Also check for setup-labels command (not in config since it's a setup command)
if f"{mention_prefix} setup-labels" in body.lower():
return "setup-labels"
# Built-in commands not in config
builtin_commands = [
"setup-labels",
"check-deps",
"suggest-tests",
"refactor-suggest",
"architecture",
"arch-check",
]
# Check built-in commands first
for command in builtin_commands:
if f"{mention_prefix} {command}" in body.lower():
return command
for command in commands:
if f"{mention_prefix} {command}" in body.lower():
@@ -392,6 +403,14 @@ class IssueAgent(BaseAgent):
return self._command_triage(context, issue)
elif command == "setup-labels":
return self._command_setup_labels(context, issue)
elif command == "check-deps":
return self._command_check_deps(context)
elif command == "suggest-tests":
return self._command_suggest_tests(context)
elif command == "refactor-suggest":
return self._command_refactor_suggest(context)
elif command == "architecture" or command == "arch-check":
return self._command_architecture(context)
return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`."
@@ -464,6 +483,12 @@ Be practical and concise."""
- `{mention_prefix} suggest` - Solution suggestions or next steps
- `{mention_prefix} security` - Security-focused analysis of the issue
### Code Quality & Security
- `{mention_prefix} check-deps` - Scan dependencies for security vulnerabilities
- `{mention_prefix} suggest-tests` - Suggest test cases for changed/new code
- `{mention_prefix} refactor-suggest` - Suggest refactoring opportunities
- `{mention_prefix} architecture` - Check architecture compliance (alias: `arch-check`)
### Interactive Chat
- `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools)
- Example: `{mention_prefix} how does authentication work?`
@@ -494,9 +519,19 @@ PR reviews run automatically when you open or update a pull request. The bot pro
{mention_prefix} triage
```
**Get help understanding:**
**Check for dependency vulnerabilities:**
```
{mention_prefix} explain
{mention_prefix} check-deps
```
**Get test suggestions:**
```
{mention_prefix} suggest-tests
```
**Check architecture compliance:**
```
{mention_prefix} architecture
```
**Ask about the codebase:**
@@ -504,11 +539,6 @@ PR reviews run automatically when you open or update a pull request. The bot pro
{mention_prefix} how does the authentication system work?
```
**Setup repository labels:**
```
{mention_prefix} setup-labels
```
---
*For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)*
@@ -854,3 +884,145 @@ PR reviews run automatically when you open or update a pull request. The bot pro
return f"{prefix} - {value}"
else: # colon or unknown
return base_name
def _command_check_deps(self, context: AgentContext) -> str:
"""Check dependencies for security vulnerabilities."""
try:
from agents.dependency_agent import DependencyAgent
agent = DependencyAgent(config=self.config)
result = agent.run(context)
if result.success:
return result.data.get(
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
)
else:
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Failed**\n\n{result.error or result.message}"
except ImportError:
return f"{self.AI_DISCLAIMER}\n\n**Dependency Agent Not Available**\n\nThe dependency security agent is not installed."
except Exception as e:
self.logger.error(f"Dependency check failed: {e}")
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Error**\n\n{e}"
def _command_suggest_tests(self, context: AgentContext) -> str:
"""Suggest tests for changed or new code."""
try:
from agents.test_coverage_agent import TestCoverageAgent
agent = TestCoverageAgent(config=self.config)
result = agent.run(context)
if result.success:
return result.data.get(
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
)
else:
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Failed**\n\n{result.error or result.message}"
except ImportError:
return f"{self.AI_DISCLAIMER}\n\n**Test Coverage Agent Not Available**\n\nThe test coverage agent is not installed."
except Exception as e:
self.logger.error(f"Test suggestion failed: {e}")
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Error**\n\n{e}"
def _command_architecture(self, context: AgentContext) -> str:
"""Check architecture compliance."""
try:
from agents.architecture_agent import ArchitectureAgent
agent = ArchitectureAgent(config=self.config)
result = agent.run(context)
if result.success:
return result.data.get(
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
)
else:
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Failed**\n\n{result.error or result.message}"
except ImportError:
return f"{self.AI_DISCLAIMER}\n\n**Architecture Agent Not Available**\n\nThe architecture compliance agent is not installed."
except Exception as e:
self.logger.error(f"Architecture check failed: {e}")
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Error**\n\n{e}"
def _command_refactor_suggest(self, context: AgentContext) -> str:
"""Suggest refactoring opportunities."""
issue = context.event_data.get("issue", {})
title = issue.get("title", "")
body = issue.get("body", "")
# Use LLM to analyze for refactoring opportunities
prompt = f"""Analyze the following issue/context and suggest refactoring opportunities:
Issue Title: {title}
Issue Body: {body}
Based on common refactoring patterns, suggest:
1. Code smell detection (if any code is referenced)
2. Design pattern opportunities
3. Complexity reduction suggestions
4. DRY principle violations
5. SOLID principle improvements
Format your response as a structured report with actionable recommendations.
If no code is referenced in the issue, provide general refactoring guidance based on the context.
Return as JSON:
{{
"summary": "Brief summary of refactoring opportunities",
"suggestions": [
{{
"category": "Code Smell | Design Pattern | Complexity | DRY | SOLID",
"title": "Short title",
"description": "Detailed description",
"priority": "high | medium | low",
"effort": "small | medium | large"
}}
],
"general_advice": "Any general refactoring advice"
}}"""
try:
result = self.call_llm_json(prompt)
lines = [f"{self.AI_DISCLAIMER}\n"]
lines.append("## Refactoring Suggestions\n")
if result.get("summary"):
lines.append(f"**Summary:** {result['summary']}\n")
suggestions = result.get("suggestions", [])
if suggestions:
lines.append("### Recommendations\n")
lines.append("| Priority | Category | Suggestion | Effort |")
lines.append("|----------|----------|------------|--------|")
for s in suggestions:
priority = s.get("priority", "medium").upper()
priority_icon = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🟢"}.get(
priority, ""
)
lines.append(
f"| {priority_icon} {priority} | {s.get('category', 'General')} | "
f"**{s.get('title', 'Suggestion')}** | {s.get('effort', 'medium')} |"
)
lines.append("")
# Detailed descriptions
lines.append("### Details\n")
for i, s in enumerate(suggestions, 1):
lines.append(f"**{i}. {s.get('title', 'Suggestion')}**")
lines.append(f"{s.get('description', 'No description')}\n")
if result.get("general_advice"):
lines.append("### General Advice\n")
lines.append(result["general_advice"])
return "\n".join(lines)
except Exception as e:
self.logger.error(f"Refactor suggestion failed: {e}")
return (
f"{self.AI_DISCLAIMER}\n\n**Refactor Suggestion Failed**\n\nError: {e}"
)

View File

@@ -0,0 +1,480 @@
"""Test Coverage Agent
AI agent for analyzing code changes and suggesting test cases.
Helps improve test coverage by identifying untested code paths.
"""
import base64
import os
import re
from dataclasses import dataclass, field
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class TestSuggestion:
"""A suggested test case."""
function_name: str
file_path: str
test_type: str # unit, integration, edge_case
description: str
example_code: str | None = None
priority: str = "MEDIUM" # HIGH, MEDIUM, LOW
@dataclass
class CoverageReport:
"""Report of test coverage analysis."""
functions_analyzed: int
functions_with_tests: int
functions_without_tests: int
suggestions: list[TestSuggestion]
existing_tests: list[str]
coverage_estimate: float
class TestCoverageAgent(BaseAgent):
"""Agent for analyzing test coverage and suggesting tests."""
# Marker for test coverage comments
TEST_AI_MARKER = "<!-- AI_TEST_COVERAGE -->"
# Test file patterns by language
TEST_PATTERNS = {
"python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"],
"javascript": [
r".*\.test\.[jt]sx?$",
r".*\.spec\.[jt]sx?$",
r"__tests__/.*\.[jt]sx?$",
],
"go": [r".*_test\.go$"],
"rust": [r"tests?/.*\.rs$"],
"java": [r".*Test\.java$", r".*Tests\.java$"],
"ruby": [r".*_spec\.rb$", r"test_.*\.rb$"],
}
# Function/method patterns by language
FUNCTION_PATTERNS = {
"python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(",
"javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))",
"go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(",
"rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)",
"java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{",
"ruby": r"^\s*def\s+(\w+)",
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("test_coverage", {})
if not agent_config.get("enabled", True):
return False
# Handle @codebot suggest-tests command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} suggest-tests" in comment_body.lower():
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the test coverage agent."""
self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}")
actions_taken = []
# Get issue/PR number and author
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {}).get("user", {}).get("login", "user")
)
# Check if this is a PR
is_pr = issue.get("pull_request") is not None
if is_pr:
# Analyze PR diff for changed functions
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
changed_functions = self._extract_changed_functions(diff)
actions_taken.append(f"Analyzed {len(changed_functions)} changed functions")
else:
# Analyze entire repository
changed_functions = self._analyze_repository(context.owner, context.repo)
actions_taken.append(
f"Analyzed {len(changed_functions)} functions in repository"
)
# Find existing tests
existing_tests = self._find_existing_tests(context.owner, context.repo)
actions_taken.append(f"Found {len(existing_tests)} existing test files")
# Generate test suggestions using LLM
report = self._generate_suggestions(
context.owner, context.repo, changed_functions, existing_tests
)
# Post report
if issue_number:
comment = self._format_coverage_report(report, comment_author, is_pr)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.TEST_AI_MARKER,
)
actions_taken.append("Posted test coverage report")
return AgentResult(
success=True,
message=f"Generated {len(report.suggestions)} test suggestions",
data={
"functions_analyzed": report.functions_analyzed,
"suggestions_count": len(report.suggestions),
"coverage_estimate": report.coverage_estimate,
},
actions_taken=actions_taken,
)
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
"""Get the PR diff."""
try:
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
except Exception as e:
self.logger.error(f"Failed to get PR diff: {e}")
return ""
def _extract_changed_functions(self, diff: str) -> list[dict]:
"""Extract changed functions from diff."""
functions = []
current_file = None
current_language = None
for line in diff.splitlines():
# Track current file
if line.startswith("diff --git"):
match = re.search(r"b/(.+)$", line)
if match:
current_file = match.group(1)
current_language = self._detect_language(current_file)
# Look for function definitions in added lines
if line.startswith("+") and not line.startswith("+++"):
if current_language and current_language in self.FUNCTION_PATTERNS:
pattern = self.FUNCTION_PATTERNS[current_language]
match = re.search(pattern, line[1:]) # Skip the + prefix
if match:
func_name = next(g for g in match.groups() if g)
functions.append(
{
"name": func_name,
"file": current_file,
"language": current_language,
"line": line[1:].strip(),
}
)
return functions
def _analyze_repository(self, owner: str, repo: str) -> list[dict]:
"""Analyze repository for functions without tests."""
functions = []
code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"}
# Get repository contents (limited to avoid API exhaustion)
try:
contents = self.gitea.get_file_contents(owner, repo, "")
if isinstance(contents, list):
for item in contents[:50]: # Limit files
if item.get("type") == "file":
filepath = item.get("path", "")
ext = os.path.splitext(filepath)[1]
if ext in code_extensions:
file_functions = self._extract_functions_from_file(
owner, repo, filepath
)
functions.extend(file_functions)
except Exception as e:
self.logger.warning(f"Failed to analyze repository: {e}")
return functions[:100] # Limit total functions
def _extract_functions_from_file(
self, owner: str, repo: str, filepath: str
) -> list[dict]:
"""Extract function definitions from a file."""
functions = []
language = self._detect_language(filepath)
if not language or language not in self.FUNCTION_PATTERNS:
return functions
try:
content_data = self.gitea.get_file_contents(owner, repo, filepath)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
pattern = self.FUNCTION_PATTERNS[language]
for i, line in enumerate(content.splitlines(), 1):
match = re.search(pattern, line)
if match:
func_name = next((g for g in match.groups() if g), None)
if func_name and not func_name.startswith("_"):
functions.append(
{
"name": func_name,
"file": filepath,
"language": language,
"line_number": i,
}
)
except Exception:
pass
return functions
def _detect_language(self, filepath: str) -> str | None:
"""Detect programming language from file path."""
ext_map = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "javascript",
".tsx": "javascript",
".go": "go",
".rs": "rust",
".java": "java",
".rb": "ruby",
}
ext = os.path.splitext(filepath)[1]
return ext_map.get(ext)
def _find_existing_tests(self, owner: str, repo: str) -> list[str]:
"""Find existing test files in the repository."""
test_files = []
# Common test directories
test_dirs = ["tests", "test", "__tests__", "spec"]
for test_dir in test_dirs:
try:
contents = self.gitea.get_file_contents(owner, repo, test_dir)
if isinstance(contents, list):
for item in contents:
if item.get("type") == "file":
test_files.append(item.get("path", ""))
except Exception:
pass
# Also check root for test files
try:
contents = self.gitea.get_file_contents(owner, repo, "")
if isinstance(contents, list):
for item in contents:
if item.get("type") == "file":
filepath = item.get("path", "")
if self._is_test_file(filepath):
test_files.append(filepath)
except Exception:
pass
return test_files
def _is_test_file(self, filepath: str) -> bool:
"""Check if a file is a test file."""
for lang, patterns in self.TEST_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, filepath):
return True
return False
def _generate_suggestions(
self,
owner: str,
repo: str,
functions: list[dict],
existing_tests: list[str],
) -> CoverageReport:
"""Generate test suggestions using LLM."""
suggestions = []
# Build prompt for LLM
if functions:
functions_text = "\n".join(
[
f"- {f['name']} in {f['file']} ({f['language']})"
for f in functions[:20] # Limit for prompt size
]
)
prompt = f"""Analyze these functions and suggest test cases:
Functions to test:
{functions_text}
Existing test files:
{", ".join(existing_tests[:10]) if existing_tests else "None found"}
For each function, suggest:
1. What to test (happy path, edge cases, error handling)
2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities)
3. Brief example test code if possible
Respond in JSON format:
{{
"suggestions": [
{{
"function_name": "function_name",
"file_path": "path/to/file",
"test_type": "unit|integration|edge_case",
"description": "What to test",
"example_code": "brief example or null",
"priority": "HIGH|MEDIUM|LOW"
}}
],
"coverage_estimate": 0.0 to 1.0
}}
"""
try:
result = self.call_llm_json(prompt)
for s in result.get("suggestions", []):
suggestions.append(
TestSuggestion(
function_name=s.get("function_name", ""),
file_path=s.get("file_path", ""),
test_type=s.get("test_type", "unit"),
description=s.get("description", ""),
example_code=s.get("example_code"),
priority=s.get("priority", "MEDIUM"),
)
)
coverage_estimate = result.get("coverage_estimate", 0.5)
except Exception as e:
self.logger.warning(f"LLM suggestion failed: {e}")
# Generate basic suggestions without LLM
for f in functions[:10]:
suggestions.append(
TestSuggestion(
function_name=f["name"],
file_path=f["file"],
test_type="unit",
description=f"Add unit tests for {f['name']}",
priority="MEDIUM",
)
)
coverage_estimate = 0.5
else:
coverage_estimate = 1.0 if existing_tests else 0.0
# Estimate functions with tests
functions_with_tests = int(len(functions) * coverage_estimate)
return CoverageReport(
functions_analyzed=len(functions),
functions_with_tests=functions_with_tests,
functions_without_tests=len(functions) - functions_with_tests,
suggestions=suggestions,
existing_tests=existing_tests,
coverage_estimate=coverage_estimate,
)
def _format_coverage_report(
self, report: CoverageReport, user: str | None, is_pr: bool
) -> str:
"""Format the coverage report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🧪 Test Coverage Analysis",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Functions Analyzed | {report.functions_analyzed} |",
f"| Estimated Coverage | {report.coverage_estimate:.0%} |",
f"| Test Files Found | {len(report.existing_tests)} |",
f"| Suggestions | {len(report.suggestions)} |",
"",
]
)
# Suggestions by priority
if report.suggestions:
lines.append("### 💡 Test Suggestions")
lines.append("")
# Group by priority
by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []}
for s in report.suggestions:
if s.priority in by_priority:
by_priority[s.priority].append(s)
priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
for priority in ["HIGH", "MEDIUM", "LOW"]:
suggestions = by_priority[priority]
if suggestions:
lines.append(f"#### {priority_emoji[priority]} {priority} Priority")
lines.append("")
for s in suggestions[:5]: # Limit display
lines.append(f"**`{s.function_name}`** in `{s.file_path}`")
lines.append(f"- Type: {s.test_type}")
lines.append(f"- {s.description}")
if s.example_code:
lines.append(f"```")
lines.append(s.example_code[:200])
lines.append(f"```")
lines.append("")
if len(suggestions) > 5:
lines.append(f"*... and {len(suggestions) - 5} more*")
lines.append("")
# Existing test files
if report.existing_tests:
lines.append("### 📁 Existing Test Files")
lines.append("")
for f in report.existing_tests[:10]:
lines.append(f"- `{f}`")
if len(report.existing_tests) > 10:
lines.append(f"- *... and {len(report.existing_tests) - 10} more*")
lines.append("")
# Coverage bar
lines.append("### 📊 Coverage Estimate")
lines.append("")
filled = int(report.coverage_estimate * 10)
bar = "" * filled + "" * (10 - filled)
lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}")
lines.append("")
# Recommendations
if report.coverage_estimate < 0.8:
lines.append("---")
lines.append("⚠️ **Coverage below 80%** - Consider adding more tests")
elif report.coverage_estimate >= 0.8:
lines.append("---")
lines.append("✅ **Good test coverage!**")
return "\n".join(lines)