Files
openrabbit/tools/ai-review/agents/architecture_agent.py
latte e8d28225e0
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
just why not
2026-01-07 21:19:46 +01:00

548 lines
20 KiB
Python

"""Architecture Compliance Agent
AI agent for enforcing architectural patterns and layer separation.
Detects cross-layer violations and circular dependencies.
"""
import base64
import os
import re
from dataclasses import dataclass, field
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class ArchitectureViolation:
"""An architecture violation."""
file: str
line: int
violation_type: str # cross_layer, circular, naming, structure
severity: str # HIGH, MEDIUM, LOW
description: str
recommendation: str
source_layer: str | None = None
target_layer: str | None = None
@dataclass
class ArchitectureReport:
"""Report of architecture analysis."""
violations: list[ArchitectureViolation]
layers_detected: dict[str, list[str]]
circular_dependencies: list[tuple[str, str]]
compliance_score: float
recommendations: list[str]
class ArchitectureAgent(BaseAgent):
"""Agent for enforcing architectural compliance."""
# Marker for architecture comments
ARCH_AI_MARKER = "<!-- AI_ARCHITECTURE_CHECK -->"
# Default layer definitions
DEFAULT_LAYERS = {
"api": {
"patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"],
"can_import": ["services", "models", "utils", "config"],
"cannot_import": ["db", "repositories", "infrastructure"],
},
"services": {
"patterns": ["services/", "usecases/", "application/"],
"can_import": ["models", "repositories", "utils", "config"],
"cannot_import": ["api", "controllers", "handlers"],
},
"repositories": {
"patterns": ["repositories/", "repos/", "data/"],
"can_import": ["models", "db", "utils", "config"],
"cannot_import": ["api", "services", "controllers"],
},
"models": {
"patterns": ["models/", "entities/", "domain/", "schemas/"],
"can_import": ["utils", "config"],
"cannot_import": ["api", "services", "repositories", "db"],
},
"db": {
"patterns": ["db/", "database/", "infrastructure/"],
"can_import": ["models", "config"],
"cannot_import": ["api", "services"],
},
"utils": {
"patterns": ["utils/", "helpers/", "common/", "lib/"],
"can_import": ["config"],
"cannot_import": [],
},
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("architecture", {})
if not agent_config.get("enabled", False):
return False
# Handle PR events
if event_type == "pull_request":
action = event_data.get("action", "")
if action in ["opened", "synchronize"]:
return True
# Handle @codebot architecture command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} architecture" in comment_body.lower():
return True
if f"{mention_prefix} arch-check" in comment_body.lower():
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the architecture agent."""
self.logger.info(f"Checking architecture for {context.owner}/{context.repo}")
actions_taken = []
# Get layer configuration
agent_config = self.config.get("agents", {}).get("architecture", {})
layers = agent_config.get("layers", self.DEFAULT_LAYERS)
# Determine issue number
if context.event_type == "issue_comment":
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {})
.get("user", {})
.get("login", "user")
)
is_pr = issue.get("pull_request") is not None
else:
pr = context.event_data.get("pull_request", {})
issue_number = pr.get("number")
comment_author = None
is_pr = True
if is_pr and issue_number:
# Analyze PR diff
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
report = self._analyze_diff(diff, layers)
actions_taken.append(f"Analyzed PR diff for architecture violations")
else:
# Analyze repository structure
report = self._analyze_repository(context.owner, context.repo, layers)
actions_taken.append(f"Analyzed repository architecture")
# Post report
if issue_number:
comment = self._format_architecture_report(report, comment_author)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.ARCH_AI_MARKER,
)
actions_taken.append("Posted architecture report")
return AgentResult(
success=len(report.violations) == 0 or report.compliance_score >= 0.8,
message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance",
data={
"violations_count": len(report.violations),
"compliance_score": report.compliance_score,
"circular_dependencies": len(report.circular_dependencies),
},
actions_taken=actions_taken,
)
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
"""Get the PR diff."""
try:
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
except Exception as e:
self.logger.error(f"Failed to get PR diff: {e}")
return ""
def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport:
"""Analyze PR diff for architecture violations."""
violations = []
imports_by_file = {}
current_file = None
current_language = None
for line in diff.splitlines():
# Track current file
if line.startswith("diff --git"):
match = re.search(r"b/(.+)$", line)
if match:
current_file = match.group(1)
current_language = self._detect_language(current_file)
imports_by_file[current_file] = []
# Look for import statements in added lines
if line.startswith("+") and not line.startswith("+++"):
if current_file and current_language:
imports = self._extract_imports(line[1:], current_language)
imports_by_file.setdefault(current_file, []).extend(imports)
# Check for violations
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
layer_config = layers.get(source_layer, {})
cannot_import = layer_config.get("cannot_import", [])
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer in cannot_import:
violations.append(
ArchitectureViolation(
file=file_path,
line=0, # Line number not tracked in this simple implementation
violation_type="cross_layer",
severity="HIGH",
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
recommendation=f"Move this import to an allowed layer or refactor the dependency",
source_layer=source_layer,
target_layer=target_layer,
)
)
# Detect circular dependencies
circular = self._detect_circular_dependencies(imports_by_file, layers)
# Calculate compliance score
total_imports = sum(len(imps) for imps in imports_by_file.values())
if total_imports > 0:
compliance = 1.0 - (len(violations) / max(total_imports, 1))
else:
compliance = 1.0
return ArchitectureReport(
violations=violations,
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
circular_dependencies=circular,
compliance_score=max(0.0, compliance),
recommendations=self._generate_recommendations(violations),
)
def _analyze_repository(
self, owner: str, repo: str, layers: dict
) -> ArchitectureReport:
"""Analyze repository structure for architecture compliance."""
violations = []
imports_by_file = {}
# Collect files from each layer
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
try:
path = pattern.rstrip("/")
contents = self.gitea.get_file_contents(owner, repo, path)
if isinstance(contents, list):
for item in contents[:20]: # Limit files per layer
if item.get("type") == "file":
filepath = item.get("path", "")
imports = self._get_file_imports(owner, repo, filepath)
imports_by_file[filepath] = imports
except Exception:
pass
# Check for violations
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
layer_config = layers.get(source_layer, {})
cannot_import = layer_config.get("cannot_import", [])
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer in cannot_import:
violations.append(
ArchitectureViolation(
file=file_path,
line=0,
violation_type="cross_layer",
severity="HIGH",
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
recommendation=f"Refactor to remove dependency on '{target_layer}'",
source_layer=source_layer,
target_layer=target_layer,
)
)
# Detect circular dependencies
circular = self._detect_circular_dependencies(imports_by_file, layers)
# Calculate compliance
total_imports = sum(len(imps) for imps in imports_by_file.values())
if total_imports > 0:
compliance = 1.0 - (len(violations) / max(total_imports, 1))
else:
compliance = 1.0
return ArchitectureReport(
violations=violations,
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
circular_dependencies=circular,
compliance_score=max(0.0, compliance),
recommendations=self._generate_recommendations(violations),
)
def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]:
"""Get imports from a file."""
imports = []
language = self._detect_language(filepath)
if not language:
return imports
try:
content_data = self.gitea.get_file_contents(owner, repo, filepath)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
for line in content.splitlines():
imports.extend(self._extract_imports(line, language))
except Exception:
pass
return imports
def _detect_language(self, filepath: str) -> str | None:
"""Detect programming language from file path."""
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".go": "go",
".java": "java",
".rb": "ruby",
}
ext = os.path.splitext(filepath)[1]
return ext_map.get(ext)
def _extract_imports(self, line: str, language: str) -> list[str]:
"""Extract import statements from a line of code."""
imports = []
line = line.strip()
if language == "python":
# from x import y, import x
match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line)
if match:
imp = match.group(1) or match.group(2)
if imp:
imports.append(imp.split(".")[0])
elif language in ("javascript", "typescript"):
# import x from 'y', require('y')
match = re.search(
r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line
)
if match:
imp = match.group(1) or match.group(2)
if imp and not imp.startswith("."):
imports.append(imp.split("/")[0])
elif imp:
imports.append(imp)
elif language == "go":
# import "package"
match = re.search(r'import\s+["\']([^"\']+)["\']', line)
if match:
imports.append(match.group(1).split("/")[-1])
elif language == "java":
# import package.Class
match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line)
if match:
parts = match.group(1).split(".")
if len(parts) > 1:
imports.append(parts[-2]) # Package name
return imports
def _detect_layer(self, filepath: str, layers: dict) -> str | None:
"""Detect which layer a file belongs to."""
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
if pattern.rstrip("/") in filepath:
return layer_name
return None
def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None:
"""Detect which layer an import refers to."""
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
pattern_name = pattern.rstrip("/").split("/")[-1]
if pattern_name in import_path or import_path.startswith(pattern_name):
return layer_name
return None
def _detect_circular_dependencies(
self, imports_by_file: dict, layers: dict
) -> list[tuple[str, str]]:
"""Detect circular dependencies between layers."""
circular = []
# Build layer dependency graph
layer_deps = {}
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
if source_layer not in layer_deps:
layer_deps[source_layer] = set()
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer != source_layer:
layer_deps[source_layer].add(target_layer)
# Check for circular dependencies
for layer_a, deps_a in layer_deps.items():
for layer_b in deps_a:
if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()):
pair = tuple(sorted([layer_a, layer_b]))
if pair not in circular:
circular.append(pair)
return circular
def _group_files_by_layer(
self, files: list[str], layers: dict
) -> dict[str, list[str]]:
"""Group files by their layer."""
grouped = {}
for filepath in files:
layer = self._detect_layer(filepath, layers)
if layer:
if layer not in grouped:
grouped[layer] = []
grouped[layer].append(filepath)
return grouped
def _generate_recommendations(
self, violations: list[ArchitectureViolation]
) -> list[str]:
"""Generate recommendations based on violations."""
recommendations = []
# Count violations by type
cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer")
if cross_layer > 0:
recommendations.append(
f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces"
)
if cross_layer > 5:
recommendations.append(
"Consider using dependency injection to reduce coupling between layers"
)
return recommendations
def _format_architecture_report(
self, report: ArchitectureReport, user: str | None
) -> str:
"""Format the architecture report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🏗️ Architecture Compliance Check",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Compliance Score | {report.compliance_score:.0%} |",
f"| Violations | {len(report.violations)} |",
f"| Circular Dependencies | {len(report.circular_dependencies)} |",
f"| Layers Detected | {len(report.layers_detected)} |",
"",
]
)
# Compliance bar
filled = int(report.compliance_score * 10)
bar = "" * filled + "" * (10 - filled)
lines.append(f"`[{bar}]` {report.compliance_score:.0%}")
lines.append("")
# Violations
if report.violations:
lines.append("### 🚨 Violations")
lines.append("")
for v in report.violations[:10]: # Limit display
severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
lines.append(
f"{severity_emoji.get(v.severity, '')} **{v.violation_type.upper()}** in `{v.file}`"
)
lines.append(f" - {v.description}")
lines.append(f" - 💡 {v.recommendation}")
lines.append("")
if len(report.violations) > 10:
lines.append(f"*... and {len(report.violations) - 10} more violations*")
lines.append("")
# Circular dependencies
if report.circular_dependencies:
lines.append("### 🔄 Circular Dependencies")
lines.append("")
for a, b in report.circular_dependencies:
lines.append(f"- `{a}` ↔ `{b}`")
lines.append("")
# Layers detected
if report.layers_detected:
lines.append("### 📁 Layers Detected")
lines.append("")
for layer, files in report.layers_detected.items():
lines.append(f"- **{layer}**: {len(files)} files")
lines.append("")
# Recommendations
if report.recommendations:
lines.append("### 💡 Recommendations")
lines.append("")
for rec in report.recommendations:
lines.append(f"- {rec}")
lines.append("")
# Overall status
if report.compliance_score >= 0.9:
lines.append("---")
lines.append("✅ **Excellent architecture compliance!**")
elif report.compliance_score >= 0.7:
lines.append("---")
lines.append("⚠️ **Some architectural issues to address**")
else:
lines.append("---")
lines.append("❌ **Significant architectural violations detected**")
return "\n".join(lines)