just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
547
tools/ai-review/agents/architecture_agent.py
Normal file
547
tools/ai-review/agents/architecture_agent.py
Normal file
@@ -0,0 +1,547 @@
|
||||
"""Architecture Compliance Agent
|
||||
|
||||
AI agent for enforcing architectural patterns and layer separation.
|
||||
Detects cross-layer violations and circular dependencies.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchitectureViolation:
|
||||
"""An architecture violation."""
|
||||
|
||||
file: str
|
||||
line: int
|
||||
violation_type: str # cross_layer, circular, naming, structure
|
||||
severity: str # HIGH, MEDIUM, LOW
|
||||
description: str
|
||||
recommendation: str
|
||||
source_layer: str | None = None
|
||||
target_layer: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchitectureReport:
|
||||
"""Report of architecture analysis."""
|
||||
|
||||
violations: list[ArchitectureViolation]
|
||||
layers_detected: dict[str, list[str]]
|
||||
circular_dependencies: list[tuple[str, str]]
|
||||
compliance_score: float
|
||||
recommendations: list[str]
|
||||
|
||||
|
||||
class ArchitectureAgent(BaseAgent):
|
||||
"""Agent for enforcing architectural compliance."""
|
||||
|
||||
# Marker for architecture comments
|
||||
ARCH_AI_MARKER = "<!-- AI_ARCHITECTURE_CHECK -->"
|
||||
|
||||
# Default layer definitions
|
||||
DEFAULT_LAYERS = {
|
||||
"api": {
|
||||
"patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"],
|
||||
"can_import": ["services", "models", "utils", "config"],
|
||||
"cannot_import": ["db", "repositories", "infrastructure"],
|
||||
},
|
||||
"services": {
|
||||
"patterns": ["services/", "usecases/", "application/"],
|
||||
"can_import": ["models", "repositories", "utils", "config"],
|
||||
"cannot_import": ["api", "controllers", "handlers"],
|
||||
},
|
||||
"repositories": {
|
||||
"patterns": ["repositories/", "repos/", "data/"],
|
||||
"can_import": ["models", "db", "utils", "config"],
|
||||
"cannot_import": ["api", "services", "controllers"],
|
||||
},
|
||||
"models": {
|
||||
"patterns": ["models/", "entities/", "domain/", "schemas/"],
|
||||
"can_import": ["utils", "config"],
|
||||
"cannot_import": ["api", "services", "repositories", "db"],
|
||||
},
|
||||
"db": {
|
||||
"patterns": ["db/", "database/", "infrastructure/"],
|
||||
"can_import": ["models", "config"],
|
||||
"cannot_import": ["api", "services"],
|
||||
},
|
||||
"utils": {
|
||||
"patterns": ["utils/", "helpers/", "common/", "lib/"],
|
||||
"can_import": ["config"],
|
||||
"cannot_import": [],
|
||||
},
|
||||
}
|
||||
|
||||
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||
"""Check if this agent handles the given event."""
|
||||
agent_config = self.config.get("agents", {}).get("architecture", {})
|
||||
if not agent_config.get("enabled", False):
|
||||
return False
|
||||
|
||||
# Handle PR events
|
||||
if event_type == "pull_request":
|
||||
action = event_data.get("action", "")
|
||||
if action in ["opened", "synchronize"]:
|
||||
return True
|
||||
|
||||
# Handle @codebot architecture command
|
||||
if event_type == "issue_comment":
|
||||
comment_body = event_data.get("comment", {}).get("body", "")
|
||||
mention_prefix = self.config.get("interaction", {}).get(
|
||||
"mention_prefix", "@codebot"
|
||||
)
|
||||
if f"{mention_prefix} architecture" in comment_body.lower():
|
||||
return True
|
||||
if f"{mention_prefix} arch-check" in comment_body.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def execute(self, context: AgentContext) -> AgentResult:
|
||||
"""Execute the architecture agent."""
|
||||
self.logger.info(f"Checking architecture for {context.owner}/{context.repo}")
|
||||
|
||||
actions_taken = []
|
||||
|
||||
# Get layer configuration
|
||||
agent_config = self.config.get("agents", {}).get("architecture", {})
|
||||
layers = agent_config.get("layers", self.DEFAULT_LAYERS)
|
||||
|
||||
# Determine issue number
|
||||
if context.event_type == "issue_comment":
|
||||
issue = context.event_data.get("issue", {})
|
||||
issue_number = issue.get("number")
|
||||
comment_author = (
|
||||
context.event_data.get("comment", {})
|
||||
.get("user", {})
|
||||
.get("login", "user")
|
||||
)
|
||||
is_pr = issue.get("pull_request") is not None
|
||||
else:
|
||||
pr = context.event_data.get("pull_request", {})
|
||||
issue_number = pr.get("number")
|
||||
comment_author = None
|
||||
is_pr = True
|
||||
|
||||
if is_pr and issue_number:
|
||||
# Analyze PR diff
|
||||
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
|
||||
report = self._analyze_diff(diff, layers)
|
||||
actions_taken.append(f"Analyzed PR diff for architecture violations")
|
||||
else:
|
||||
# Analyze repository structure
|
||||
report = self._analyze_repository(context.owner, context.repo, layers)
|
||||
actions_taken.append(f"Analyzed repository architecture")
|
||||
|
||||
# Post report
|
||||
if issue_number:
|
||||
comment = self._format_architecture_report(report, comment_author)
|
||||
self.upsert_comment(
|
||||
context.owner,
|
||||
context.repo,
|
||||
issue_number,
|
||||
comment,
|
||||
marker=self.ARCH_AI_MARKER,
|
||||
)
|
||||
actions_taken.append("Posted architecture report")
|
||||
|
||||
return AgentResult(
|
||||
success=len(report.violations) == 0 or report.compliance_score >= 0.8,
|
||||
message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance",
|
||||
data={
|
||||
"violations_count": len(report.violations),
|
||||
"compliance_score": report.compliance_score,
|
||||
"circular_dependencies": len(report.circular_dependencies),
|
||||
},
|
||||
actions_taken=actions_taken,
|
||||
)
|
||||
|
||||
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
|
||||
"""Get the PR diff."""
|
||||
try:
|
||||
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get PR diff: {e}")
|
||||
return ""
|
||||
|
||||
def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport:
|
||||
"""Analyze PR diff for architecture violations."""
|
||||
violations = []
|
||||
imports_by_file = {}
|
||||
|
||||
current_file = None
|
||||
current_language = None
|
||||
|
||||
for line in diff.splitlines():
|
||||
# Track current file
|
||||
if line.startswith("diff --git"):
|
||||
match = re.search(r"b/(.+)$", line)
|
||||
if match:
|
||||
current_file = match.group(1)
|
||||
current_language = self._detect_language(current_file)
|
||||
imports_by_file[current_file] = []
|
||||
|
||||
# Look for import statements in added lines
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
if current_file and current_language:
|
||||
imports = self._extract_imports(line[1:], current_language)
|
||||
imports_by_file.setdefault(current_file, []).extend(imports)
|
||||
|
||||
# Check for violations
|
||||
for file_path, imports in imports_by_file.items():
|
||||
source_layer = self._detect_layer(file_path, layers)
|
||||
if not source_layer:
|
||||
continue
|
||||
|
||||
layer_config = layers.get(source_layer, {})
|
||||
cannot_import = layer_config.get("cannot_import", [])
|
||||
|
||||
for imp in imports:
|
||||
target_layer = self._detect_layer_from_import(imp, layers)
|
||||
if target_layer and target_layer in cannot_import:
|
||||
violations.append(
|
||||
ArchitectureViolation(
|
||||
file=file_path,
|
||||
line=0, # Line number not tracked in this simple implementation
|
||||
violation_type="cross_layer",
|
||||
severity="HIGH",
|
||||
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
|
||||
recommendation=f"Move this import to an allowed layer or refactor the dependency",
|
||||
source_layer=source_layer,
|
||||
target_layer=target_layer,
|
||||
)
|
||||
)
|
||||
|
||||
# Detect circular dependencies
|
||||
circular = self._detect_circular_dependencies(imports_by_file, layers)
|
||||
|
||||
# Calculate compliance score
|
||||
total_imports = sum(len(imps) for imps in imports_by_file.values())
|
||||
if total_imports > 0:
|
||||
compliance = 1.0 - (len(violations) / max(total_imports, 1))
|
||||
else:
|
||||
compliance = 1.0
|
||||
|
||||
return ArchitectureReport(
|
||||
violations=violations,
|
||||
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
|
||||
circular_dependencies=circular,
|
||||
compliance_score=max(0.0, compliance),
|
||||
recommendations=self._generate_recommendations(violations),
|
||||
)
|
||||
|
||||
def _analyze_repository(
|
||||
self, owner: str, repo: str, layers: dict
|
||||
) -> ArchitectureReport:
|
||||
"""Analyze repository structure for architecture compliance."""
|
||||
violations = []
|
||||
imports_by_file = {}
|
||||
|
||||
# Collect files from each layer
|
||||
for layer_name, layer_config in layers.items():
|
||||
for pattern in layer_config.get("patterns", []):
|
||||
try:
|
||||
path = pattern.rstrip("/")
|
||||
contents = self.gitea.get_file_contents(owner, repo, path)
|
||||
if isinstance(contents, list):
|
||||
for item in contents[:20]: # Limit files per layer
|
||||
if item.get("type") == "file":
|
||||
filepath = item.get("path", "")
|
||||
imports = self._get_file_imports(owner, repo, filepath)
|
||||
imports_by_file[filepath] = imports
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for violations
|
||||
for file_path, imports in imports_by_file.items():
|
||||
source_layer = self._detect_layer(file_path, layers)
|
||||
if not source_layer:
|
||||
continue
|
||||
|
||||
layer_config = layers.get(source_layer, {})
|
||||
cannot_import = layer_config.get("cannot_import", [])
|
||||
|
||||
for imp in imports:
|
||||
target_layer = self._detect_layer_from_import(imp, layers)
|
||||
if target_layer and target_layer in cannot_import:
|
||||
violations.append(
|
||||
ArchitectureViolation(
|
||||
file=file_path,
|
||||
line=0,
|
||||
violation_type="cross_layer",
|
||||
severity="HIGH",
|
||||
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
|
||||
recommendation=f"Refactor to remove dependency on '{target_layer}'",
|
||||
source_layer=source_layer,
|
||||
target_layer=target_layer,
|
||||
)
|
||||
)
|
||||
|
||||
# Detect circular dependencies
|
||||
circular = self._detect_circular_dependencies(imports_by_file, layers)
|
||||
|
||||
# Calculate compliance
|
||||
total_imports = sum(len(imps) for imps in imports_by_file.values())
|
||||
if total_imports > 0:
|
||||
compliance = 1.0 - (len(violations) / max(total_imports, 1))
|
||||
else:
|
||||
compliance = 1.0
|
||||
|
||||
return ArchitectureReport(
|
||||
violations=violations,
|
||||
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
|
||||
circular_dependencies=circular,
|
||||
compliance_score=max(0.0, compliance),
|
||||
recommendations=self._generate_recommendations(violations),
|
||||
)
|
||||
|
||||
def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]:
|
||||
"""Get imports from a file."""
|
||||
imports = []
|
||||
language = self._detect_language(filepath)
|
||||
|
||||
if not language:
|
||||
return imports
|
||||
|
||||
try:
|
||||
content_data = self.gitea.get_file_contents(owner, repo, filepath)
|
||||
if content_data.get("content"):
|
||||
content = base64.b64decode(content_data["content"]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
for line in content.splitlines():
|
||||
imports.extend(self._extract_imports(line, language))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return imports
|
||||
|
||||
def _detect_language(self, filepath: str) -> str | None:
|
||||
"""Detect programming language from file path."""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".go": "go",
|
||||
".java": "java",
|
||||
".rb": "ruby",
|
||||
}
|
||||
ext = os.path.splitext(filepath)[1]
|
||||
return ext_map.get(ext)
|
||||
|
||||
def _extract_imports(self, line: str, language: str) -> list[str]:
|
||||
"""Extract import statements from a line of code."""
|
||||
imports = []
|
||||
line = line.strip()
|
||||
|
||||
if language == "python":
|
||||
# from x import y, import x
|
||||
match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line)
|
||||
if match:
|
||||
imp = match.group(1) or match.group(2)
|
||||
if imp:
|
||||
imports.append(imp.split(".")[0])
|
||||
|
||||
elif language in ("javascript", "typescript"):
|
||||
# import x from 'y', require('y')
|
||||
match = re.search(
|
||||
r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line
|
||||
)
|
||||
if match:
|
||||
imp = match.group(1) or match.group(2)
|
||||
if imp and not imp.startswith("."):
|
||||
imports.append(imp.split("/")[0])
|
||||
elif imp:
|
||||
imports.append(imp)
|
||||
|
||||
elif language == "go":
|
||||
# import "package"
|
||||
match = re.search(r'import\s+["\']([^"\']+)["\']', line)
|
||||
if match:
|
||||
imports.append(match.group(1).split("/")[-1])
|
||||
|
||||
elif language == "java":
|
||||
# import package.Class
|
||||
match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line)
|
||||
if match:
|
||||
parts = match.group(1).split(".")
|
||||
if len(parts) > 1:
|
||||
imports.append(parts[-2]) # Package name
|
||||
|
||||
return imports
|
||||
|
||||
def _detect_layer(self, filepath: str, layers: dict) -> str | None:
|
||||
"""Detect which layer a file belongs to."""
|
||||
for layer_name, layer_config in layers.items():
|
||||
for pattern in layer_config.get("patterns", []):
|
||||
if pattern.rstrip("/") in filepath:
|
||||
return layer_name
|
||||
return None
|
||||
|
||||
def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None:
|
||||
"""Detect which layer an import refers to."""
|
||||
for layer_name, layer_config in layers.items():
|
||||
for pattern in layer_config.get("patterns", []):
|
||||
pattern_name = pattern.rstrip("/").split("/")[-1]
|
||||
if pattern_name in import_path or import_path.startswith(pattern_name):
|
||||
return layer_name
|
||||
return None
|
||||
|
||||
def _detect_circular_dependencies(
|
||||
self, imports_by_file: dict, layers: dict
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Detect circular dependencies between layers."""
|
||||
circular = []
|
||||
|
||||
# Build layer dependency graph
|
||||
layer_deps = {}
|
||||
for file_path, imports in imports_by_file.items():
|
||||
source_layer = self._detect_layer(file_path, layers)
|
||||
if not source_layer:
|
||||
continue
|
||||
|
||||
if source_layer not in layer_deps:
|
||||
layer_deps[source_layer] = set()
|
||||
|
||||
for imp in imports:
|
||||
target_layer = self._detect_layer_from_import(imp, layers)
|
||||
if target_layer and target_layer != source_layer:
|
||||
layer_deps[source_layer].add(target_layer)
|
||||
|
||||
# Check for circular dependencies
|
||||
for layer_a, deps_a in layer_deps.items():
|
||||
for layer_b in deps_a:
|
||||
if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()):
|
||||
pair = tuple(sorted([layer_a, layer_b]))
|
||||
if pair not in circular:
|
||||
circular.append(pair)
|
||||
|
||||
return circular
|
||||
|
||||
def _group_files_by_layer(
|
||||
self, files: list[str], layers: dict
|
||||
) -> dict[str, list[str]]:
|
||||
"""Group files by their layer."""
|
||||
grouped = {}
|
||||
for filepath in files:
|
||||
layer = self._detect_layer(filepath, layers)
|
||||
if layer:
|
||||
if layer not in grouped:
|
||||
grouped[layer] = []
|
||||
grouped[layer].append(filepath)
|
||||
return grouped
|
||||
|
||||
def _generate_recommendations(
|
||||
self, violations: list[ArchitectureViolation]
|
||||
) -> list[str]:
|
||||
"""Generate recommendations based on violations."""
|
||||
recommendations = []
|
||||
|
||||
# Count violations by type
|
||||
cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer")
|
||||
|
||||
if cross_layer > 0:
|
||||
recommendations.append(
|
||||
f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces"
|
||||
)
|
||||
|
||||
if cross_layer > 5:
|
||||
recommendations.append(
|
||||
"Consider using dependency injection to reduce coupling between layers"
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
def _format_architecture_report(
|
||||
self, report: ArchitectureReport, user: str | None
|
||||
) -> str:
|
||||
"""Format the architecture report as a comment."""
|
||||
lines = []
|
||||
|
||||
if user:
|
||||
lines.append(f"@{user}")
|
||||
lines.append("")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"{self.AI_DISCLAIMER}",
|
||||
"",
|
||||
"## 🏗️ Architecture Compliance Check",
|
||||
"",
|
||||
"### Summary",
|
||||
"",
|
||||
f"| Metric | Value |",
|
||||
f"|--------|-------|",
|
||||
f"| Compliance Score | {report.compliance_score:.0%} |",
|
||||
f"| Violations | {len(report.violations)} |",
|
||||
f"| Circular Dependencies | {len(report.circular_dependencies)} |",
|
||||
f"| Layers Detected | {len(report.layers_detected)} |",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
# Compliance bar
|
||||
filled = int(report.compliance_score * 10)
|
||||
bar = "█" * filled + "░" * (10 - filled)
|
||||
lines.append(f"`[{bar}]` {report.compliance_score:.0%}")
|
||||
lines.append("")
|
||||
|
||||
# Violations
|
||||
if report.violations:
|
||||
lines.append("### 🚨 Violations")
|
||||
lines.append("")
|
||||
|
||||
for v in report.violations[:10]: # Limit display
|
||||
severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
|
||||
lines.append(
|
||||
f"{severity_emoji.get(v.severity, '⚪')} **{v.violation_type.upper()}** in `{v.file}`"
|
||||
)
|
||||
lines.append(f" - {v.description}")
|
||||
lines.append(f" - 💡 {v.recommendation}")
|
||||
lines.append("")
|
||||
|
||||
if len(report.violations) > 10:
|
||||
lines.append(f"*... and {len(report.violations) - 10} more violations*")
|
||||
lines.append("")
|
||||
|
||||
# Circular dependencies
|
||||
if report.circular_dependencies:
|
||||
lines.append("### 🔄 Circular Dependencies")
|
||||
lines.append("")
|
||||
for a, b in report.circular_dependencies:
|
||||
lines.append(f"- `{a}` ↔ `{b}`")
|
||||
lines.append("")
|
||||
|
||||
# Layers detected
|
||||
if report.layers_detected:
|
||||
lines.append("### 📁 Layers Detected")
|
||||
lines.append("")
|
||||
for layer, files in report.layers_detected.items():
|
||||
lines.append(f"- **{layer}**: {len(files)} files")
|
||||
lines.append("")
|
||||
|
||||
# Recommendations
|
||||
if report.recommendations:
|
||||
lines.append("### 💡 Recommendations")
|
||||
lines.append("")
|
||||
for rec in report.recommendations:
|
||||
lines.append(f"- {rec}")
|
||||
lines.append("")
|
||||
|
||||
# Overall status
|
||||
if report.compliance_score >= 0.9:
|
||||
lines.append("---")
|
||||
lines.append("✅ **Excellent architecture compliance!**")
|
||||
elif report.compliance_score >= 0.7:
|
||||
lines.append("---")
|
||||
lines.append("⚠️ **Some architectural issues to address**")
|
||||
else:
|
||||
lines.append("---")
|
||||
lines.append("❌ **Significant architectural violations detected**")
|
||||
|
||||
return "\n".join(lines)
|
||||
Reference in New Issue
Block a user