just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
17
CLAUDE.md
17
CLAUDE.md
@@ -69,11 +69,17 @@ The codebase uses an **agent-based architecture** where specialized agents handl
|
|||||||
- `execute(context)` - Main execution logic
|
- `execute(context)` - Main execution logic
|
||||||
- Returns `AgentResult` with success status, message, data, and actions taken
|
- Returns `AgentResult` with success status, message, data, and actions taken
|
||||||
|
|
||||||
|
**Core Agents:**
|
||||||
- **PRAgent** - Reviews pull requests with inline comments and security scanning
|
- **PRAgent** - Reviews pull requests with inline comments and security scanning
|
||||||
- **IssueAgent** - Triages issues and responds to @ai-bot commands
|
- **IssueAgent** - Triages issues and responds to @codebot commands
|
||||||
- **CodebaseAgent** - Analyzes entire codebase health and tech debt
|
- **CodebaseAgent** - Analyzes entire codebase health and tech debt
|
||||||
- **ChatAgent** - Interactive assistant with tool calling (search_codebase, read_file, search_web)
|
- **ChatAgent** - Interactive assistant with tool calling (search_codebase, read_file, search_web)
|
||||||
|
|
||||||
|
**Specialized Agents:**
|
||||||
|
- **DependencyAgent** - Scans dependencies for security vulnerabilities (Python, JavaScript)
|
||||||
|
- **TestCoverageAgent** - Analyzes code for test coverage gaps and suggests test cases
|
||||||
|
- **ArchitectureAgent** - Enforces layer separation and detects architecture violations
|
||||||
|
|
||||||
3. **Dispatcher** (`dispatcher.py`) - Routes events to appropriate agents:
|
3. **Dispatcher** (`dispatcher.py`) - Routes events to appropriate agents:
|
||||||
- Registers agents at startup
|
- Registers agents at startup
|
||||||
- Determines which agents can handle each event
|
- Determines which agents can handle each event
|
||||||
@@ -84,14 +90,23 @@ The codebase uses an **agent-based architecture** where specialized agents handl
|
|||||||
|
|
||||||
The `LLMClient` (`clients/llm_client.py`) provides a unified interface for multiple LLM providers:
|
The `LLMClient` (`clients/llm_client.py`) provides a unified interface for multiple LLM providers:
|
||||||
|
|
||||||
|
**Core Providers (in llm_client.py):**
|
||||||
- **OpenAI** - Primary provider (gpt-4.1-mini default)
|
- **OpenAI** - Primary provider (gpt-4.1-mini default)
|
||||||
- **OpenRouter** - Multi-provider access (claude-3.5-sonnet)
|
- **OpenRouter** - Multi-provider access (claude-3.5-sonnet)
|
||||||
- **Ollama** - Self-hosted models (codellama:13b)
|
- **Ollama** - Self-hosted models (codellama:13b)
|
||||||
|
|
||||||
|
**Additional Providers (in clients/providers/):**
|
||||||
|
- **AnthropicProvider** - Direct Anthropic Claude API (claude-3.5-sonnet)
|
||||||
|
- **AzureOpenAIProvider** - Azure OpenAI Service with API key auth
|
||||||
|
- **AzureOpenAIWithAADProvider** - Azure OpenAI with Azure AD authentication
|
||||||
|
- **GeminiProvider** - Google Gemini API (public)
|
||||||
|
- **VertexAIGeminiProvider** - Google Vertex AI Gemini (enterprise GCP)
|
||||||
|
|
||||||
Key features:
|
Key features:
|
||||||
- Tool/function calling support via `call_with_tools(messages, tools)`
|
- Tool/function calling support via `call_with_tools(messages, tools)`
|
||||||
- JSON response parsing with fallback extraction
|
- JSON response parsing with fallback extraction
|
||||||
- Provider-specific configuration via `config.yml`
|
- Provider-specific configuration via `config.yml`
|
||||||
|
- Configurable timeouts per provider
|
||||||
|
|
||||||
### Platform Abstraction
|
### Platform Abstraction
|
||||||
|
|
||||||
|
|||||||
116
README.md
116
README.md
@@ -1,6 +1,6 @@
|
|||||||
# OpenRabbit
|
# OpenRabbit
|
||||||
|
|
||||||
Enterprise-grade AI code review system for **Gitea** with automated PR review, issue triage, interactive chat, and codebase analysis.
|
Enterprise-grade AI code review system for **Gitea** and **GitHub** with automated PR review, issue triage, interactive chat, and codebase analysis.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -14,9 +14,15 @@ Enterprise-grade AI code review system for **Gitea** with automated PR review, i
|
|||||||
| **Chat** | Interactive AI chat with codebase search and web search tools |
|
| **Chat** | Interactive AI chat with codebase search and web search tools |
|
||||||
| **@codebot Commands** | `@codebot summarize`, `changelog`, `explain-diff`, `explain`, `suggest`, `triage`, `review-again` in comments |
|
| **@codebot Commands** | `@codebot summarize`, `changelog`, `explain-diff`, `explain`, `suggest`, `triage`, `review-again` in comments |
|
||||||
| **Codebase Analysis** | Health scores, tech debt tracking, weekly reports |
|
| **Codebase Analysis** | Health scores, tech debt tracking, weekly reports |
|
||||||
| **Security Scanner** | 17 OWASP-aligned rules for vulnerability detection |
|
| **Security Scanner** | 17 OWASP-aligned rules + SAST integration (Bandit, Semgrep) |
|
||||||
|
| **Dependency Scanning** | Vulnerability detection for Python, JavaScript dependencies |
|
||||||
|
| **Test Coverage** | AI-powered test suggestions for untested code |
|
||||||
|
| **Architecture Compliance** | Layer separation enforcement, circular dependency detection |
|
||||||
|
| **Notifications** | Slack/Discord alerts for security findings and reviews |
|
||||||
|
| **Compliance** | Audit trail, CODEOWNERS enforcement, regulatory support |
|
||||||
|
| **Multi-Provider LLM** | OpenAI, Anthropic Claude, Azure OpenAI, Google Gemini, Ollama |
|
||||||
| **Enterprise Ready** | Audit logging, metrics, Prometheus export |
|
| **Enterprise Ready** | Audit logging, metrics, Prometheus export |
|
||||||
| **Gitea Native** | Built for Gitea workflows and API |
|
| **Gitea Native** | Built for Gitea workflows and API (also works with GitHub) |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -116,12 +122,28 @@ tools/ai-review/
|
|||||||
│ ├── issue_agent.py # Issue triage & @codebot commands
|
│ ├── issue_agent.py # Issue triage & @codebot commands
|
||||||
│ ├── pr_agent.py # PR review with security scan
|
│ ├── pr_agent.py # PR review with security scan
|
||||||
│ ├── codebase_agent.py # Codebase health analysis
|
│ ├── codebase_agent.py # Codebase health analysis
|
||||||
│ └── chat_agent.py # Interactive chat with tool calling
|
│ ├── chat_agent.py # Interactive chat with tool calling
|
||||||
|
│ ├── dependency_agent.py # Dependency vulnerability scanning
|
||||||
|
│ ├── test_coverage_agent.py # Test coverage analysis
|
||||||
|
│ └── architecture_agent.py # Architecture compliance checking
|
||||||
├── clients/ # API clients
|
├── clients/ # API clients
|
||||||
│ ├── gitea_client.py # Gitea REST API wrapper
|
│ ├── gitea_client.py # Gitea REST API wrapper
|
||||||
│ └── llm_client.py # Multi-provider LLM client with tool support
|
│ ├── llm_client.py # Multi-provider LLM client with tool support
|
||||||
|
│ └── providers/ # Additional LLM providers
|
||||||
|
│ ├── anthropic_provider.py # Direct Anthropic Claude API
|
||||||
|
│ ├── azure_provider.py # Azure OpenAI Service
|
||||||
|
│ └── gemini_provider.py # Google Gemini API
|
||||||
├── security/ # Security scanning
|
├── security/ # Security scanning
|
||||||
│ └── security_scanner.py # 17 OWASP-aligned rules
|
│ ├── security_scanner.py # 17 OWASP-aligned rules
|
||||||
|
│ └── sast_scanner.py # Bandit, Semgrep, Trivy integration
|
||||||
|
├── notifications/ # Alerting system
|
||||||
|
│ └── notifier.py # Slack, Discord, webhook notifications
|
||||||
|
├── compliance/ # Compliance & audit
|
||||||
|
│ ├── audit_trail.py # Audit logging with integrity verification
|
||||||
|
│ └── codeowners.py # CODEOWNERS enforcement
|
||||||
|
├── utils/ # Utility functions
|
||||||
|
│ ├── ignore_patterns.py # .ai-reviewignore support
|
||||||
|
│ └── webhook_sanitizer.py # Input validation
|
||||||
├── enterprise/ # Enterprise features
|
├── enterprise/ # Enterprise features
|
||||||
│ ├── audit_logger.py # JSONL audit logging
|
│ ├── audit_logger.py # JSONL audit logging
|
||||||
│ └── metrics.py # Prometheus-compatible metrics
|
│ └── metrics.py # Prometheus-compatible metrics
|
||||||
@@ -182,6 +204,10 @@ In any issue comment:
|
|||||||
| `@codebot summarize` | Summarize the issue in 2-3 sentences |
|
| `@codebot summarize` | Summarize the issue in 2-3 sentences |
|
||||||
| `@codebot explain` | Explain what the issue is about |
|
| `@codebot explain` | Explain what the issue is about |
|
||||||
| `@codebot suggest` | Suggest solutions or next steps |
|
| `@codebot suggest` | Suggest solutions or next steps |
|
||||||
|
| `@codebot check-deps` | Scan dependencies for security vulnerabilities |
|
||||||
|
| `@codebot suggest-tests` | Suggest test cases for changed code |
|
||||||
|
| `@codebot refactor-suggest` | Suggest refactoring opportunities |
|
||||||
|
| `@codebot architecture` | Check architecture compliance (alias: `arch-check`) |
|
||||||
| `@codebot` (any question) | Chat with AI using codebase/web search tools |
|
| `@codebot` (any question) | Chat with AI using codebase/web search tools |
|
||||||
|
|
||||||
### Pull Request Commands
|
### Pull Request Commands
|
||||||
@@ -522,19 +548,91 @@ Replace `'Bartender'` with your bot's Gitea username. This prevents the bot from
|
|||||||
|
|
||||||
| Provider | Model | Use Case |
|
| Provider | Model | Use Case |
|
||||||
|----------|-------|----------|
|
|----------|-------|----------|
|
||||||
| OpenAI | gpt-4.1-mini | Fast, reliable |
|
| OpenAI | gpt-4.1-mini | Fast, reliable, default |
|
||||||
|
| Anthropic | claude-3.5-sonnet | Direct Claude API access |
|
||||||
|
| Azure OpenAI | gpt-4 (deployment) | Enterprise Azure deployments |
|
||||||
|
| Google Gemini | gemini-1.5-pro | GCP customers, Vertex AI |
|
||||||
| OpenRouter | claude-3.5-sonnet | Multi-provider access |
|
| OpenRouter | claude-3.5-sonnet | Multi-provider access |
|
||||||
| Ollama | codellama:13b | Self-hosted, private |
|
| Ollama | codellama:13b | Self-hosted, private |
|
||||||
|
|
||||||
|
### Provider Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# In config.yml
|
||||||
|
provider: anthropic # openai | anthropic | azure | gemini | openrouter | ollama
|
||||||
|
|
||||||
|
# Azure OpenAI
|
||||||
|
azure:
|
||||||
|
endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var
|
||||||
|
deployment: "gpt-4"
|
||||||
|
api_version: "2024-02-15-preview"
|
||||||
|
|
||||||
|
# Google Gemini (Vertex AI)
|
||||||
|
gemini:
|
||||||
|
project: "" # Set via GOOGLE_CLOUD_PROJECT env var
|
||||||
|
region: "us-central1"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
| Variable | Provider | Description |
|
||||||
|
|----------|----------|-------------|
|
||||||
|
| `OPENAI_API_KEY` | OpenAI | API key |
|
||||||
|
| `ANTHROPIC_API_KEY` | Anthropic | API key |
|
||||||
|
| `AZURE_OPENAI_ENDPOINT` | Azure | Service endpoint URL |
|
||||||
|
| `AZURE_OPENAI_API_KEY` | Azure | API key |
|
||||||
|
| `AZURE_OPENAI_DEPLOYMENT` | Azure | Deployment name |
|
||||||
|
| `GOOGLE_API_KEY` | Gemini | API key (public API) |
|
||||||
|
| `GOOGLE_CLOUD_PROJECT` | Vertex AI | GCP project ID |
|
||||||
|
| `OPENROUTER_API_KEY` | OpenRouter | API key |
|
||||||
|
| `OLLAMA_HOST` | Ollama | Server URL (default: localhost:11434) |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Enterprise Features
|
## Enterprise Features
|
||||||
|
|
||||||
- **Audit Logging**: JSONL logs with daily rotation
|
- **Audit Logging**: JSONL logs with integrity checksums and daily rotation
|
||||||
|
- **Compliance**: HIPAA, SOC2, PCI-DSS, GDPR support with configurable rules
|
||||||
|
- **CODEOWNERS Enforcement**: Validate approvals against CODEOWNERS file
|
||||||
|
- **Notifications**: Slack/Discord webhooks for critical findings
|
||||||
|
- **SAST Integration**: Bandit, Semgrep, Trivy for advanced security scanning
|
||||||
- **Metrics**: Prometheus-compatible export
|
- **Metrics**: Prometheus-compatible export
|
||||||
- **Rate Limiting**: Configurable request limits
|
- **Rate Limiting**: Configurable request limits and timeouts
|
||||||
- **Custom Security Rules**: Define your own patterns via YAML
|
- **Custom Security Rules**: Define your own patterns via YAML
|
||||||
- **Tool Calling**: LLM function calling for interactive chat
|
- **Tool Calling**: LLM function calling for interactive chat
|
||||||
|
- **Ignore Patterns**: `.ai-reviewignore` for excluding files from review
|
||||||
|
|
||||||
|
### Notifications Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# In config.yml
|
||||||
|
notifications:
|
||||||
|
enabled: true
|
||||||
|
threshold: "warning" # info | warning | error | critical
|
||||||
|
|
||||||
|
slack:
|
||||||
|
enabled: true
|
||||||
|
webhook_url: "" # Set via SLACK_WEBHOOK_URL env var
|
||||||
|
channel: "#code-review"
|
||||||
|
|
||||||
|
discord:
|
||||||
|
enabled: true
|
||||||
|
webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compliance Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compliance:
|
||||||
|
enabled: true
|
||||||
|
audit:
|
||||||
|
enabled: true
|
||||||
|
log_file: "audit.log"
|
||||||
|
retention_days: 90
|
||||||
|
codeowners:
|
||||||
|
enabled: true
|
||||||
|
require_approval: true
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
296
tests/test_dispatcher.py
Normal file
296
tests/test_dispatcher.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""Test Suite for Dispatcher
|
||||||
|
|
||||||
|
Tests for event routing and agent execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDispatcherCreation:
|
||||||
|
"""Test dispatcher initialization."""
|
||||||
|
|
||||||
|
def test_create_dispatcher(self):
|
||||||
|
"""Test creating dispatcher."""
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
assert dispatcher is not None
|
||||||
|
assert dispatcher.agents == []
|
||||||
|
|
||||||
|
def test_create_dispatcher_with_config(self):
|
||||||
|
"""Test creating dispatcher with config."""
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
config = {"dispatcher": {"max_workers": 4}}
|
||||||
|
dispatcher = Dispatcher(config=config)
|
||||||
|
assert dispatcher.config == config
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentRegistration:
|
||||||
|
"""Test agent registration."""
|
||||||
|
|
||||||
|
def test_register_agent(self):
|
||||||
|
"""Test registering an agent."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
class MockAgent(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "test"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="done")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||||
|
dispatcher.register_agent(agent)
|
||||||
|
|
||||||
|
assert len(dispatcher.agents) == 1
|
||||||
|
assert dispatcher.agents[0] == agent
|
||||||
|
|
||||||
|
def test_register_multiple_agents(self):
|
||||||
|
"""Test registering multiple agents."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
class MockAgent1(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "type1"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="agent1")
|
||||||
|
|
||||||
|
class MockAgent2(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "type2"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="agent2")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
dispatcher.register_agent(
|
||||||
|
MockAgent1(config={}, gitea_client=None, llm_client=None)
|
||||||
|
)
|
||||||
|
dispatcher.register_agent(
|
||||||
|
MockAgent2(config={}, gitea_client=None, llm_client=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dispatcher.agents) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventRouting:
|
||||||
|
"""Test event routing to agents."""
|
||||||
|
|
||||||
|
def test_route_to_matching_agent(self):
|
||||||
|
"""Test that events are routed to matching agents."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
class MockAgent(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "issues"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="handled")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||||
|
dispatcher.register_agent(agent)
|
||||||
|
|
||||||
|
result = dispatcher.dispatch(
|
||||||
|
event_type="issues",
|
||||||
|
event_data={"action": "opened"},
|
||||||
|
owner="test",
|
||||||
|
repo="repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.agents_run) == 1
|
||||||
|
assert result.results[0].success is True
|
||||||
|
|
||||||
|
def test_no_matching_agent(self):
|
||||||
|
"""Test dispatch when no agent matches."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
class MockAgent(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "issues"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="handled")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||||
|
dispatcher.register_agent(agent)
|
||||||
|
|
||||||
|
result = dispatcher.dispatch(
|
||||||
|
event_type="pull_request", # Different event type
|
||||||
|
event_data={"action": "opened"},
|
||||||
|
owner="test",
|
||||||
|
repo="repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.agents_run) == 0
|
||||||
|
|
||||||
|
def test_multiple_matching_agents(self):
|
||||||
|
"""Test dispatch when multiple agents match."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
class MockAgent1(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "issues"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="agent1")
|
||||||
|
|
||||||
|
class MockAgent2(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return event_type == "issues"
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
return AgentResult(success=True, message="agent2")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
dispatcher.register_agent(
|
||||||
|
MockAgent1(config={}, gitea_client=None, llm_client=None)
|
||||||
|
)
|
||||||
|
dispatcher.register_agent(
|
||||||
|
MockAgent2(config={}, gitea_client=None, llm_client=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = dispatcher.dispatch(
|
||||||
|
event_type="issues",
|
||||||
|
event_data={"action": "opened"},
|
||||||
|
owner="test",
|
||||||
|
repo="repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.agents_run) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestDispatchResult:
|
||||||
|
"""Test dispatch result structure."""
|
||||||
|
|
||||||
|
def test_result_structure(self):
|
||||||
|
"""Test DispatchResult has correct structure."""
|
||||||
|
from dispatcher import DispatchResult
|
||||||
|
|
||||||
|
result = DispatchResult(
|
||||||
|
agents_run=["Agent1", "Agent2"],
|
||||||
|
results=[],
|
||||||
|
errors=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.agents_run == ["Agent1", "Agent2"]
|
||||||
|
assert result.results == []
|
||||||
|
assert result.errors == []
|
||||||
|
|
||||||
|
def test_result_with_errors(self):
|
||||||
|
"""Test DispatchResult with errors."""
|
||||||
|
from dispatcher import DispatchResult
|
||||||
|
|
||||||
|
result = DispatchResult(
|
||||||
|
agents_run=["Agent1"],
|
||||||
|
results=[],
|
||||||
|
errors=["Error 1", "Error 2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.errors) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentExecution:
|
||||||
|
"""Test agent execution through dispatcher."""
|
||||||
|
|
||||||
|
def test_agent_receives_context(self):
|
||||||
|
"""Test that agents receive proper context."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
received_context = None
|
||||||
|
|
||||||
|
class MockAgent(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
nonlocal received_context
|
||||||
|
received_context = context
|
||||||
|
return AgentResult(success=True, message="done")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
dispatcher.register_agent(
|
||||||
|
MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
dispatcher.dispatch(
|
||||||
|
event_type="issues",
|
||||||
|
event_data={"action": "opened", "issue": {"number": 123}},
|
||||||
|
owner="testowner",
|
||||||
|
repo="testrepo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert received_context is not None
|
||||||
|
assert received_context.owner == "testowner"
|
||||||
|
assert received_context.repo == "testrepo"
|
||||||
|
assert received_context.event_type == "issues"
|
||||||
|
assert received_context.event_data["action"] == "opened"
|
||||||
|
|
||||||
|
def test_agent_failure_captured(self):
|
||||||
|
"""Test that agent failures are captured in results."""
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
from dispatcher import Dispatcher
|
||||||
|
|
||||||
|
class FailingAgent(BaseAgent):
|
||||||
|
def can_handle(self, event_type, event_data):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
raise Exception("Test error")
|
||||||
|
|
||||||
|
dispatcher = Dispatcher()
|
||||||
|
dispatcher.register_agent(
|
||||||
|
FailingAgent(config={}, gitea_client=None, llm_client=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = dispatcher.dispatch(
|
||||||
|
event_type="issues",
|
||||||
|
event_data={},
|
||||||
|
owner="test",
|
||||||
|
repo="repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent should still be in agents_run
|
||||||
|
assert len(result.agents_run) == 1
|
||||||
|
# Result should indicate failure
|
||||||
|
assert result.results[0].success is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDispatcher:
|
||||||
|
"""Test get_dispatcher factory function."""
|
||||||
|
|
||||||
|
def test_get_dispatcher_returns_singleton(self):
|
||||||
|
"""Test that get_dispatcher returns configured dispatcher."""
|
||||||
|
from dispatcher import get_dispatcher
|
||||||
|
|
||||||
|
dispatcher = get_dispatcher()
|
||||||
|
assert dispatcher is not None
|
||||||
|
|
||||||
|
def test_get_dispatcher_with_config(self):
|
||||||
|
"""Test get_dispatcher with custom config."""
|
||||||
|
from dispatcher import get_dispatcher
|
||||||
|
|
||||||
|
config = {"test": "value"}
|
||||||
|
dispatcher = get_dispatcher(config=config)
|
||||||
|
assert dispatcher.config.get("test") == "value"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
456
tests/test_llm_client.py
Normal file
456
tests/test_llm_client.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
"""Test Suite for LLM Client
|
||||||
|
|
||||||
|
Tests for LLM client functionality including provider support,
|
||||||
|
tool calling, and JSON parsing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMClientCreation:
|
||||||
|
"""Test LLM client initialization."""
|
||||||
|
|
||||||
|
def test_create_openai_client(self):
|
||||||
|
"""Test creating OpenAI client."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient(
|
||||||
|
provider="openai",
|
||||||
|
config={"model": "gpt-4o-mini", "api_key": "test-key"},
|
||||||
|
)
|
||||||
|
assert client.provider_name == "openai"
|
||||||
|
|
||||||
|
def test_create_openrouter_client(self):
|
||||||
|
"""Test creating OpenRouter client."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient(
|
||||||
|
provider="openrouter",
|
||||||
|
config={"model": "anthropic/claude-3.5-sonnet", "api_key": "test-key"},
|
||||||
|
)
|
||||||
|
assert client.provider_name == "openrouter"
|
||||||
|
|
||||||
|
def test_create_ollama_client(self):
|
||||||
|
"""Test creating Ollama client."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient(
|
||||||
|
provider="ollama",
|
||||||
|
config={"model": "codellama:13b", "host": "http://localhost:11434"},
|
||||||
|
)
|
||||||
|
assert client.provider_name == "ollama"
|
||||||
|
|
||||||
|
def test_invalid_provider_raises_error(self):
|
||||||
|
"""Test that invalid provider raises ValueError."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown provider"):
|
||||||
|
LLMClient(provider="invalid_provider")
|
||||||
|
|
||||||
|
def test_from_config_openai(self):
|
||||||
|
"""Test creating client from config dict."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": {"openai": "gpt-4o-mini"},
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
}
|
||||||
|
client = LLMClient.from_config(config)
|
||||||
|
assert client.provider_name == "openai"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMResponse:
|
||||||
|
"""Test LLM response dataclass."""
|
||||||
|
|
||||||
|
def test_response_creation(self):
|
||||||
|
"""Test creating LLMResponse."""
|
||||||
|
from clients.llm_client import LLMResponse
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content="Test response",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
provider="openai",
|
||||||
|
tokens_used=100,
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "Test response"
|
||||||
|
assert response.model == "gpt-4o-mini"
|
||||||
|
assert response.provider == "openai"
|
||||||
|
assert response.tokens_used == 100
|
||||||
|
assert response.finish_reason == "stop"
|
||||||
|
assert response.tool_calls is None
|
||||||
|
|
||||||
|
def test_response_with_tool_calls(self):
|
||||||
|
"""Test LLMResponse with tool calls."""
|
||||||
|
from clients.llm_client import LLMResponse, ToolCall
|
||||||
|
|
||||||
|
tool_calls = [ToolCall(id="call_1", name="search", arguments={"query": "test"})]
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content="",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
provider="openai",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.tool_calls is not None
|
||||||
|
assert len(response.tool_calls) == 1
|
||||||
|
assert response.tool_calls[0].name == "search"
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCall:
|
||||||
|
"""Test ToolCall dataclass."""
|
||||||
|
|
||||||
|
def test_tool_call_creation(self):
|
||||||
|
"""Test creating ToolCall."""
|
||||||
|
from clients.llm_client import ToolCall
|
||||||
|
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="call_123",
|
||||||
|
name="search_codebase",
|
||||||
|
arguments={"query": "authentication", "file_pattern": "*.py"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_call.id == "call_123"
|
||||||
|
assert tool_call.name == "search_codebase"
|
||||||
|
assert tool_call.arguments["query"] == "authentication"
|
||||||
|
assert tool_call.arguments["file_pattern"] == "*.py"
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONParsing:
|
||||||
|
"""Test JSON extraction and parsing."""
|
||||||
|
|
||||||
|
def test_parse_direct_json(self):
|
||||||
|
"""Test parsing direct JSON response."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = '{"key": "value", "number": 42}'
|
||||||
|
result = client._extract_json(content)
|
||||||
|
|
||||||
|
assert result["key"] == "value"
|
||||||
|
assert result["number"] == 42
|
||||||
|
|
||||||
|
def test_parse_json_in_code_block(self):
|
||||||
|
"""Test parsing JSON in markdown code block."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = """Here is the analysis:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "bug",
|
||||||
|
"priority": "high"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
That's my analysis."""
|
||||||
|
|
||||||
|
result = client._extract_json(content)
|
||||||
|
assert result["type"] == "bug"
|
||||||
|
assert result["priority"] == "high"
|
||||||
|
|
||||||
|
def test_parse_json_in_plain_code_block(self):
|
||||||
|
"""Test parsing JSON in plain code block (no json specifier)."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = """Analysis:
|
||||||
|
|
||||||
|
```
|
||||||
|
{"status": "success", "count": 5}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = client._extract_json(content)
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert result["count"] == 5
|
||||||
|
|
||||||
|
def test_parse_json_with_preamble(self):
|
||||||
|
"""Test parsing JSON with text before it."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = """Based on my analysis, here is the result:
|
||||||
|
{"findings": ["issue1", "issue2"], "severity": "medium"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = client._extract_json(content)
|
||||||
|
assert result["findings"] == ["issue1", "issue2"]
|
||||||
|
assert result["severity"] == "medium"
|
||||||
|
|
||||||
|
def test_parse_json_with_postamble(self):
|
||||||
|
"""Test parsing JSON with text after it."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = """{"result": true}
|
||||||
|
|
||||||
|
Let me know if you need more details."""
|
||||||
|
|
||||||
|
result = client._extract_json(content)
|
||||||
|
assert result["result"] is True
|
||||||
|
|
||||||
|
def test_parse_nested_json(self):
|
||||||
|
"""Test parsing nested JSON objects."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = """{
|
||||||
|
"outer": {
|
||||||
|
"inner": {
|
||||||
|
"value": "deep"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"array": [1, 2, 3]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
result = client._extract_json(content)
|
||||||
|
assert result["outer"]["inner"]["value"] == "deep"
|
||||||
|
assert result["array"] == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_parse_invalid_json_raises_error(self):
|
||||||
|
"""Test that invalid JSON raises ValueError."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = "This is not JSON at all"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Failed to parse JSON"):
|
||||||
|
client._extract_json(content)
|
||||||
|
|
||||||
|
def test_parse_truncated_json_raises_error(self):
|
||||||
|
"""Test that truncated JSON raises ValueError."""
|
||||||
|
from clients.llm_client import LLMClient
|
||||||
|
|
||||||
|
client = LLMClient.__new__(LLMClient)
|
||||||
|
|
||||||
|
content = '{"key": "value", "incomplete'
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
client._extract_json(content)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIProvider:
|
||||||
|
"""Test OpenAI provider."""
|
||||||
|
|
||||||
|
def test_provider_creation(self):
|
||||||
|
"""Test OpenAI provider initialization."""
|
||||||
|
from clients.llm_client import OpenAIProvider
|
||||||
|
|
||||||
|
provider = OpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.model == "gpt-4o-mini"
|
||||||
|
assert provider.temperature == 0.5
|
||||||
|
assert provider.max_tokens == 2048
|
||||||
|
|
||||||
|
def test_provider_requires_api_key(self):
|
||||||
|
"""Test that calling without API key raises error."""
|
||||||
|
from clients.llm_client import OpenAIProvider
|
||||||
|
|
||||||
|
provider = OpenAIProvider(api_key="")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="API key is required"):
|
||||||
|
provider.call("test prompt")
|
||||||
|
|
||||||
|
@patch("clients.llm_client.requests.post")
|
||||||
|
def test_provider_call_success(self, mock_post):
|
||||||
|
"""Test successful API call."""
|
||||||
|
from clients.llm_client import OpenAIProvider
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {"content": "Test response"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"usage": {"total_tokens": 50},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OpenAIProvider(api_key="test-key")
|
||||||
|
response = provider.call("Hello")
|
||||||
|
|
||||||
|
assert response.content == "Test response"
|
||||||
|
assert response.provider == "openai"
|
||||||
|
assert response.tokens_used == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenRouterProvider:
|
||||||
|
"""Test OpenRouter provider."""
|
||||||
|
|
||||||
|
def test_provider_creation(self):
|
||||||
|
"""Test OpenRouter provider initialization."""
|
||||||
|
from clients.llm_client import OpenRouterProvider
|
||||||
|
|
||||||
|
provider = OpenRouterProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
model="anthropic/claude-3.5-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.model == "anthropic/claude-3.5-sonnet"
|
||||||
|
|
||||||
|
@patch("clients.llm_client.requests.post")
|
||||||
|
def test_provider_call_success(self, mock_post):
|
||||||
|
"""Test successful API call."""
|
||||||
|
from clients.llm_client import OpenRouterProvider
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {"content": "Claude response"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "anthropic/claude-3.5-sonnet",
|
||||||
|
"usage": {"total_tokens": 75},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OpenRouterProvider(api_key="test-key")
|
||||||
|
response = provider.call("Hello")
|
||||||
|
|
||||||
|
assert response.content == "Claude response"
|
||||||
|
assert response.provider == "openrouter"
|
||||||
|
|
||||||
|
|
||||||
|
class TestOllamaProvider:
|
||||||
|
"""Test Ollama provider."""
|
||||||
|
|
||||||
|
def test_provider_creation(self):
|
||||||
|
"""Test Ollama provider initialization."""
|
||||||
|
from clients.llm_client import OllamaProvider
|
||||||
|
|
||||||
|
provider = OllamaProvider(
|
||||||
|
host="http://localhost:11434",
|
||||||
|
model="codellama:13b",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.model == "codellama:13b"
|
||||||
|
assert provider.host == "http://localhost:11434"
|
||||||
|
|
||||||
|
@patch("clients.llm_client.requests.post")
|
||||||
|
def test_provider_call_success(self, mock_post):
|
||||||
|
"""Test successful API call."""
|
||||||
|
from clients.llm_client import OllamaProvider
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"response": "Ollama response",
|
||||||
|
"model": "codellama:13b",
|
||||||
|
"done": True,
|
||||||
|
"eval_count": 30,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OllamaProvider()
|
||||||
|
response = provider.call("Hello")
|
||||||
|
|
||||||
|
assert response.content == "Ollama response"
|
||||||
|
assert response.provider == "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCalling:
|
||||||
|
"""Test tool/function calling support."""
|
||||||
|
|
||||||
|
@patch("clients.llm_client.requests.post")
|
||||||
|
def test_openai_tool_calling(self, mock_post):
|
||||||
|
"""Test OpenAI tool calling."""
|
||||||
|
from clients.llm_client import OpenAIProvider
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_abc123",
|
||||||
|
"function": {
|
||||||
|
"name": "search_codebase",
|
||||||
|
"arguments": '{"query": "auth"}',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"usage": {"total_tokens": 100},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
provider = OpenAIProvider(api_key="test-key")
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_codebase",
|
||||||
|
"description": "Search the codebase",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = provider.call_with_tools(
|
||||||
|
messages=[{"role": "user", "content": "Search for auth"}],
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.tool_calls is not None
|
||||||
|
assert len(response.tool_calls) == 1
|
||||||
|
assert response.tool_calls[0].name == "search_codebase"
|
||||||
|
assert response.tool_calls[0].arguments["query"] == "auth"
|
||||||
|
|
||||||
|
def test_ollama_tool_calling_not_supported(self):
|
||||||
|
"""Test that Ollama raises NotImplementedError for tool calling."""
|
||||||
|
from clients.llm_client import OllamaProvider
|
||||||
|
|
||||||
|
provider = OllamaProvider()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
provider.call_with_tools(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
tools=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -2,20 +2,40 @@
|
|||||||
|
|
||||||
This package contains the modular agent implementations for the
|
This package contains the modular agent implementations for the
|
||||||
enterprise AI code review system.
|
enterprise AI code review system.
|
||||||
|
|
||||||
|
Core Agents:
|
||||||
|
- PRAgent: Pull request review and analysis
|
||||||
|
- IssueAgent: Issue triage and response
|
||||||
|
- CodebaseAgent: Codebase health analysis
|
||||||
|
- ChatAgent: Interactive chat with tool calling
|
||||||
|
|
||||||
|
Specialized Agents:
|
||||||
|
- DependencyAgent: Dependency vulnerability scanning
|
||||||
|
- TestCoverageAgent: Test coverage analysis and suggestions
|
||||||
|
- ArchitectureAgent: Architecture compliance checking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from agents.architecture_agent import ArchitectureAgent
|
||||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
from agents.chat_agent import ChatAgent
|
from agents.chat_agent import ChatAgent
|
||||||
from agents.codebase_agent import CodebaseAgent
|
from agents.codebase_agent import CodebaseAgent
|
||||||
|
from agents.dependency_agent import DependencyAgent
|
||||||
from agents.issue_agent import IssueAgent
|
from agents.issue_agent import IssueAgent
|
||||||
from agents.pr_agent import PRAgent
|
from agents.pr_agent import PRAgent
|
||||||
|
from agents.test_coverage_agent import TestCoverageAgent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Base
|
||||||
"BaseAgent",
|
"BaseAgent",
|
||||||
"AgentContext",
|
"AgentContext",
|
||||||
"AgentResult",
|
"AgentResult",
|
||||||
|
# Core Agents
|
||||||
"IssueAgent",
|
"IssueAgent",
|
||||||
"PRAgent",
|
"PRAgent",
|
||||||
"CodebaseAgent",
|
"CodebaseAgent",
|
||||||
"ChatAgent",
|
"ChatAgent",
|
||||||
|
# Specialized Agents
|
||||||
|
"DependencyAgent",
|
||||||
|
"TestCoverageAgent",
|
||||||
|
"ArchitectureAgent",
|
||||||
]
|
]
|
||||||
|
|||||||
547
tools/ai-review/agents/architecture_agent.py
Normal file
547
tools/ai-review/agents/architecture_agent.py
Normal file
@@ -0,0 +1,547 @@
|
|||||||
|
"""Architecture Compliance Agent
|
||||||
|
|
||||||
|
AI agent for enforcing architectural patterns and layer separation.
|
||||||
|
Detects cross-layer violations and circular dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ArchitectureViolation:
|
||||||
|
"""An architecture violation."""
|
||||||
|
|
||||||
|
file: str
|
||||||
|
line: int
|
||||||
|
violation_type: str # cross_layer, circular, naming, structure
|
||||||
|
severity: str # HIGH, MEDIUM, LOW
|
||||||
|
description: str
|
||||||
|
recommendation: str
|
||||||
|
source_layer: str | None = None
|
||||||
|
target_layer: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ArchitectureReport:
|
||||||
|
"""Report of architecture analysis."""
|
||||||
|
|
||||||
|
violations: list[ArchitectureViolation]
|
||||||
|
layers_detected: dict[str, list[str]]
|
||||||
|
circular_dependencies: list[tuple[str, str]]
|
||||||
|
compliance_score: float
|
||||||
|
recommendations: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ArchitectureAgent(BaseAgent):
|
||||||
|
"""Agent for enforcing architectural compliance."""
|
||||||
|
|
||||||
|
# Marker for architecture comments
|
||||||
|
ARCH_AI_MARKER = "<!-- AI_ARCHITECTURE_CHECK -->"
|
||||||
|
|
||||||
|
# Default layer definitions
|
||||||
|
DEFAULT_LAYERS = {
|
||||||
|
"api": {
|
||||||
|
"patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"],
|
||||||
|
"can_import": ["services", "models", "utils", "config"],
|
||||||
|
"cannot_import": ["db", "repositories", "infrastructure"],
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"patterns": ["services/", "usecases/", "application/"],
|
||||||
|
"can_import": ["models", "repositories", "utils", "config"],
|
||||||
|
"cannot_import": ["api", "controllers", "handlers"],
|
||||||
|
},
|
||||||
|
"repositories": {
|
||||||
|
"patterns": ["repositories/", "repos/", "data/"],
|
||||||
|
"can_import": ["models", "db", "utils", "config"],
|
||||||
|
"cannot_import": ["api", "services", "controllers"],
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"patterns": ["models/", "entities/", "domain/", "schemas/"],
|
||||||
|
"can_import": ["utils", "config"],
|
||||||
|
"cannot_import": ["api", "services", "repositories", "db"],
|
||||||
|
},
|
||||||
|
"db": {
|
||||||
|
"patterns": ["db/", "database/", "infrastructure/"],
|
||||||
|
"can_import": ["models", "config"],
|
||||||
|
"cannot_import": ["api", "services"],
|
||||||
|
},
|
||||||
|
"utils": {
|
||||||
|
"patterns": ["utils/", "helpers/", "common/", "lib/"],
|
||||||
|
"can_import": ["config"],
|
||||||
|
"cannot_import": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||||
|
"""Check if this agent handles the given event."""
|
||||||
|
agent_config = self.config.get("agents", {}).get("architecture", {})
|
||||||
|
if not agent_config.get("enabled", False):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle PR events
|
||||||
|
if event_type == "pull_request":
|
||||||
|
action = event_data.get("action", "")
|
||||||
|
if action in ["opened", "synchronize"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Handle @codebot architecture command
|
||||||
|
if event_type == "issue_comment":
|
||||||
|
comment_body = event_data.get("comment", {}).get("body", "")
|
||||||
|
mention_prefix = self.config.get("interaction", {}).get(
|
||||||
|
"mention_prefix", "@codebot"
|
||||||
|
)
|
||||||
|
if f"{mention_prefix} architecture" in comment_body.lower():
|
||||||
|
return True
|
||||||
|
if f"{mention_prefix} arch-check" in comment_body.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute(self, context: AgentContext) -> AgentResult:
|
||||||
|
"""Execute the architecture agent."""
|
||||||
|
self.logger.info(f"Checking architecture for {context.owner}/{context.repo}")
|
||||||
|
|
||||||
|
actions_taken = []
|
||||||
|
|
||||||
|
# Get layer configuration
|
||||||
|
agent_config = self.config.get("agents", {}).get("architecture", {})
|
||||||
|
layers = agent_config.get("layers", self.DEFAULT_LAYERS)
|
||||||
|
|
||||||
|
# Determine issue number
|
||||||
|
if context.event_type == "issue_comment":
|
||||||
|
issue = context.event_data.get("issue", {})
|
||||||
|
issue_number = issue.get("number")
|
||||||
|
comment_author = (
|
||||||
|
context.event_data.get("comment", {})
|
||||||
|
.get("user", {})
|
||||||
|
.get("login", "user")
|
||||||
|
)
|
||||||
|
is_pr = issue.get("pull_request") is not None
|
||||||
|
else:
|
||||||
|
pr = context.event_data.get("pull_request", {})
|
||||||
|
issue_number = pr.get("number")
|
||||||
|
comment_author = None
|
||||||
|
is_pr = True
|
||||||
|
|
||||||
|
if is_pr and issue_number:
|
||||||
|
# Analyze PR diff
|
||||||
|
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
|
||||||
|
report = self._analyze_diff(diff, layers)
|
||||||
|
actions_taken.append(f"Analyzed PR diff for architecture violations")
|
||||||
|
else:
|
||||||
|
# Analyze repository structure
|
||||||
|
report = self._analyze_repository(context.owner, context.repo, layers)
|
||||||
|
actions_taken.append(f"Analyzed repository architecture")
|
||||||
|
|
||||||
|
# Post report
|
||||||
|
if issue_number:
|
||||||
|
comment = self._format_architecture_report(report, comment_author)
|
||||||
|
self.upsert_comment(
|
||||||
|
context.owner,
|
||||||
|
context.repo,
|
||||||
|
issue_number,
|
||||||
|
comment,
|
||||||
|
marker=self.ARCH_AI_MARKER,
|
||||||
|
)
|
||||||
|
actions_taken.append("Posted architecture report")
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=len(report.violations) == 0 or report.compliance_score >= 0.8,
|
||||||
|
message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance",
|
||||||
|
data={
|
||||||
|
"violations_count": len(report.violations),
|
||||||
|
"compliance_score": report.compliance_score,
|
||||||
|
"circular_dependencies": len(report.circular_dependencies),
|
||||||
|
},
|
||||||
|
actions_taken=actions_taken,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
|
||||||
|
"""Get the PR diff."""
|
||||||
|
try:
|
||||||
|
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to get PR diff: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport:
|
||||||
|
"""Analyze PR diff for architecture violations."""
|
||||||
|
violations = []
|
||||||
|
imports_by_file = {}
|
||||||
|
|
||||||
|
current_file = None
|
||||||
|
current_language = None
|
||||||
|
|
||||||
|
for line in diff.splitlines():
|
||||||
|
# Track current file
|
||||||
|
if line.startswith("diff --git"):
|
||||||
|
match = re.search(r"b/(.+)$", line)
|
||||||
|
if match:
|
||||||
|
current_file = match.group(1)
|
||||||
|
current_language = self._detect_language(current_file)
|
||||||
|
imports_by_file[current_file] = []
|
||||||
|
|
||||||
|
# Look for import statements in added lines
|
||||||
|
if line.startswith("+") and not line.startswith("+++"):
|
||||||
|
if current_file and current_language:
|
||||||
|
imports = self._extract_imports(line[1:], current_language)
|
||||||
|
imports_by_file.setdefault(current_file, []).extend(imports)
|
||||||
|
|
||||||
|
# Check for violations
|
||||||
|
for file_path, imports in imports_by_file.items():
|
||||||
|
source_layer = self._detect_layer(file_path, layers)
|
||||||
|
if not source_layer:
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer_config = layers.get(source_layer, {})
|
||||||
|
cannot_import = layer_config.get("cannot_import", [])
|
||||||
|
|
||||||
|
for imp in imports:
|
||||||
|
target_layer = self._detect_layer_from_import(imp, layers)
|
||||||
|
if target_layer and target_layer in cannot_import:
|
||||||
|
violations.append(
|
||||||
|
ArchitectureViolation(
|
||||||
|
file=file_path,
|
||||||
|
line=0, # Line number not tracked in this simple implementation
|
||||||
|
violation_type="cross_layer",
|
||||||
|
severity="HIGH",
|
||||||
|
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
|
||||||
|
recommendation=f"Move this import to an allowed layer or refactor the dependency",
|
||||||
|
source_layer=source_layer,
|
||||||
|
target_layer=target_layer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect circular dependencies
|
||||||
|
circular = self._detect_circular_dependencies(imports_by_file, layers)
|
||||||
|
|
||||||
|
# Calculate compliance score
|
||||||
|
total_imports = sum(len(imps) for imps in imports_by_file.values())
|
||||||
|
if total_imports > 0:
|
||||||
|
compliance = 1.0 - (len(violations) / max(total_imports, 1))
|
||||||
|
else:
|
||||||
|
compliance = 1.0
|
||||||
|
|
||||||
|
return ArchitectureReport(
|
||||||
|
violations=violations,
|
||||||
|
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
|
||||||
|
circular_dependencies=circular,
|
||||||
|
compliance_score=max(0.0, compliance),
|
||||||
|
recommendations=self._generate_recommendations(violations),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_repository(
|
||||||
|
self, owner: str, repo: str, layers: dict
|
||||||
|
) -> ArchitectureReport:
|
||||||
|
"""Analyze repository structure for architecture compliance."""
|
||||||
|
violations = []
|
||||||
|
imports_by_file = {}
|
||||||
|
|
||||||
|
# Collect files from each layer
|
||||||
|
for layer_name, layer_config in layers.items():
|
||||||
|
for pattern in layer_config.get("patterns", []):
|
||||||
|
try:
|
||||||
|
path = pattern.rstrip("/")
|
||||||
|
contents = self.gitea.get_file_contents(owner, repo, path)
|
||||||
|
if isinstance(contents, list):
|
||||||
|
for item in contents[:20]: # Limit files per layer
|
||||||
|
if item.get("type") == "file":
|
||||||
|
filepath = item.get("path", "")
|
||||||
|
imports = self._get_file_imports(owner, repo, filepath)
|
||||||
|
imports_by_file[filepath] = imports
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check for violations
|
||||||
|
for file_path, imports in imports_by_file.items():
|
||||||
|
source_layer = self._detect_layer(file_path, layers)
|
||||||
|
if not source_layer:
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer_config = layers.get(source_layer, {})
|
||||||
|
cannot_import = layer_config.get("cannot_import", [])
|
||||||
|
|
||||||
|
for imp in imports:
|
||||||
|
target_layer = self._detect_layer_from_import(imp, layers)
|
||||||
|
if target_layer and target_layer in cannot_import:
|
||||||
|
violations.append(
|
||||||
|
ArchitectureViolation(
|
||||||
|
file=file_path,
|
||||||
|
line=0,
|
||||||
|
violation_type="cross_layer",
|
||||||
|
severity="HIGH",
|
||||||
|
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
|
||||||
|
recommendation=f"Refactor to remove dependency on '{target_layer}'",
|
||||||
|
source_layer=source_layer,
|
||||||
|
target_layer=target_layer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detect circular dependencies
|
||||||
|
circular = self._detect_circular_dependencies(imports_by_file, layers)
|
||||||
|
|
||||||
|
# Calculate compliance
|
||||||
|
total_imports = sum(len(imps) for imps in imports_by_file.values())
|
||||||
|
if total_imports > 0:
|
||||||
|
compliance = 1.0 - (len(violations) / max(total_imports, 1))
|
||||||
|
else:
|
||||||
|
compliance = 1.0
|
||||||
|
|
||||||
|
return ArchitectureReport(
|
||||||
|
violations=violations,
|
||||||
|
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
|
||||||
|
circular_dependencies=circular,
|
||||||
|
compliance_score=max(0.0, compliance),
|
||||||
|
recommendations=self._generate_recommendations(violations),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]:
|
||||||
|
"""Get imports from a file."""
|
||||||
|
imports = []
|
||||||
|
language = self._detect_language(filepath)
|
||||||
|
|
||||||
|
if not language:
|
||||||
|
return imports
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_data = self.gitea.get_file_contents(owner, repo, filepath)
|
||||||
|
if content_data.get("content"):
|
||||||
|
content = base64.b64decode(content_data["content"]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
for line in content.splitlines():
|
||||||
|
imports.extend(self._extract_imports(line, language))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
def _detect_language(self, filepath: str) -> str | None:
|
||||||
|
"""Detect programming language from file path."""
|
||||||
|
ext_map = {
|
||||||
|
".py": "python",
|
||||||
|
".js": "javascript",
|
||||||
|
".ts": "typescript",
|
||||||
|
".go": "go",
|
||||||
|
".java": "java",
|
||||||
|
".rb": "ruby",
|
||||||
|
}
|
||||||
|
ext = os.path.splitext(filepath)[1]
|
||||||
|
return ext_map.get(ext)
|
||||||
|
|
||||||
|
def _extract_imports(self, line: str, language: str) -> list[str]:
|
||||||
|
"""Extract import statements from a line of code."""
|
||||||
|
imports = []
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
if language == "python":
|
||||||
|
# from x import y, import x
|
||||||
|
match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line)
|
||||||
|
if match:
|
||||||
|
imp = match.group(1) or match.group(2)
|
||||||
|
if imp:
|
||||||
|
imports.append(imp.split(".")[0])
|
||||||
|
|
||||||
|
elif language in ("javascript", "typescript"):
|
||||||
|
# import x from 'y', require('y')
|
||||||
|
match = re.search(
|
||||||
|
r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
imp = match.group(1) or match.group(2)
|
||||||
|
if imp and not imp.startswith("."):
|
||||||
|
imports.append(imp.split("/")[0])
|
||||||
|
elif imp:
|
||||||
|
imports.append(imp)
|
||||||
|
|
||||||
|
elif language == "go":
|
||||||
|
# import "package"
|
||||||
|
match = re.search(r'import\s+["\']([^"\']+)["\']', line)
|
||||||
|
if match:
|
||||||
|
imports.append(match.group(1).split("/")[-1])
|
||||||
|
|
||||||
|
elif language == "java":
|
||||||
|
# import package.Class
|
||||||
|
match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line)
|
||||||
|
if match:
|
||||||
|
parts = match.group(1).split(".")
|
||||||
|
if len(parts) > 1:
|
||||||
|
imports.append(parts[-2]) # Package name
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
def _detect_layer(self, filepath: str, layers: dict) -> str | None:
|
||||||
|
"""Detect which layer a file belongs to."""
|
||||||
|
for layer_name, layer_config in layers.items():
|
||||||
|
for pattern in layer_config.get("patterns", []):
|
||||||
|
if pattern.rstrip("/") in filepath:
|
||||||
|
return layer_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None:
|
||||||
|
"""Detect which layer an import refers to."""
|
||||||
|
for layer_name, layer_config in layers.items():
|
||||||
|
for pattern in layer_config.get("patterns", []):
|
||||||
|
pattern_name = pattern.rstrip("/").split("/")[-1]
|
||||||
|
if pattern_name in import_path or import_path.startswith(pattern_name):
|
||||||
|
return layer_name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _detect_circular_dependencies(
|
||||||
|
self, imports_by_file: dict, layers: dict
|
||||||
|
) -> list[tuple[str, str]]:
|
||||||
|
"""Detect circular dependencies between layers."""
|
||||||
|
circular = []
|
||||||
|
|
||||||
|
# Build layer dependency graph
|
||||||
|
layer_deps = {}
|
||||||
|
for file_path, imports in imports_by_file.items():
|
||||||
|
source_layer = self._detect_layer(file_path, layers)
|
||||||
|
if not source_layer:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if source_layer not in layer_deps:
|
||||||
|
layer_deps[source_layer] = set()
|
||||||
|
|
||||||
|
for imp in imports:
|
||||||
|
target_layer = self._detect_layer_from_import(imp, layers)
|
||||||
|
if target_layer and target_layer != source_layer:
|
||||||
|
layer_deps[source_layer].add(target_layer)
|
||||||
|
|
||||||
|
# Check for circular dependencies
|
||||||
|
for layer_a, deps_a in layer_deps.items():
|
||||||
|
for layer_b in deps_a:
|
||||||
|
if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()):
|
||||||
|
pair = tuple(sorted([layer_a, layer_b]))
|
||||||
|
if pair not in circular:
|
||||||
|
circular.append(pair)
|
||||||
|
|
||||||
|
return circular
|
||||||
|
|
||||||
|
def _group_files_by_layer(
|
||||||
|
self, files: list[str], layers: dict
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Group files by their layer."""
|
||||||
|
grouped = {}
|
||||||
|
for filepath in files:
|
||||||
|
layer = self._detect_layer(filepath, layers)
|
||||||
|
if layer:
|
||||||
|
if layer not in grouped:
|
||||||
|
grouped[layer] = []
|
||||||
|
grouped[layer].append(filepath)
|
||||||
|
return grouped
|
||||||
|
|
||||||
|
def _generate_recommendations(
|
||||||
|
self, violations: list[ArchitectureViolation]
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate recommendations based on violations."""
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
# Count violations by type
|
||||||
|
cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer")
|
||||||
|
|
||||||
|
if cross_layer > 0:
|
||||||
|
recommendations.append(
|
||||||
|
f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_layer > 5:
|
||||||
|
recommendations.append(
|
||||||
|
"Consider using dependency injection to reduce coupling between layers"
|
||||||
|
)
|
||||||
|
|
||||||
|
return recommendations
|
||||||
|
|
||||||
|
def _format_architecture_report(
|
||||||
|
self, report: ArchitectureReport, user: str | None
|
||||||
|
) -> str:
|
||||||
|
"""Format the architecture report as a comment."""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
if user:
|
||||||
|
lines.append(f"@{user}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
f"{self.AI_DISCLAIMER}",
|
||||||
|
"",
|
||||||
|
"## 🏗️ Architecture Compliance Check",
|
||||||
|
"",
|
||||||
|
"### Summary",
|
||||||
|
"",
|
||||||
|
f"| Metric | Value |",
|
||||||
|
f"|--------|-------|",
|
||||||
|
f"| Compliance Score | {report.compliance_score:.0%} |",
|
||||||
|
f"| Violations | {len(report.violations)} |",
|
||||||
|
f"| Circular Dependencies | {len(report.circular_dependencies)} |",
|
||||||
|
f"| Layers Detected | {len(report.layers_detected)} |",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compliance bar
|
||||||
|
filled = int(report.compliance_score * 10)
|
||||||
|
bar = "█" * filled + "░" * (10 - filled)
|
||||||
|
lines.append(f"`[{bar}]` {report.compliance_score:.0%}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Violations
|
||||||
|
if report.violations:
|
||||||
|
lines.append("### 🚨 Violations")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for v in report.violations[:10]: # Limit display
|
||||||
|
severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
|
||||||
|
lines.append(
|
||||||
|
f"{severity_emoji.get(v.severity, '⚪')} **{v.violation_type.upper()}** in `{v.file}`"
|
||||||
|
)
|
||||||
|
lines.append(f" - {v.description}")
|
||||||
|
lines.append(f" - 💡 {v.recommendation}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
if len(report.violations) > 10:
|
||||||
|
lines.append(f"*... and {len(report.violations) - 10} more violations*")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Circular dependencies
|
||||||
|
if report.circular_dependencies:
|
||||||
|
lines.append("### 🔄 Circular Dependencies")
|
||||||
|
lines.append("")
|
||||||
|
for a, b in report.circular_dependencies:
|
||||||
|
lines.append(f"- `{a}` ↔ `{b}`")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Layers detected
|
||||||
|
if report.layers_detected:
|
||||||
|
lines.append("### 📁 Layers Detected")
|
||||||
|
lines.append("")
|
||||||
|
for layer, files in report.layers_detected.items():
|
||||||
|
lines.append(f"- **{layer}**: {len(files)} files")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Recommendations
|
||||||
|
if report.recommendations:
|
||||||
|
lines.append("### 💡 Recommendations")
|
||||||
|
lines.append("")
|
||||||
|
for rec in report.recommendations:
|
||||||
|
lines.append(f"- {rec}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Overall status
|
||||||
|
if report.compliance_score >= 0.9:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("✅ **Excellent architecture compliance!**")
|
||||||
|
elif report.compliance_score >= 0.7:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("⚠️ **Some architectural issues to address**")
|
||||||
|
else:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("❌ **Significant architectural violations detected**")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -65,9 +65,10 @@ class BaseAgent(ABC):
|
|||||||
self.llm = llm_client or LLMClient.from_config(self.config)
|
self.llm = llm_client or LLMClient.from_config(self.config)
|
||||||
self.logger = logging.getLogger(self.__class__.__name__)
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
# Rate limiting
|
# Rate limiting - now configurable
|
||||||
self._last_request_time = 0.0
|
self._last_request_time = 0.0
|
||||||
self._min_request_interval = 1.0 # seconds
|
rate_limits = self.config.get("rate_limits", {})
|
||||||
|
self._min_request_interval = rate_limits.get("min_interval", 1.0) # seconds
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_config() -> dict:
|
def _load_config() -> dict:
|
||||||
|
|||||||
548
tools/ai-review/agents/dependency_agent.py
Normal file
548
tools/ai-review/agents/dependency_agent.py
Normal file
@@ -0,0 +1,548 @@
|
|||||||
|
"""Dependency Security Agent
|
||||||
|
|
||||||
|
AI agent for scanning dependency files for known vulnerabilities
|
||||||
|
and outdated packages. Supports multiple package managers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DependencyFinding:
|
||||||
|
"""A security finding in a dependency."""
|
||||||
|
|
||||||
|
package: str
|
||||||
|
version: str
|
||||||
|
severity: str # CRITICAL, HIGH, MEDIUM, LOW
|
||||||
|
vulnerability_id: str # CVE, GHSA, etc.
|
||||||
|
title: str
|
||||||
|
description: str
|
||||||
|
fixed_version: str | None = None
|
||||||
|
references: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DependencyReport:
|
||||||
|
"""Report of dependency analysis."""
|
||||||
|
|
||||||
|
total_packages: int
|
||||||
|
vulnerable_packages: int
|
||||||
|
outdated_packages: int
|
||||||
|
findings: list[DependencyFinding]
|
||||||
|
recommendations: list[str]
|
||||||
|
files_scanned: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyAgent(BaseAgent):
|
||||||
|
"""Agent for scanning dependencies for security vulnerabilities."""
|
||||||
|
|
||||||
|
# Marker for dependency comments
|
||||||
|
DEP_AI_MARKER = "<!-- AI_DEPENDENCY_SCAN -->"
|
||||||
|
|
||||||
|
# Supported dependency files
|
||||||
|
DEPENDENCY_FILES = {
|
||||||
|
"python": ["requirements.txt", "Pipfile", "pyproject.toml", "setup.py"],
|
||||||
|
"javascript": ["package.json", "package-lock.json", "yarn.lock"],
|
||||||
|
"ruby": ["Gemfile", "Gemfile.lock"],
|
||||||
|
"go": ["go.mod", "go.sum"],
|
||||||
|
"rust": ["Cargo.toml", "Cargo.lock"],
|
||||||
|
"java": ["pom.xml", "build.gradle", "build.gradle.kts"],
|
||||||
|
"php": ["composer.json", "composer.lock"],
|
||||||
|
"dotnet": ["*.csproj", "packages.config", "*.fsproj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Common vulnerable package patterns
|
||||||
|
KNOWN_VULNERABILITIES = {
|
||||||
|
"python": {
|
||||||
|
"requests": {
|
||||||
|
"< 2.31.0": "CVE-2023-32681 - Proxy-Authorization header leak"
|
||||||
|
},
|
||||||
|
"urllib3": {
|
||||||
|
"< 2.0.7": "CVE-2023-45803 - Request body not stripped on redirects"
|
||||||
|
},
|
||||||
|
"cryptography": {"< 41.0.0": "Multiple CVEs - Update recommended"},
|
||||||
|
"pillow": {"< 10.0.0": "CVE-2023-4863 - WebP vulnerability"},
|
||||||
|
"django": {"< 4.2.0": "Multiple security fixes"},
|
||||||
|
"flask": {"< 2.3.0": "Security improvements"},
|
||||||
|
"pyyaml": {"< 6.0": "CVE-2020-14343 - Arbitrary code execution"},
|
||||||
|
"jinja2": {"< 3.1.0": "Security fixes"},
|
||||||
|
},
|
||||||
|
"javascript": {
|
||||||
|
"lodash": {"< 4.17.21": "CVE-2021-23337 - Prototype pollution"},
|
||||||
|
"axios": {"< 1.6.0": "CVE-2023-45857 - CSRF vulnerability"},
|
||||||
|
"express": {"< 4.18.0": "Security updates"},
|
||||||
|
"jquery": {"< 3.5.0": "XSS vulnerabilities"},
|
||||||
|
"minimist": {"< 1.2.6": "Prototype pollution"},
|
||||||
|
"node-fetch": {"< 3.3.0": "Security fixes"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||||
|
"""Check if this agent handles the given event."""
|
||||||
|
agent_config = self.config.get("agents", {}).get("dependency", {})
|
||||||
|
if not agent_config.get("enabled", True):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle PR events that modify dependency files
|
||||||
|
if event_type == "pull_request":
|
||||||
|
action = event_data.get("action", "")
|
||||||
|
if action in ["opened", "synchronize"]:
|
||||||
|
# Check if any dependency files are modified
|
||||||
|
files = event_data.get("files", [])
|
||||||
|
for f in files:
|
||||||
|
if self._is_dependency_file(f.get("filename", "")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Handle @codebot check-deps command
|
||||||
|
if event_type == "issue_comment":
|
||||||
|
comment_body = event_data.get("comment", {}).get("body", "")
|
||||||
|
mention_prefix = self.config.get("interaction", {}).get(
|
||||||
|
"mention_prefix", "@codebot"
|
||||||
|
)
|
||||||
|
if f"{mention_prefix} check-deps" in comment_body.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_dependency_file(self, filename: str) -> bool:
|
||||||
|
"""Check if a file is a dependency file."""
|
||||||
|
basename = os.path.basename(filename)
|
||||||
|
for lang, files in self.DEPENDENCY_FILES.items():
|
||||||
|
for pattern in files:
|
||||||
|
if pattern.startswith("*"):
|
||||||
|
if basename.endswith(pattern[1:]):
|
||||||
|
return True
|
||||||
|
elif basename == pattern:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute(self, context: AgentContext) -> AgentResult:
|
||||||
|
"""Execute the dependency agent."""
|
||||||
|
self.logger.info(f"Scanning dependencies for {context.owner}/{context.repo}")
|
||||||
|
|
||||||
|
actions_taken = []
|
||||||
|
|
||||||
|
# Determine if this is a command or PR event
|
||||||
|
if context.event_type == "issue_comment":
|
||||||
|
issue = context.event_data.get("issue", {})
|
||||||
|
issue_number = issue.get("number")
|
||||||
|
comment_author = (
|
||||||
|
context.event_data.get("comment", {})
|
||||||
|
.get("user", {})
|
||||||
|
.get("login", "user")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pr = context.event_data.get("pull_request", {})
|
||||||
|
issue_number = pr.get("number")
|
||||||
|
comment_author = None
|
||||||
|
|
||||||
|
# Collect dependency files
|
||||||
|
dep_files = self._collect_dependency_files(context.owner, context.repo)
|
||||||
|
if not dep_files:
|
||||||
|
message = "No dependency files found in repository."
|
||||||
|
if issue_number:
|
||||||
|
self.gitea.create_issue_comment(
|
||||||
|
context.owner,
|
||||||
|
context.repo,
|
||||||
|
issue_number,
|
||||||
|
f"{self.AI_DISCLAIMER}\n\n{message}",
|
||||||
|
)
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions_taken.append(f"Found {len(dep_files)} dependency files")
|
||||||
|
|
||||||
|
# Analyze dependencies
|
||||||
|
report = self._analyze_dependencies(context.owner, context.repo, dep_files)
|
||||||
|
actions_taken.append(f"Analyzed {report.total_packages} packages")
|
||||||
|
|
||||||
|
# Run external scanners if available
|
||||||
|
external_findings = self._run_external_scanners(context.owner, context.repo)
|
||||||
|
if external_findings:
|
||||||
|
report.findings.extend(external_findings)
|
||||||
|
actions_taken.append(
|
||||||
|
f"External scanner found {len(external_findings)} issues"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate and post report
|
||||||
|
if issue_number:
|
||||||
|
comment = self._format_dependency_report(report, comment_author)
|
||||||
|
self.upsert_comment(
|
||||||
|
context.owner,
|
||||||
|
context.repo,
|
||||||
|
issue_number,
|
||||||
|
comment,
|
||||||
|
marker=self.DEP_AI_MARKER,
|
||||||
|
)
|
||||||
|
actions_taken.append("Posted dependency report")
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
message=f"Dependency scan complete: {report.vulnerable_packages} vulnerable, {report.outdated_packages} outdated",
|
||||||
|
data={
|
||||||
|
"total_packages": report.total_packages,
|
||||||
|
"vulnerable_packages": report.vulnerable_packages,
|
||||||
|
"outdated_packages": report.outdated_packages,
|
||||||
|
"findings_count": len(report.findings),
|
||||||
|
},
|
||||||
|
actions_taken=actions_taken,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _collect_dependency_files(
|
||||||
|
self, owner: str, repo: str
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Collect all dependency files from the repository."""
|
||||||
|
dep_files = {}
|
||||||
|
|
||||||
|
# Common paths to check
|
||||||
|
paths_to_check = [
|
||||||
|
"", # Root
|
||||||
|
"backend/",
|
||||||
|
"frontend/",
|
||||||
|
"api/",
|
||||||
|
"services/",
|
||||||
|
]
|
||||||
|
|
||||||
|
for base_path in paths_to_check:
|
||||||
|
for lang, filenames in self.DEPENDENCY_FILES.items():
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.startswith("*"):
|
||||||
|
continue # Skip glob patterns for now
|
||||||
|
|
||||||
|
filepath = f"{base_path}{filename}".lstrip("/")
|
||||||
|
try:
|
||||||
|
content_data = self.gitea.get_file_contents(
|
||||||
|
owner, repo, filepath
|
||||||
|
)
|
||||||
|
if content_data.get("content"):
|
||||||
|
content = base64.b64decode(content_data["content"]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
dep_files[filepath] = {
|
||||||
|
"language": lang,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass # File doesn't exist
|
||||||
|
|
||||||
|
return dep_files
|
||||||
|
|
||||||
|
def _analyze_dependencies(
|
||||||
|
self, owner: str, repo: str, dep_files: dict
|
||||||
|
) -> DependencyReport:
|
||||||
|
"""Analyze dependency files for vulnerabilities."""
|
||||||
|
findings = []
|
||||||
|
total_packages = 0
|
||||||
|
vulnerable_count = 0
|
||||||
|
outdated_count = 0
|
||||||
|
recommendations = []
|
||||||
|
files_scanned = list(dep_files.keys())
|
||||||
|
|
||||||
|
for filepath, file_info in dep_files.items():
|
||||||
|
lang = file_info["language"]
|
||||||
|
content = file_info["content"]
|
||||||
|
|
||||||
|
if lang == "python":
|
||||||
|
packages = self._parse_python_deps(content, filepath)
|
||||||
|
elif lang == "javascript":
|
||||||
|
packages = self._parse_javascript_deps(content, filepath)
|
||||||
|
else:
|
||||||
|
packages = []
|
||||||
|
|
||||||
|
total_packages += len(packages)
|
||||||
|
|
||||||
|
# Check for known vulnerabilities
|
||||||
|
known_vulns = self.KNOWN_VULNERABILITIES.get(lang, {})
|
||||||
|
for pkg_name, version in packages:
|
||||||
|
if pkg_name.lower() in known_vulns:
|
||||||
|
vuln_info = known_vulns[pkg_name.lower()]
|
||||||
|
for version_constraint, vuln_desc in vuln_info.items():
|
||||||
|
if self._version_matches_constraint(
|
||||||
|
version, version_constraint
|
||||||
|
):
|
||||||
|
findings.append(
|
||||||
|
DependencyFinding(
|
||||||
|
package=pkg_name,
|
||||||
|
version=version or "unknown",
|
||||||
|
severity="HIGH",
|
||||||
|
vulnerability_id=vuln_desc.split(" - ")[0]
|
||||||
|
if " - " in vuln_desc
|
||||||
|
else "VULN",
|
||||||
|
title=vuln_desc,
|
||||||
|
description=f"Package {pkg_name} version {version} has known vulnerabilities",
|
||||||
|
fixed_version=version_constraint.replace("< ", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vulnerable_count += 1
|
||||||
|
|
||||||
|
# Add recommendations
|
||||||
|
if vulnerable_count > 0:
|
||||||
|
recommendations.append(
|
||||||
|
f"Update {vulnerable_count} packages with known vulnerabilities"
|
||||||
|
)
|
||||||
|
if total_packages > 50:
|
||||||
|
recommendations.append(
|
||||||
|
"Consider auditing dependencies to reduce attack surface"
|
||||||
|
)
|
||||||
|
|
||||||
|
return DependencyReport(
|
||||||
|
total_packages=total_packages,
|
||||||
|
vulnerable_packages=vulnerable_count,
|
||||||
|
outdated_packages=outdated_count,
|
||||||
|
findings=findings,
|
||||||
|
recommendations=recommendations,
|
||||||
|
files_scanned=files_scanned,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_python_deps(
|
||||||
|
self, content: str, filepath: str
|
||||||
|
) -> list[tuple[str, str | None]]:
|
||||||
|
"""Parse Python dependency file."""
|
||||||
|
packages = []
|
||||||
|
|
||||||
|
if "requirements" in filepath.lower():
|
||||||
|
# requirements.txt format
|
||||||
|
for line in content.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#") or line.startswith("-"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse package==version, package>=version, package
|
||||||
|
match = re.match(r"([a-zA-Z0-9_-]+)([<>=!]+)?(.+)?", line)
|
||||||
|
if match:
|
||||||
|
pkg_name = match.group(1)
|
||||||
|
version = match.group(3) if match.group(3) else None
|
||||||
|
packages.append((pkg_name, version))
|
||||||
|
|
||||||
|
elif filepath.endswith("pyproject.toml"):
|
||||||
|
# pyproject.toml format
|
||||||
|
in_deps = False
|
||||||
|
for line in content.splitlines():
|
||||||
|
if (
|
||||||
|
"[project.dependencies]" in line
|
||||||
|
or "[tool.poetry.dependencies]" in line
|
||||||
|
):
|
||||||
|
in_deps = True
|
||||||
|
continue
|
||||||
|
if in_deps:
|
||||||
|
if line.startswith("["):
|
||||||
|
in_deps = False
|
||||||
|
continue
|
||||||
|
match = re.match(r'"?([a-zA-Z0-9_-]+)"?\s*[=<>]', line)
|
||||||
|
if match:
|
||||||
|
packages.append((match.group(1), None))
|
||||||
|
|
||||||
|
return packages
|
||||||
|
|
||||||
|
def _parse_javascript_deps(
|
||||||
|
self, content: str, filepath: str
|
||||||
|
) -> list[tuple[str, str | None]]:
|
||||||
|
"""Parse JavaScript dependency file."""
|
||||||
|
packages = []
|
||||||
|
|
||||||
|
if filepath.endswith("package.json"):
|
||||||
|
try:
|
||||||
|
data = json.loads(content)
|
||||||
|
for dep_type in ["dependencies", "devDependencies"]:
|
||||||
|
deps = data.get(dep_type, {})
|
||||||
|
for name, version in deps.items():
|
||||||
|
# Strip version prefixes like ^, ~, >=
|
||||||
|
clean_version = re.sub(r"^[\^~>=<]+", "", version)
|
||||||
|
packages.append((name, clean_version))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return packages
|
||||||
|
|
||||||
|
def _version_matches_constraint(self, version: str | None, constraint: str) -> bool:
|
||||||
|
"""Check if version matches a vulnerability constraint."""
|
||||||
|
if not version:
|
||||||
|
return True # Assume vulnerable if version unknown
|
||||||
|
|
||||||
|
# Simple version comparison
|
||||||
|
if constraint.startswith("< "):
|
||||||
|
target = constraint[2:]
|
||||||
|
try:
|
||||||
|
return self._compare_versions(version, target) < 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _compare_versions(self, v1: str, v2: str) -> int:
|
||||||
|
"""Compare two version strings. Returns -1, 0, or 1."""
|
||||||
|
|
||||||
|
def normalize(v):
|
||||||
|
return [int(x) for x in re.sub(r"[^0-9.]", "", v).split(".") if x]
|
||||||
|
|
||||||
|
try:
|
||||||
|
parts1 = normalize(v1)
|
||||||
|
parts2 = normalize(v2)
|
||||||
|
|
||||||
|
for i in range(max(len(parts1), len(parts2))):
|
||||||
|
p1 = parts1[i] if i < len(parts1) else 0
|
||||||
|
p2 = parts2[i] if i < len(parts2) else 0
|
||||||
|
if p1 < p2:
|
||||||
|
return -1
|
||||||
|
if p1 > p2:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _run_external_scanners(self, owner: str, repo: str) -> list[DependencyFinding]:
|
||||||
|
"""Run external vulnerability scanners if available."""
|
||||||
|
findings = []
|
||||||
|
agent_config = self.config.get("agents", {}).get("dependency", {})
|
||||||
|
|
||||||
|
# Try pip-audit for Python
|
||||||
|
if agent_config.get("pip_audit", False):
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["pip-audit", "--format", "json"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
for vuln in data.get("vulnerabilities", []):
|
||||||
|
findings.append(
|
||||||
|
DependencyFinding(
|
||||||
|
package=vuln.get("name", ""),
|
||||||
|
version=vuln.get("version", ""),
|
||||||
|
severity=vuln.get("severity", "MEDIUM"),
|
||||||
|
vulnerability_id=vuln.get("id", ""),
|
||||||
|
title=vuln.get("description", "")[:100],
|
||||||
|
description=vuln.get("description", ""),
|
||||||
|
fixed_version=vuln.get("fix_versions", [None])[0],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"pip-audit not available: {e}")
|
||||||
|
|
||||||
|
# Try npm audit for JavaScript
|
||||||
|
if agent_config.get("npm_audit", False):
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["npm", "audit", "--json"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
for vuln_id, vuln in data.get("vulnerabilities", {}).items():
|
||||||
|
findings.append(
|
||||||
|
DependencyFinding(
|
||||||
|
package=vuln.get("name", vuln_id),
|
||||||
|
version=vuln.get("range", ""),
|
||||||
|
severity=vuln.get("severity", "moderate").upper(),
|
||||||
|
vulnerability_id=vuln_id,
|
||||||
|
title=vuln.get("title", ""),
|
||||||
|
description=vuln.get("overview", ""),
|
||||||
|
fixed_version=vuln.get("fixAvailable", {}).get("version"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"npm audit not available: {e}")
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
def _format_dependency_report(
|
||||||
|
self, report: DependencyReport, user: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Format the dependency report as a comment."""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
if user:
|
||||||
|
lines.append(f"@{user}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
f"{self.AI_DISCLAIMER}",
|
||||||
|
"",
|
||||||
|
"## 🔍 Dependency Security Scan",
|
||||||
|
"",
|
||||||
|
"### Summary",
|
||||||
|
"",
|
||||||
|
f"| Metric | Value |",
|
||||||
|
f"|--------|-------|",
|
||||||
|
f"| Total Packages | {report.total_packages} |",
|
||||||
|
f"| Vulnerable | {report.vulnerable_packages} |",
|
||||||
|
f"| Outdated | {report.outdated_packages} |",
|
||||||
|
f"| Files Scanned | {len(report.files_scanned)} |",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Findings by severity
|
||||||
|
if report.findings:
|
||||||
|
lines.append("### 🚨 Security Findings")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Group by severity
|
||||||
|
by_severity = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []}
|
||||||
|
for finding in report.findings:
|
||||||
|
sev = finding.severity.upper()
|
||||||
|
if sev in by_severity:
|
||||||
|
by_severity[sev].append(finding)
|
||||||
|
|
||||||
|
severity_emoji = {
|
||||||
|
"CRITICAL": "🔴",
|
||||||
|
"HIGH": "🟠",
|
||||||
|
"MEDIUM": "🟡",
|
||||||
|
"LOW": "🔵",
|
||||||
|
}
|
||||||
|
|
||||||
|
for severity in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]:
|
||||||
|
findings = by_severity[severity]
|
||||||
|
if findings:
|
||||||
|
lines.append(f"#### {severity_emoji[severity]} {severity}")
|
||||||
|
lines.append("")
|
||||||
|
for f in findings[:10]: # Limit display
|
||||||
|
lines.append(f"- **{f.package}** `{f.version}`")
|
||||||
|
lines.append(f" - {f.vulnerability_id}: {f.title}")
|
||||||
|
if f.fixed_version:
|
||||||
|
lines.append(f" - ✅ Fix: Upgrade to `{f.fixed_version}`")
|
||||||
|
if len(findings) > 10:
|
||||||
|
lines.append(f" - ... and {len(findings) - 10} more")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Files scanned
|
||||||
|
lines.append("### 📁 Files Scanned")
|
||||||
|
lines.append("")
|
||||||
|
for f in report.files_scanned:
|
||||||
|
lines.append(f"- `{f}`")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Recommendations
|
||||||
|
if report.recommendations:
|
||||||
|
lines.append("### 💡 Recommendations")
|
||||||
|
lines.append("")
|
||||||
|
for rec in report.recommendations:
|
||||||
|
lines.append(f"- {rec}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Overall status
|
||||||
|
if report.vulnerable_packages == 0:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("✅ **No known vulnerabilities detected**")
|
||||||
|
else:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append(
|
||||||
|
f"⚠️ **{report.vulnerable_packages} vulnerable packages require attention**"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -365,9 +365,20 @@ class IssueAgent(BaseAgent):
|
|||||||
"commands", ["explain", "suggest", "security", "summarize", "triage"]
|
"commands", ["explain", "suggest", "security", "summarize", "triage"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Also check for setup-labels command (not in config since it's a setup command)
|
# Built-in commands not in config
|
||||||
if f"{mention_prefix} setup-labels" in body.lower():
|
builtin_commands = [
|
||||||
return "setup-labels"
|
"setup-labels",
|
||||||
|
"check-deps",
|
||||||
|
"suggest-tests",
|
||||||
|
"refactor-suggest",
|
||||||
|
"architecture",
|
||||||
|
"arch-check",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check built-in commands first
|
||||||
|
for command in builtin_commands:
|
||||||
|
if f"{mention_prefix} {command}" in body.lower():
|
||||||
|
return command
|
||||||
|
|
||||||
for command in commands:
|
for command in commands:
|
||||||
if f"{mention_prefix} {command}" in body.lower():
|
if f"{mention_prefix} {command}" in body.lower():
|
||||||
@@ -392,6 +403,14 @@ class IssueAgent(BaseAgent):
|
|||||||
return self._command_triage(context, issue)
|
return self._command_triage(context, issue)
|
||||||
elif command == "setup-labels":
|
elif command == "setup-labels":
|
||||||
return self._command_setup_labels(context, issue)
|
return self._command_setup_labels(context, issue)
|
||||||
|
elif command == "check-deps":
|
||||||
|
return self._command_check_deps(context)
|
||||||
|
elif command == "suggest-tests":
|
||||||
|
return self._command_suggest_tests(context)
|
||||||
|
elif command == "refactor-suggest":
|
||||||
|
return self._command_refactor_suggest(context)
|
||||||
|
elif command == "architecture" or command == "arch-check":
|
||||||
|
return self._command_architecture(context)
|
||||||
|
|
||||||
return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`."
|
return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`."
|
||||||
|
|
||||||
@@ -464,6 +483,12 @@ Be practical and concise."""
|
|||||||
- `{mention_prefix} suggest` - Solution suggestions or next steps
|
- `{mention_prefix} suggest` - Solution suggestions or next steps
|
||||||
- `{mention_prefix} security` - Security-focused analysis of the issue
|
- `{mention_prefix} security` - Security-focused analysis of the issue
|
||||||
|
|
||||||
|
### Code Quality & Security
|
||||||
|
- `{mention_prefix} check-deps` - Scan dependencies for security vulnerabilities
|
||||||
|
- `{mention_prefix} suggest-tests` - Suggest test cases for changed/new code
|
||||||
|
- `{mention_prefix} refactor-suggest` - Suggest refactoring opportunities
|
||||||
|
- `{mention_prefix} architecture` - Check architecture compliance (alias: `arch-check`)
|
||||||
|
|
||||||
### Interactive Chat
|
### Interactive Chat
|
||||||
- `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools)
|
- `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools)
|
||||||
- Example: `{mention_prefix} how does authentication work?`
|
- Example: `{mention_prefix} how does authentication work?`
|
||||||
@@ -494,9 +519,19 @@ PR reviews run automatically when you open or update a pull request. The bot pro
|
|||||||
{mention_prefix} triage
|
{mention_prefix} triage
|
||||||
```
|
```
|
||||||
|
|
||||||
**Get help understanding:**
|
**Check for dependency vulnerabilities:**
|
||||||
```
|
```
|
||||||
{mention_prefix} explain
|
{mention_prefix} check-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
**Get test suggestions:**
|
||||||
|
```
|
||||||
|
{mention_prefix} suggest-tests
|
||||||
|
```
|
||||||
|
|
||||||
|
**Check architecture compliance:**
|
||||||
|
```
|
||||||
|
{mention_prefix} architecture
|
||||||
```
|
```
|
||||||
|
|
||||||
**Ask about the codebase:**
|
**Ask about the codebase:**
|
||||||
@@ -504,11 +539,6 @@ PR reviews run automatically when you open or update a pull request. The bot pro
|
|||||||
{mention_prefix} how does the authentication system work?
|
{mention_prefix} how does the authentication system work?
|
||||||
```
|
```
|
||||||
|
|
||||||
**Setup repository labels:**
|
|
||||||
```
|
|
||||||
{mention_prefix} setup-labels
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
*For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)*
|
*For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)*
|
||||||
@@ -854,3 +884,145 @@ PR reviews run automatically when you open or update a pull request. The bot pro
|
|||||||
return f"{prefix} - {value}"
|
return f"{prefix} - {value}"
|
||||||
else: # colon or unknown
|
else: # colon or unknown
|
||||||
return base_name
|
return base_name
|
||||||
|
|
||||||
|
def _command_check_deps(self, context: AgentContext) -> str:
|
||||||
|
"""Check dependencies for security vulnerabilities."""
|
||||||
|
try:
|
||||||
|
from agents.dependency_agent import DependencyAgent
|
||||||
|
|
||||||
|
agent = DependencyAgent(config=self.config)
|
||||||
|
result = agent.run(context)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
return result.data.get(
|
||||||
|
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Failed**\n\n{result.error or result.message}"
|
||||||
|
except ImportError:
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Dependency Agent Not Available**\n\nThe dependency security agent is not installed."
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Dependency check failed: {e}")
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Error**\n\n{e}"
|
||||||
|
|
||||||
|
def _command_suggest_tests(self, context: AgentContext) -> str:
|
||||||
|
"""Suggest tests for changed or new code."""
|
||||||
|
try:
|
||||||
|
from agents.test_coverage_agent import TestCoverageAgent
|
||||||
|
|
||||||
|
agent = TestCoverageAgent(config=self.config)
|
||||||
|
result = agent.run(context)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
return result.data.get(
|
||||||
|
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Failed**\n\n{result.error or result.message}"
|
||||||
|
except ImportError:
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Test Coverage Agent Not Available**\n\nThe test coverage agent is not installed."
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Test suggestion failed: {e}")
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Error**\n\n{e}"
|
||||||
|
|
||||||
|
def _command_architecture(self, context: AgentContext) -> str:
|
||||||
|
"""Check architecture compliance."""
|
||||||
|
try:
|
||||||
|
from agents.architecture_agent import ArchitectureAgent
|
||||||
|
|
||||||
|
agent = ArchitectureAgent(config=self.config)
|
||||||
|
result = agent.run(context)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
return result.data.get(
|
||||||
|
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Failed**\n\n{result.error or result.message}"
|
||||||
|
except ImportError:
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Architecture Agent Not Available**\n\nThe architecture compliance agent is not installed."
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Architecture check failed: {e}")
|
||||||
|
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Error**\n\n{e}"
|
||||||
|
|
||||||
|
def _command_refactor_suggest(self, context: AgentContext) -> str:
|
||||||
|
"""Suggest refactoring opportunities."""
|
||||||
|
issue = context.event_data.get("issue", {})
|
||||||
|
title = issue.get("title", "")
|
||||||
|
body = issue.get("body", "")
|
||||||
|
|
||||||
|
# Use LLM to analyze for refactoring opportunities
|
||||||
|
prompt = f"""Analyze the following issue/context and suggest refactoring opportunities:
|
||||||
|
|
||||||
|
Issue Title: {title}
|
||||||
|
Issue Body: {body}
|
||||||
|
|
||||||
|
Based on common refactoring patterns, suggest:
|
||||||
|
1. Code smell detection (if any code is referenced)
|
||||||
|
2. Design pattern opportunities
|
||||||
|
3. Complexity reduction suggestions
|
||||||
|
4. DRY principle violations
|
||||||
|
5. SOLID principle improvements
|
||||||
|
|
||||||
|
Format your response as a structured report with actionable recommendations.
|
||||||
|
If no code is referenced in the issue, provide general refactoring guidance based on the context.
|
||||||
|
|
||||||
|
Return as JSON:
|
||||||
|
{{
|
||||||
|
"summary": "Brief summary of refactoring opportunities",
|
||||||
|
"suggestions": [
|
||||||
|
{{
|
||||||
|
"category": "Code Smell | Design Pattern | Complexity | DRY | SOLID",
|
||||||
|
"title": "Short title",
|
||||||
|
"description": "Detailed description",
|
||||||
|
"priority": "high | medium | low",
|
||||||
|
"effort": "small | medium | large"
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"general_advice": "Any general refactoring advice"
|
||||||
|
}}"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.call_llm_json(prompt)
|
||||||
|
|
||||||
|
lines = [f"{self.AI_DISCLAIMER}\n"]
|
||||||
|
lines.append("## Refactoring Suggestions\n")
|
||||||
|
|
||||||
|
if result.get("summary"):
|
||||||
|
lines.append(f"**Summary:** {result['summary']}\n")
|
||||||
|
|
||||||
|
suggestions = result.get("suggestions", [])
|
||||||
|
if suggestions:
|
||||||
|
lines.append("### Recommendations\n")
|
||||||
|
lines.append("| Priority | Category | Suggestion | Effort |")
|
||||||
|
lines.append("|----------|----------|------------|--------|")
|
||||||
|
|
||||||
|
for s in suggestions:
|
||||||
|
priority = s.get("priority", "medium").upper()
|
||||||
|
priority_icon = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🟢"}.get(
|
||||||
|
priority, "⚪"
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
f"| {priority_icon} {priority} | {s.get('category', 'General')} | "
|
||||||
|
f"**{s.get('title', 'Suggestion')}** | {s.get('effort', 'medium')} |"
|
||||||
|
)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Detailed descriptions
|
||||||
|
lines.append("### Details\n")
|
||||||
|
for i, s in enumerate(suggestions, 1):
|
||||||
|
lines.append(f"**{i}. {s.get('title', 'Suggestion')}**")
|
||||||
|
lines.append(f"{s.get('description', 'No description')}\n")
|
||||||
|
|
||||||
|
if result.get("general_advice"):
|
||||||
|
lines.append("### General Advice\n")
|
||||||
|
lines.append(result["general_advice"])
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Refactor suggestion failed: {e}")
|
||||||
|
return (
|
||||||
|
f"{self.AI_DISCLAIMER}\n\n**Refactor Suggestion Failed**\n\nError: {e}"
|
||||||
|
)
|
||||||
|
|||||||
480
tools/ai-review/agents/test_coverage_agent.py
Normal file
480
tools/ai-review/agents/test_coverage_agent.py
Normal file
@@ -0,0 +1,480 @@
|
|||||||
|
"""Test Coverage Agent
|
||||||
|
|
||||||
|
AI agent for analyzing code changes and suggesting test cases.
|
||||||
|
Helps improve test coverage by identifying untested code paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestSuggestion:
|
||||||
|
"""A suggested test case."""
|
||||||
|
|
||||||
|
function_name: str
|
||||||
|
file_path: str
|
||||||
|
test_type: str # unit, integration, edge_case
|
||||||
|
description: str
|
||||||
|
example_code: str | None = None
|
||||||
|
priority: str = "MEDIUM" # HIGH, MEDIUM, LOW
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CoverageReport:
|
||||||
|
"""Report of test coverage analysis."""
|
||||||
|
|
||||||
|
functions_analyzed: int
|
||||||
|
functions_with_tests: int
|
||||||
|
functions_without_tests: int
|
||||||
|
suggestions: list[TestSuggestion]
|
||||||
|
existing_tests: list[str]
|
||||||
|
coverage_estimate: float
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoverageAgent(BaseAgent):
|
||||||
|
"""Agent for analyzing test coverage and suggesting tests."""
|
||||||
|
|
||||||
|
# Marker for test coverage comments
|
||||||
|
TEST_AI_MARKER = "<!-- AI_TEST_COVERAGE -->"
|
||||||
|
|
||||||
|
# Test file patterns by language
|
||||||
|
TEST_PATTERNS = {
|
||||||
|
"python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"],
|
||||||
|
"javascript": [
|
||||||
|
r".*\.test\.[jt]sx?$",
|
||||||
|
r".*\.spec\.[jt]sx?$",
|
||||||
|
r"__tests__/.*\.[jt]sx?$",
|
||||||
|
],
|
||||||
|
"go": [r".*_test\.go$"],
|
||||||
|
"rust": [r"tests?/.*\.rs$"],
|
||||||
|
"java": [r".*Test\.java$", r".*Tests\.java$"],
|
||||||
|
"ruby": [r".*_spec\.rb$", r"test_.*\.rb$"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function/method patterns by language
|
||||||
|
FUNCTION_PATTERNS = {
|
||||||
|
"python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(",
|
||||||
|
"javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))",
|
||||||
|
"go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(",
|
||||||
|
"rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)",
|
||||||
|
"java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{",
|
||||||
|
"ruby": r"^\s*def\s+(\w+)",
|
||||||
|
}
|
||||||
|
|
||||||
|
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||||
|
"""Check if this agent handles the given event."""
|
||||||
|
agent_config = self.config.get("agents", {}).get("test_coverage", {})
|
||||||
|
if not agent_config.get("enabled", True):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle @codebot suggest-tests command
|
||||||
|
if event_type == "issue_comment":
|
||||||
|
comment_body = event_data.get("comment", {}).get("body", "")
|
||||||
|
mention_prefix = self.config.get("interaction", {}).get(
|
||||||
|
"mention_prefix", "@codebot"
|
||||||
|
)
|
||||||
|
if f"{mention_prefix} suggest-tests" in comment_body.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute(self, context: AgentContext) -> AgentResult:
|
||||||
|
"""Execute the test coverage agent."""
|
||||||
|
self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}")
|
||||||
|
|
||||||
|
actions_taken = []
|
||||||
|
|
||||||
|
# Get issue/PR number and author
|
||||||
|
issue = context.event_data.get("issue", {})
|
||||||
|
issue_number = issue.get("number")
|
||||||
|
comment_author = (
|
||||||
|
context.event_data.get("comment", {}).get("user", {}).get("login", "user")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if this is a PR
|
||||||
|
is_pr = issue.get("pull_request") is not None
|
||||||
|
|
||||||
|
if is_pr:
|
||||||
|
# Analyze PR diff for changed functions
|
||||||
|
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
|
||||||
|
changed_functions = self._extract_changed_functions(diff)
|
||||||
|
actions_taken.append(f"Analyzed {len(changed_functions)} changed functions")
|
||||||
|
else:
|
||||||
|
# Analyze entire repository
|
||||||
|
changed_functions = self._analyze_repository(context.owner, context.repo)
|
||||||
|
actions_taken.append(
|
||||||
|
f"Analyzed {len(changed_functions)} functions in repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find existing tests
|
||||||
|
existing_tests = self._find_existing_tests(context.owner, context.repo)
|
||||||
|
actions_taken.append(f"Found {len(existing_tests)} existing test files")
|
||||||
|
|
||||||
|
# Generate test suggestions using LLM
|
||||||
|
report = self._generate_suggestions(
|
||||||
|
context.owner, context.repo, changed_functions, existing_tests
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post report
|
||||||
|
if issue_number:
|
||||||
|
comment = self._format_coverage_report(report, comment_author, is_pr)
|
||||||
|
self.upsert_comment(
|
||||||
|
context.owner,
|
||||||
|
context.repo,
|
||||||
|
issue_number,
|
||||||
|
comment,
|
||||||
|
marker=self.TEST_AI_MARKER,
|
||||||
|
)
|
||||||
|
actions_taken.append("Posted test coverage report")
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
success=True,
|
||||||
|
message=f"Generated {len(report.suggestions)} test suggestions",
|
||||||
|
data={
|
||||||
|
"functions_analyzed": report.functions_analyzed,
|
||||||
|
"suggestions_count": len(report.suggestions),
|
||||||
|
"coverage_estimate": report.coverage_estimate,
|
||||||
|
},
|
||||||
|
actions_taken=actions_taken,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
|
||||||
|
"""Get the PR diff."""
|
||||||
|
try:
|
||||||
|
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to get PR diff: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _extract_changed_functions(self, diff: str) -> list[dict]:
|
||||||
|
"""Extract changed functions from diff."""
|
||||||
|
functions = []
|
||||||
|
current_file = None
|
||||||
|
current_language = None
|
||||||
|
|
||||||
|
for line in diff.splitlines():
|
||||||
|
# Track current file
|
||||||
|
if line.startswith("diff --git"):
|
||||||
|
match = re.search(r"b/(.+)$", line)
|
||||||
|
if match:
|
||||||
|
current_file = match.group(1)
|
||||||
|
current_language = self._detect_language(current_file)
|
||||||
|
|
||||||
|
# Look for function definitions in added lines
|
||||||
|
if line.startswith("+") and not line.startswith("+++"):
|
||||||
|
if current_language and current_language in self.FUNCTION_PATTERNS:
|
||||||
|
pattern = self.FUNCTION_PATTERNS[current_language]
|
||||||
|
match = re.search(pattern, line[1:]) # Skip the + prefix
|
||||||
|
if match:
|
||||||
|
func_name = next(g for g in match.groups() if g)
|
||||||
|
functions.append(
|
||||||
|
{
|
||||||
|
"name": func_name,
|
||||||
|
"file": current_file,
|
||||||
|
"language": current_language,
|
||||||
|
"line": line[1:].strip(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return functions
|
||||||
|
|
||||||
|
def _analyze_repository(self, owner: str, repo: str) -> list[dict]:
|
||||||
|
"""Analyze repository for functions without tests."""
|
||||||
|
functions = []
|
||||||
|
code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"}
|
||||||
|
|
||||||
|
# Get repository contents (limited to avoid API exhaustion)
|
||||||
|
try:
|
||||||
|
contents = self.gitea.get_file_contents(owner, repo, "")
|
||||||
|
if isinstance(contents, list):
|
||||||
|
for item in contents[:50]: # Limit files
|
||||||
|
if item.get("type") == "file":
|
||||||
|
filepath = item.get("path", "")
|
||||||
|
ext = os.path.splitext(filepath)[1]
|
||||||
|
if ext in code_extensions:
|
||||||
|
file_functions = self._extract_functions_from_file(
|
||||||
|
owner, repo, filepath
|
||||||
|
)
|
||||||
|
functions.extend(file_functions)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to analyze repository: {e}")
|
||||||
|
|
||||||
|
return functions[:100] # Limit total functions
|
||||||
|
|
||||||
|
def _extract_functions_from_file(
|
||||||
|
self, owner: str, repo: str, filepath: str
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Extract function definitions from a file."""
|
||||||
|
functions = []
|
||||||
|
language = self._detect_language(filepath)
|
||||||
|
|
||||||
|
if not language or language not in self.FUNCTION_PATTERNS:
|
||||||
|
return functions
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_data = self.gitea.get_file_contents(owner, repo, filepath)
|
||||||
|
if content_data.get("content"):
|
||||||
|
content = base64.b64decode(content_data["content"]).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
pattern = self.FUNCTION_PATTERNS[language]
|
||||||
|
for i, line in enumerate(content.splitlines(), 1):
|
||||||
|
match = re.search(pattern, line)
|
||||||
|
if match:
|
||||||
|
func_name = next((g for g in match.groups() if g), None)
|
||||||
|
if func_name and not func_name.startswith("_"):
|
||||||
|
functions.append(
|
||||||
|
{
|
||||||
|
"name": func_name,
|
||||||
|
"file": filepath,
|
||||||
|
"language": language,
|
||||||
|
"line_number": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return functions
|
||||||
|
|
||||||
|
def _detect_language(self, filepath: str) -> str | None:
|
||||||
|
"""Detect programming language from file path."""
|
||||||
|
ext_map = {
|
||||||
|
".py": "python",
|
||||||
|
".js": "javascript",
|
||||||
|
".jsx": "javascript",
|
||||||
|
".ts": "javascript",
|
||||||
|
".tsx": "javascript",
|
||||||
|
".go": "go",
|
||||||
|
".rs": "rust",
|
||||||
|
".java": "java",
|
||||||
|
".rb": "ruby",
|
||||||
|
}
|
||||||
|
ext = os.path.splitext(filepath)[1]
|
||||||
|
return ext_map.get(ext)
|
||||||
|
|
||||||
|
def _find_existing_tests(self, owner: str, repo: str) -> list[str]:
|
||||||
|
"""Find existing test files in the repository."""
|
||||||
|
test_files = []
|
||||||
|
|
||||||
|
# Common test directories
|
||||||
|
test_dirs = ["tests", "test", "__tests__", "spec"]
|
||||||
|
|
||||||
|
for test_dir in test_dirs:
|
||||||
|
try:
|
||||||
|
contents = self.gitea.get_file_contents(owner, repo, test_dir)
|
||||||
|
if isinstance(contents, list):
|
||||||
|
for item in contents:
|
||||||
|
if item.get("type") == "file":
|
||||||
|
test_files.append(item.get("path", ""))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Also check root for test files
|
||||||
|
try:
|
||||||
|
contents = self.gitea.get_file_contents(owner, repo, "")
|
||||||
|
if isinstance(contents, list):
|
||||||
|
for item in contents:
|
||||||
|
if item.get("type") == "file":
|
||||||
|
filepath = item.get("path", "")
|
||||||
|
if self._is_test_file(filepath):
|
||||||
|
test_files.append(filepath)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return test_files
|
||||||
|
|
||||||
|
def _is_test_file(self, filepath: str) -> bool:
|
||||||
|
"""Check if a file is a test file."""
|
||||||
|
for lang, patterns in self.TEST_PATTERNS.items():
|
||||||
|
for pattern in patterns:
|
||||||
|
if re.search(pattern, filepath):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _generate_suggestions(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
functions: list[dict],
|
||||||
|
existing_tests: list[str],
|
||||||
|
) -> CoverageReport:
|
||||||
|
"""Generate test suggestions using LLM."""
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
# Build prompt for LLM
|
||||||
|
if functions:
|
||||||
|
functions_text = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {f['name']} in {f['file']} ({f['language']})"
|
||||||
|
for f in functions[:20] # Limit for prompt size
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"""Analyze these functions and suggest test cases:
|
||||||
|
|
||||||
|
Functions to test:
|
||||||
|
{functions_text}
|
||||||
|
|
||||||
|
Existing test files:
|
||||||
|
{", ".join(existing_tests[:10]) if existing_tests else "None found"}
|
||||||
|
|
||||||
|
For each function, suggest:
|
||||||
|
1. What to test (happy path, edge cases, error handling)
|
||||||
|
2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities)
|
||||||
|
3. Brief example test code if possible
|
||||||
|
|
||||||
|
Respond in JSON format:
|
||||||
|
{{
|
||||||
|
"suggestions": [
|
||||||
|
{{
|
||||||
|
"function_name": "function_name",
|
||||||
|
"file_path": "path/to/file",
|
||||||
|
"test_type": "unit|integration|edge_case",
|
||||||
|
"description": "What to test",
|
||||||
|
"example_code": "brief example or null",
|
||||||
|
"priority": "HIGH|MEDIUM|LOW"
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"coverage_estimate": 0.0 to 1.0
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.call_llm_json(prompt)
|
||||||
|
|
||||||
|
for s in result.get("suggestions", []):
|
||||||
|
suggestions.append(
|
||||||
|
TestSuggestion(
|
||||||
|
function_name=s.get("function_name", ""),
|
||||||
|
file_path=s.get("file_path", ""),
|
||||||
|
test_type=s.get("test_type", "unit"),
|
||||||
|
description=s.get("description", ""),
|
||||||
|
example_code=s.get("example_code"),
|
||||||
|
priority=s.get("priority", "MEDIUM"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
coverage_estimate = result.get("coverage_estimate", 0.5)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"LLM suggestion failed: {e}")
|
||||||
|
# Generate basic suggestions without LLM
|
||||||
|
for f in functions[:10]:
|
||||||
|
suggestions.append(
|
||||||
|
TestSuggestion(
|
||||||
|
function_name=f["name"],
|
||||||
|
file_path=f["file"],
|
||||||
|
test_type="unit",
|
||||||
|
description=f"Add unit tests for {f['name']}",
|
||||||
|
priority="MEDIUM",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
coverage_estimate = 0.5
|
||||||
|
|
||||||
|
else:
|
||||||
|
coverage_estimate = 1.0 if existing_tests else 0.0
|
||||||
|
|
||||||
|
# Estimate functions with tests
|
||||||
|
functions_with_tests = int(len(functions) * coverage_estimate)
|
||||||
|
|
||||||
|
return CoverageReport(
|
||||||
|
functions_analyzed=len(functions),
|
||||||
|
functions_with_tests=functions_with_tests,
|
||||||
|
functions_without_tests=len(functions) - functions_with_tests,
|
||||||
|
suggestions=suggestions,
|
||||||
|
existing_tests=existing_tests,
|
||||||
|
coverage_estimate=coverage_estimate,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_coverage_report(
|
||||||
|
self, report: CoverageReport, user: str | None, is_pr: bool
|
||||||
|
) -> str:
|
||||||
|
"""Format the coverage report as a comment."""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
if user:
|
||||||
|
lines.append(f"@{user}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
f"{self.AI_DISCLAIMER}",
|
||||||
|
"",
|
||||||
|
"## 🧪 Test Coverage Analysis",
|
||||||
|
"",
|
||||||
|
"### Summary",
|
||||||
|
"",
|
||||||
|
f"| Metric | Value |",
|
||||||
|
f"|--------|-------|",
|
||||||
|
f"| Functions Analyzed | {report.functions_analyzed} |",
|
||||||
|
f"| Estimated Coverage | {report.coverage_estimate:.0%} |",
|
||||||
|
f"| Test Files Found | {len(report.existing_tests)} |",
|
||||||
|
f"| Suggestions | {len(report.suggestions)} |",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Suggestions by priority
|
||||||
|
if report.suggestions:
|
||||||
|
lines.append("### 💡 Test Suggestions")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Group by priority
|
||||||
|
by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []}
|
||||||
|
for s in report.suggestions:
|
||||||
|
if s.priority in by_priority:
|
||||||
|
by_priority[s.priority].append(s)
|
||||||
|
|
||||||
|
priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
|
||||||
|
|
||||||
|
for priority in ["HIGH", "MEDIUM", "LOW"]:
|
||||||
|
suggestions = by_priority[priority]
|
||||||
|
if suggestions:
|
||||||
|
lines.append(f"#### {priority_emoji[priority]} {priority} Priority")
|
||||||
|
lines.append("")
|
||||||
|
for s in suggestions[:5]: # Limit display
|
||||||
|
lines.append(f"**`{s.function_name}`** in `{s.file_path}`")
|
||||||
|
lines.append(f"- Type: {s.test_type}")
|
||||||
|
lines.append(f"- {s.description}")
|
||||||
|
if s.example_code:
|
||||||
|
lines.append(f"```")
|
||||||
|
lines.append(s.example_code[:200])
|
||||||
|
lines.append(f"```")
|
||||||
|
lines.append("")
|
||||||
|
if len(suggestions) > 5:
|
||||||
|
lines.append(f"*... and {len(suggestions) - 5} more*")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Existing test files
|
||||||
|
if report.existing_tests:
|
||||||
|
lines.append("### 📁 Existing Test Files")
|
||||||
|
lines.append("")
|
||||||
|
for f in report.existing_tests[:10]:
|
||||||
|
lines.append(f"- `{f}`")
|
||||||
|
if len(report.existing_tests) > 10:
|
||||||
|
lines.append(f"- *... and {len(report.existing_tests) - 10} more*")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Coverage bar
|
||||||
|
lines.append("### 📊 Coverage Estimate")
|
||||||
|
lines.append("")
|
||||||
|
filled = int(report.coverage_estimate * 10)
|
||||||
|
bar = "█" * filled + "░" * (10 - filled)
|
||||||
|
lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Recommendations
|
||||||
|
if report.coverage_estimate < 0.8:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("⚠️ **Coverage below 80%** - Consider adding more tests")
|
||||||
|
elif report.coverage_estimate >= 0.8:
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("✅ **Good test coverage!**")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -77,11 +77,13 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
model: str = "gpt-4o-mini",
|
model: str = "gpt-4o-mini",
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
|
timeout: int = 120,
|
||||||
):
|
):
|
||||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.timeout = timeout
|
||||||
self.api_url = "https://api.openai.com/v1/chat/completions"
|
self.api_url = "https://api.openai.com/v1/chat/completions"
|
||||||
|
|
||||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
@@ -101,7 +103,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
},
|
},
|
||||||
timeout=120,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -145,7 +147,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
json=request_body,
|
json=request_body,
|
||||||
timeout=120,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -186,11 +188,13 @@ class OpenRouterProvider(BaseLLMProvider):
|
|||||||
model: str = "anthropic/claude-3.5-sonnet",
|
model: str = "anthropic/claude-3.5-sonnet",
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
|
timeout: int = 120,
|
||||||
):
|
):
|
||||||
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.timeout = timeout
|
||||||
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
|
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||||
|
|
||||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
@@ -210,7 +214,7 @@ class OpenRouterProvider(BaseLLMProvider):
|
|||||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
},
|
},
|
||||||
timeout=120,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -254,7 +258,7 @@ class OpenRouterProvider(BaseLLMProvider):
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
json=request_body,
|
json=request_body,
|
||||||
timeout=120,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -294,10 +298,12 @@ class OllamaProvider(BaseLLMProvider):
|
|||||||
host: str | None = None,
|
host: str | None = None,
|
||||||
model: str = "codellama:13b",
|
model: str = "codellama:13b",
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
|
timeout: int = 300,
|
||||||
):
|
):
|
||||||
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
"""Call Ollama API."""
|
"""Call Ollama API."""
|
||||||
@@ -311,7 +317,7 @@ class OllamaProvider(BaseLLMProvider):
|
|||||||
"temperature": kwargs.get("temperature", self.temperature),
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
timeout=300, # Longer timeout for local models
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -477,12 +483,18 @@ class LLMClient:
|
|||||||
provider = config.get("provider", "openai")
|
provider = config.get("provider", "openai")
|
||||||
provider_config = {}
|
provider_config = {}
|
||||||
|
|
||||||
|
# Get timeout configuration
|
||||||
|
timeouts = config.get("timeouts", {})
|
||||||
|
llm_timeout = timeouts.get("llm", 120)
|
||||||
|
ollama_timeout = timeouts.get("ollama", 300)
|
||||||
|
|
||||||
# Map config keys to provider-specific settings
|
# Map config keys to provider-specific settings
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
provider_config = {
|
provider_config = {
|
||||||
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
|
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
|
||||||
"temperature": config.get("temperature", 0),
|
"temperature": config.get("temperature", 0),
|
||||||
"max_tokens": config.get("max_tokens", 16000),
|
"max_tokens": config.get("max_tokens", 16000),
|
||||||
|
"timeout": llm_timeout,
|
||||||
}
|
}
|
||||||
elif provider == "openrouter":
|
elif provider == "openrouter":
|
||||||
provider_config = {
|
provider_config = {
|
||||||
@@ -491,11 +503,13 @@ class LLMClient:
|
|||||||
),
|
),
|
||||||
"temperature": config.get("temperature", 0),
|
"temperature": config.get("temperature", 0),
|
||||||
"max_tokens": config.get("max_tokens", 16000),
|
"max_tokens": config.get("max_tokens", 16000),
|
||||||
|
"timeout": llm_timeout,
|
||||||
}
|
}
|
||||||
elif provider == "ollama":
|
elif provider == "ollama":
|
||||||
provider_config = {
|
provider_config = {
|
||||||
"model": config.get("model", {}).get("ollama", "codellama:13b"),
|
"model": config.get("model", {}).get("ollama", "codellama:13b"),
|
||||||
"temperature": config.get("temperature", 0),
|
"temperature": config.get("temperature", 0),
|
||||||
|
"timeout": ollama_timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
return cls(provider=provider, config=provider_config)
|
return cls(provider=provider, config=provider_config)
|
||||||
|
|||||||
27
tools/ai-review/clients/providers/__init__.py
Normal file
27
tools/ai-review/clients/providers/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""LLM Providers Package
|
||||||
|
|
||||||
|
This package contains additional LLM provider implementations
|
||||||
|
beyond the core providers in llm_client.py.
|
||||||
|
|
||||||
|
Providers:
|
||||||
|
- AnthropicProvider: Direct Anthropic Claude API
|
||||||
|
- AzureOpenAIProvider: Azure OpenAI Service with API key auth
|
||||||
|
- AzureOpenAIWithAADProvider: Azure OpenAI with Azure AD auth
|
||||||
|
- GeminiProvider: Google Gemini API (public)
|
||||||
|
- VertexAIGeminiProvider: Google Vertex AI Gemini (enterprise GCP)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from clients.providers.anthropic_provider import AnthropicProvider
|
||||||
|
from clients.providers.azure_provider import (
|
||||||
|
AzureOpenAIProvider,
|
||||||
|
AzureOpenAIWithAADProvider,
|
||||||
|
)
|
||||||
|
from clients.providers.gemini_provider import GeminiProvider, VertexAIGeminiProvider
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnthropicProvider",
|
||||||
|
"AzureOpenAIProvider",
|
||||||
|
"AzureOpenAIWithAADProvider",
|
||||||
|
"GeminiProvider",
|
||||||
|
"VertexAIGeminiProvider",
|
||||||
|
]
|
||||||
249
tools/ai-review/clients/providers/anthropic_provider.py
Normal file
249
tools/ai-review/clients/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Anthropic Claude Provider
|
||||||
|
|
||||||
|
Direct integration with Anthropic's Claude API.
|
||||||
|
Supports Claude 3.5 Sonnet, Claude 3 Opus, and other models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import base classes from parent module
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(BaseLLMProvider):
|
||||||
|
"""Anthropic Claude API provider.
|
||||||
|
|
||||||
|
Provides direct integration with Anthropic's Claude models
|
||||||
|
without going through OpenRouter.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Claude 3.5 Sonnet (claude-3-5-sonnet-20241022)
|
||||||
|
- Claude 3 Opus (claude-3-opus-20240229)
|
||||||
|
- Claude 3 Sonnet (claude-3-sonnet-20240229)
|
||||||
|
- Claude 3 Haiku (claude-3-haiku-20240307)
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_URL = "https://api.anthropic.com/v1/messages"
|
||||||
|
API_VERSION = "2023-06-01"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
model: str = "claude-3-5-sonnet-20241022",
|
||||||
|
temperature: float = 0,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
):
|
||||||
|
"""Initialize the Anthropic provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
|
||||||
|
model: Model to use. Defaults to Claude 3.5 Sonnet.
|
||||||
|
temperature: Sampling temperature (0-1).
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
|
"""Make a call to the Anthropic API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send.
|
||||||
|
**kwargs: Additional options (model, temperature, max_tokens).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with the generated content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If API key is not set.
|
||||||
|
requests.HTTPError: If the API request fails.
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Anthropic API key is required")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.API_URL,
|
||||||
|
headers={
|
||||||
|
"x-api-key": self.api_key,
|
||||||
|
"anthropic-version": self.API_VERSION,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
|
content = ""
|
||||||
|
for block in data.get("content", []):
|
||||||
|
if block.get("type") == "text":
|
||||||
|
content += block.get("text", "")
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
provider="anthropic",
|
||||||
|
tokens_used=data.get("usage", {}).get("input_tokens", 0)
|
||||||
|
+ data.get("usage", {}).get("output_tokens", 0),
|
||||||
|
finish_reason=data.get("stop_reason"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Make a call to the Anthropic API with tool support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
tools: List of tool definitions in OpenAI format.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool_calls.
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Anthropic API key is required")
|
||||||
|
|
||||||
|
# Convert OpenAI-style messages to Anthropic format
|
||||||
|
anthropic_messages = []
|
||||||
|
system_content = None
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system_content = msg.get("content", "")
|
||||||
|
elif role == "assistant":
|
||||||
|
# Handle assistant messages with tool calls
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
content = []
|
||||||
|
if msg.get("content"):
|
||||||
|
content.append({"type": "text", "text": msg["content"]})
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc["id"],
|
||||||
|
"name": tc["function"]["name"],
|
||||||
|
"input": json.loads(tc["function"]["arguments"])
|
||||||
|
if isinstance(tc["function"]["arguments"], str)
|
||||||
|
else tc["function"]["arguments"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
anthropic_messages.append({"role": "assistant", "content": content})
|
||||||
|
else:
|
||||||
|
anthropic_messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": msg.get("content", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif role == "tool":
|
||||||
|
# Tool response
|
||||||
|
anthropic_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": msg.get("tool_call_id", ""),
|
||||||
|
"content": msg.get("content", ""),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
anthropic_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": msg.get("content", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert OpenAI-style tools to Anthropic format
|
||||||
|
anthropic_tools = None
|
||||||
|
if tools:
|
||||||
|
anthropic_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool["function"]
|
||||||
|
anthropic_tools.append(
|
||||||
|
{
|
||||||
|
"name": func["name"],
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"input_schema": func.get("parameters", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"model": kwargs.get("model", self.model),
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"messages": anthropic_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_content:
|
||||||
|
request_body["system"] = system_content
|
||||||
|
|
||||||
|
if anthropic_tools:
|
||||||
|
request_body["tools"] = anthropic_tools
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.API_URL,
|
||||||
|
headers={
|
||||||
|
"x-api-key": self.api_key,
|
||||||
|
"anthropic-version": self.API_VERSION,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=request_body,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
content = ""
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
for block in data.get("content", []):
|
||||||
|
if block.get("type") == "text":
|
||||||
|
content += block.get("text", "")
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
if tool_calls is None:
|
||||||
|
tool_calls = []
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=block.get("id", ""),
|
||||||
|
name=block.get("name", ""),
|
||||||
|
arguments=block.get("input", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=data.get("model", self.model),
|
||||||
|
provider="anthropic",
|
||||||
|
tokens_used=data.get("usage", {}).get("input_tokens", 0)
|
||||||
|
+ data.get("usage", {}).get("output_tokens", 0),
|
||||||
|
finish_reason=data.get("stop_reason"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
@@ -0,0 +1,420 @@
|
|||||||
|
"""Azure OpenAI Provider
|
||||||
|
|
||||||
|
Integration with Azure OpenAI Service for enterprise deployments.
|
||||||
|
Supports custom deployments, regional endpoints, and Azure AD auth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIProvider(BaseLLMProvider):
|
||||||
|
"""Azure OpenAI Service provider.
|
||||||
|
|
||||||
|
Provides integration with Azure-hosted OpenAI models for
|
||||||
|
enterprise customers with Azure deployments.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- GPT-4, GPT-4 Turbo, GPT-4o
|
||||||
|
- GPT-3.5 Turbo
|
||||||
|
- Custom fine-tuned models
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||||
|
- AZURE_OPENAI_API_KEY: API key for authentication
|
||||||
|
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||||
|
- AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
deployment: str | None = None,
|
||||||
|
api_version: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
):
|
||||||
|
"""Initialize the Azure OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: Azure OpenAI endpoint URL.
|
||||||
|
Defaults to AZURE_OPENAI_ENDPOINT env var.
|
||||||
|
api_key: API key for authentication.
|
||||||
|
Defaults to AZURE_OPENAI_API_KEY env var.
|
||||||
|
deployment: Deployment name to use.
|
||||||
|
Defaults to AZURE_OPENAI_DEPLOYMENT env var.
|
||||||
|
api_version: API version string.
|
||||||
|
Defaults to AZURE_OPENAI_API_VERSION env var or latest.
|
||||||
|
temperature: Sampling temperature (0-2).
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
"""
|
||||||
|
self.endpoint = (
|
||||||
|
endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
|
||||||
|
).rstrip("/")
|
||||||
|
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||||
|
self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
|
||||||
|
self.api_version = api_version or os.environ.get(
|
||||||
|
"AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION
|
||||||
|
)
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def _get_api_url(self, deployment: str | None = None) -> str:
|
||||||
|
"""Build the API URL for a given deployment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deployment: Deployment name. Uses default if not specified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full API URL for chat completions.
|
||||||
|
"""
|
||||||
|
deploy = deployment or self.deployment
|
||||||
|
return (
|
||||||
|
f"{self.endpoint}/openai/deployments/{deploy}"
|
||||||
|
f"/chat/completions?api-version={self.api_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
|
"""Make a call to the Azure OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send.
|
||||||
|
**kwargs: Additional options (deployment, temperature, max_tokens).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with the generated content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required configuration is missing.
|
||||||
|
requests.HTTPError: If the API request fails.
|
||||||
|
"""
|
||||||
|
if not self.endpoint:
|
||||||
|
raise ValueError("Azure OpenAI endpoint is required")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Azure OpenAI API key is required")
|
||||||
|
if not self.deployment and not kwargs.get("deployment"):
|
||||||
|
raise ValueError("Azure OpenAI deployment name is required")
|
||||||
|
|
||||||
|
deployment = kwargs.get("deployment", self.deployment)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(deployment),
|
||||||
|
headers={
|
||||||
|
"api-key": self.api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", deployment),
|
||||||
|
provider="azure",
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Make a call to the Azure OpenAI API with tool support.
|
||||||
|
|
||||||
|
Azure OpenAI uses the same format as OpenAI for tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
tools: List of tool definitions in OpenAI format.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool_calls.
|
||||||
|
"""
|
||||||
|
if not self.endpoint:
|
||||||
|
raise ValueError("Azure OpenAI endpoint is required")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Azure OpenAI API key is required")
|
||||||
|
if not self.deployment and not kwargs.get("deployment"):
|
||||||
|
raise ValueError("Azure OpenAI deployment name is required")
|
||||||
|
|
||||||
|
deployment = kwargs.get("deployment", self.deployment)
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_body["tools"] = tools
|
||||||
|
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(deployment),
|
||||||
|
headers={
|
||||||
|
"api-key": self.api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=request_body,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
|
||||||
|
# Parse tool calls if present
|
||||||
|
tool_calls = None
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
tool_calls = []
|
||||||
|
for tc in message["tool_calls"]:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc.get("id", ""),
|
||||||
|
name=func.get("name", ""),
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", "") or "",
|
||||||
|
model=data.get("model", deployment),
|
||||||
|
provider="azure",
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIWithAADProvider(AzureOpenAIProvider):
|
||||||
|
"""Azure OpenAI provider with Azure Active Directory authentication.
|
||||||
|
|
||||||
|
Uses Azure AD tokens instead of API keys for authentication.
|
||||||
|
Requires azure-identity package for token acquisition.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||||
|
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||||
|
- AZURE_TENANT_ID: Azure AD tenant ID (optional)
|
||||||
|
- AZURE_CLIENT_ID: Azure AD client ID (optional)
|
||||||
|
- AZURE_CLIENT_SECRET: Azure AD client secret (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str | None = None,
|
||||||
|
deployment: str | None = None,
|
||||||
|
api_version: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
credential=None,
|
||||||
|
):
|
||||||
|
"""Initialize the Azure OpenAI AAD provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: Azure OpenAI endpoint URL.
|
||||||
|
deployment: Deployment name to use.
|
||||||
|
api_version: API version string.
|
||||||
|
temperature: Sampling temperature (0-2).
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
credential: Azure credential object. If not provided,
|
||||||
|
uses DefaultAzureCredential.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
endpoint=endpoint,
|
||||||
|
api_key="", # Not used with AAD
|
||||||
|
deployment=deployment,
|
||||||
|
api_version=api_version,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
self._credential = credential
|
||||||
|
self._token = None
|
||||||
|
self._token_expires_at = 0
|
||||||
|
|
||||||
|
def _get_token(self) -> str:
|
||||||
|
"""Get an Azure AD token for authentication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bearer token string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If azure-identity is not installed.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Return cached token if still valid (with 5 min buffer)
|
||||||
|
if self._token and self._token_expires_at > time.time() + 300:
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
try:
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"azure-identity package is required for AAD authentication. "
|
||||||
|
"Install with: pip install azure-identity"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._credential is None:
|
||||||
|
self._credential = DefaultAzureCredential()
|
||||||
|
|
||||||
|
token = self._credential.get_token(self.SCOPE)
|
||||||
|
self._token = token.token
|
||||||
|
self._token_expires_at = token.expires_on
|
||||||
|
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
|
"""Make a call to the Azure OpenAI API using AAD auth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with the generated content.
|
||||||
|
"""
|
||||||
|
if not self.endpoint:
|
||||||
|
raise ValueError("Azure OpenAI endpoint is required")
|
||||||
|
if not self.deployment and not kwargs.get("deployment"):
|
||||||
|
raise ValueError("Azure OpenAI deployment name is required")
|
||||||
|
|
||||||
|
deployment = kwargs.get("deployment", self.deployment)
|
||||||
|
token = self._get_token()
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(deployment),
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", ""),
|
||||||
|
model=data.get("model", deployment),
|
||||||
|
provider="azure",
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Make a call to the Azure OpenAI API with tool support using AAD auth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts.
|
||||||
|
tools: List of tool definitions.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool_calls.
|
||||||
|
"""
|
||||||
|
if not self.endpoint:
|
||||||
|
raise ValueError("Azure OpenAI endpoint is required")
|
||||||
|
if not self.deployment and not kwargs.get("deployment"):
|
||||||
|
raise ValueError("Azure OpenAI deployment name is required")
|
||||||
|
|
||||||
|
deployment = kwargs.get("deployment", self.deployment)
|
||||||
|
token = self._get_token()
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_body["tools"] = tools
|
||||||
|
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(deployment),
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=request_body,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
|
||||||
|
# Parse tool calls if present
|
||||||
|
tool_calls = None
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
tool_calls = []
|
||||||
|
for tc in message["tool_calls"]:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc.get("id", ""),
|
||||||
|
name=func.get("name", ""),
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content", "") or "",
|
||||||
|
model=data.get("model", deployment),
|
||||||
|
provider="azure",
|
||||||
|
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
@@ -0,0 +1,599 @@
|
|||||||
|
"""Google Gemini Provider
|
||||||
|
|
||||||
|
Integration with Google's Gemini API for GCP customers.
|
||||||
|
Supports Gemini Pro, Gemini Ultra, and other models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProvider(BaseLLMProvider):
|
||||||
|
"""Google Gemini API provider.
|
||||||
|
|
||||||
|
Provides integration with Google's Gemini models.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Gemini 1.5 Pro (gemini-1.5-pro)
|
||||||
|
- Gemini 1.5 Flash (gemini-1.5-flash)
|
||||||
|
- Gemini 1.0 Pro (gemini-pro)
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- GOOGLE_API_KEY: Google AI API key
|
||||||
|
- GEMINI_MODEL: Default model (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
model: str = "gemini-1.5-pro",
|
||||||
|
temperature: float = 0,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
):
|
||||||
|
"""Initialize the Gemini provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
|
||||||
|
model: Model to use. Defaults to gemini-1.5-pro.
|
||||||
|
temperature: Sampling temperature (0-1).
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
|
||||||
|
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
|
||||||
|
"""Build the API URL for a given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name. Uses default if not specified.
|
||||||
|
stream: Whether to use streaming endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full API URL.
|
||||||
|
"""
|
||||||
|
m = model or self.model
|
||||||
|
action = "streamGenerateContent" if stream else "generateContent"
|
||||||
|
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
|
||||||
|
|
||||||
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
|
"""Make a call to the Gemini API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send.
|
||||||
|
**kwargs: Additional options (model, temperature, max_tokens).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with the generated content.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If API key is not set.
|
||||||
|
requests.HTTPError: If the API request fails.
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Google API key is required")
|
||||||
|
|
||||||
|
model = kwargs.get("model", self.model)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(model),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json={
|
||||||
|
"contents": [{"parts": [{"text": prompt}]}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
|
content = ""
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if candidates:
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
for part in parts:
|
||||||
|
if "text" in part:
|
||||||
|
content += part["text"]
|
||||||
|
|
||||||
|
# Get token counts
|
||||||
|
usage = data.get("usageMetadata", {})
|
||||||
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||||
|
"candidatesTokenCount", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason = None
|
||||||
|
if candidates:
|
||||||
|
finish_reason = candidates[0].get("finishReason")
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model,
|
||||||
|
provider="gemini",
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Make a call to the Gemini API with tool support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
tools: List of tool definitions in OpenAI format.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool_calls.
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("Google API key is required")
|
||||||
|
|
||||||
|
model = kwargs.get("model", self.model)
|
||||||
|
|
||||||
|
# Convert OpenAI-style messages to Gemini format
|
||||||
|
gemini_contents = []
|
||||||
|
system_instruction = None
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system_instruction = content
|
||||||
|
elif role == "assistant":
|
||||||
|
# Handle assistant messages with tool calls
|
||||||
|
parts = []
|
||||||
|
if content:
|
||||||
|
parts.append({"text": content})
|
||||||
|
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
parts.append(
|
||||||
|
{
|
||||||
|
"functionCall": {
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"args": args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
gemini_contents.append({"role": "model", "parts": parts})
|
||||||
|
elif role == "tool":
|
||||||
|
# Tool response in Gemini format
|
||||||
|
gemini_contents.append(
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionResponse": {
|
||||||
|
"name": msg.get("name", ""),
|
||||||
|
"response": {"result": content},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# User message
|
||||||
|
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||||
|
|
||||||
|
# Convert OpenAI-style tools to Gemini format
|
||||||
|
gemini_tools = None
|
||||||
|
if tools:
|
||||||
|
function_declarations = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool["function"]
|
||||||
|
function_declarations.append(
|
||||||
|
{
|
||||||
|
"name": func["name"],
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"parameters": func.get("parameters", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if function_declarations:
|
||||||
|
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"contents": gemini_contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_instruction:
|
||||||
|
request_body["systemInstruction"] = {
|
||||||
|
"parts": [{"text": system_instruction}]
|
||||||
|
}
|
||||||
|
|
||||||
|
if gemini_tools:
|
||||||
|
request_body["tools"] = gemini_tools
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(model),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
json=request_body,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
content = ""
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if candidates:
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
for part in parts:
|
||||||
|
if "text" in part:
|
||||||
|
content += part["text"]
|
||||||
|
elif "functionCall" in part:
|
||||||
|
if tool_calls is None:
|
||||||
|
tool_calls = []
|
||||||
|
fc = part["functionCall"]
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
|
||||||
|
name=fc.get("name", ""),
|
||||||
|
arguments=fc.get("args", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get token counts
|
||||||
|
usage = data.get("usageMetadata", {})
|
||||||
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||||
|
"candidatesTokenCount", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason = None
|
||||||
|
if candidates:
|
||||||
|
finish_reason = candidates[0].get("finishReason")
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model,
|
||||||
|
provider="gemini",
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIGeminiProvider(BaseLLMProvider):
|
||||||
|
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
|
||||||
|
|
||||||
|
Uses Vertex AI endpoints instead of the public Gemini API.
|
||||||
|
Supports regional deployments and IAM authentication.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- GOOGLE_CLOUD_PROJECT: GCP project ID
|
||||||
|
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
|
||||||
|
- VERTEX_AI_MODEL: Default model (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project: str | None = None,
|
||||||
|
region: str = "us-central1",
|
||||||
|
model: str = "gemini-1.5-pro",
|
||||||
|
temperature: float = 0,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
credentials=None,
|
||||||
|
):
|
||||||
|
"""Initialize the Vertex AI Gemini provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
|
||||||
|
region: GCP region. Defaults to us-central1.
|
||||||
|
model: Model to use. Defaults to gemini-1.5-pro.
|
||||||
|
temperature: Sampling temperature (0-1).
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
credentials: Google credentials object. If not provided,
|
||||||
|
uses Application Default Credentials.
|
||||||
|
"""
|
||||||
|
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||||
|
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
|
||||||
|
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self._credentials = credentials
|
||||||
|
self._token = None
|
||||||
|
self._token_expires_at = 0
|
||||||
|
|
||||||
|
def _get_token(self) -> str:
|
||||||
|
"""Get a Google Cloud access token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Access token string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If google-auth is not installed.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Return cached token if still valid (with 5 min buffer)
|
||||||
|
if self._token and self._token_expires_at > time.time() + 300:
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
try:
|
||||||
|
import google.auth
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"google-auth package is required for Vertex AI authentication. "
|
||||||
|
"Install with: pip install google-auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._credentials is None:
|
||||||
|
self._credentials, _ = google.auth.default(
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._credentials.valid:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
|
||||||
|
self._token = self._credentials.token
|
||||||
|
# Tokens typically expire in 1 hour
|
||||||
|
self._token_expires_at = time.time() + 3500
|
||||||
|
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
def _get_api_url(self, model: str | None = None) -> str:
|
||||||
|
"""Build the Vertex AI API URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name. Uses default if not specified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full API URL.
|
||||||
|
"""
|
||||||
|
m = model or self.model
|
||||||
|
return (
|
||||||
|
f"https://{self.region}-aiplatform.googleapis.com/v1/"
|
||||||
|
f"projects/{self.project}/locations/{self.region}/"
|
||||||
|
f"publishers/google/models/{m}:generateContent"
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||||
|
"""Make a call to Vertex AI Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to send.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with the generated content.
|
||||||
|
"""
|
||||||
|
if not self.project:
|
||||||
|
raise ValueError("GCP project ID is required")
|
||||||
|
|
||||||
|
model = kwargs.get("model", self.model)
|
||||||
|
token = self._get_token()
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(model),
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"contents": [{"parts": [{"text": prompt}]}],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
|
content = ""
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if candidates:
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
for part in parts:
|
||||||
|
if "text" in part:
|
||||||
|
content += part["text"]
|
||||||
|
|
||||||
|
# Get token counts
|
||||||
|
usage = data.get("usageMetadata", {})
|
||||||
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||||
|
"candidatesTokenCount", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason = None
|
||||||
|
if candidates:
|
||||||
|
finish_reason = candidates[0].get("finishReason")
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model,
|
||||||
|
provider="vertex-ai",
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def call_with_tools(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Make a call to Vertex AI Gemini with tool support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts.
|
||||||
|
tools: List of tool definitions.
|
||||||
|
**kwargs: Additional options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool_calls.
|
||||||
|
"""
|
||||||
|
if not self.project:
|
||||||
|
raise ValueError("GCP project ID is required")
|
||||||
|
|
||||||
|
model = kwargs.get("model", self.model)
|
||||||
|
token = self._get_token()
|
||||||
|
|
||||||
|
# Convert messages to Gemini format (same as GeminiProvider)
|
||||||
|
gemini_contents = []
|
||||||
|
system_instruction = None
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
system_instruction = content
|
||||||
|
elif role == "assistant":
|
||||||
|
parts = []
|
||||||
|
if content:
|
||||||
|
parts.append({"text": content})
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
parts.append(
|
||||||
|
{
|
||||||
|
"functionCall": {
|
||||||
|
"name": func.get("name", ""),
|
||||||
|
"args": args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
gemini_contents.append({"role": "model", "parts": parts})
|
||||||
|
elif role == "tool":
|
||||||
|
gemini_contents.append(
|
||||||
|
{
|
||||||
|
"role": "function",
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionResponse": {
|
||||||
|
"name": msg.get("name", ""),
|
||||||
|
"response": {"result": content},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||||
|
|
||||||
|
# Convert tools to Gemini format
|
||||||
|
gemini_tools = None
|
||||||
|
if tools:
|
||||||
|
function_declarations = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool["function"]
|
||||||
|
function_declarations.append(
|
||||||
|
{
|
||||||
|
"name": func["name"],
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"parameters": func.get("parameters", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if function_declarations:
|
||||||
|
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"contents": gemini_contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_instruction:
|
||||||
|
request_body["systemInstruction"] = {
|
||||||
|
"parts": [{"text": system_instruction}]
|
||||||
|
}
|
||||||
|
|
||||||
|
if gemini_tools:
|
||||||
|
request_body["tools"] = gemini_tools
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self._get_api_url(model),
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json=request_body,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
content = ""
|
||||||
|
tool_calls = None
|
||||||
|
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if candidates:
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
for part in parts:
|
||||||
|
if "text" in part:
|
||||||
|
content += part["text"]
|
||||||
|
elif "functionCall" in part:
|
||||||
|
if tool_calls is None:
|
||||||
|
tool_calls = []
|
||||||
|
fc = part["functionCall"]
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=f"call_{len(tool_calls)}",
|
||||||
|
name=fc.get("name", ""),
|
||||||
|
arguments=fc.get("args", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = data.get("usageMetadata", {})
|
||||||
|
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||||
|
"candidatesTokenCount", 0
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason = None
|
||||||
|
if candidates:
|
||||||
|
finish_reason = candidates[0].get("finishReason")
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model,
|
||||||
|
provider="vertex-ai",
|
||||||
|
tokens_used=tokens_used,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
14
tools/ai-review/compliance/__init__.py
Normal file
14
tools/ai-review/compliance/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Compliance Module
|
||||||
|
|
||||||
|
Provides audit trail, compliance reporting, and regulatory checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from compliance.audit_trail import AuditEvent, AuditLogger, AuditTrail
|
||||||
|
from compliance.codeowners import CodeownersChecker
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuditTrail",
|
||||||
|
"AuditLogger",
|
||||||
|
"AuditEvent",
|
||||||
|
"CodeownersChecker",
|
||||||
|
]
|
||||||
430
tools/ai-review/compliance/audit_trail.py
Normal file
430
tools/ai-review/compliance/audit_trail.py
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
"""Audit Trail
|
||||||
|
|
||||||
|
Provides comprehensive audit logging for compliance requirements.
|
||||||
|
Supports HIPAA, SOC2, and other regulatory frameworks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class AuditAction(Enum):
|
||||||
|
"""Types of auditable actions."""
|
||||||
|
|
||||||
|
# Review actions
|
||||||
|
REVIEW_STARTED = "review_started"
|
||||||
|
REVIEW_COMPLETED = "review_completed"
|
||||||
|
REVIEW_FAILED = "review_failed"
|
||||||
|
|
||||||
|
# Security actions
|
||||||
|
SECURITY_SCAN_STARTED = "security_scan_started"
|
||||||
|
SECURITY_SCAN_COMPLETED = "security_scan_completed"
|
||||||
|
SECURITY_FINDING_DETECTED = "security_finding_detected"
|
||||||
|
SECURITY_FINDING_RESOLVED = "security_finding_resolved"
|
||||||
|
|
||||||
|
# Comment actions
|
||||||
|
COMMENT_POSTED = "comment_posted"
|
||||||
|
COMMENT_UPDATED = "comment_updated"
|
||||||
|
COMMENT_DELETED = "comment_deleted"
|
||||||
|
|
||||||
|
# Label actions
|
||||||
|
LABEL_ADDED = "label_added"
|
||||||
|
LABEL_REMOVED = "label_removed"
|
||||||
|
|
||||||
|
# Configuration actions
|
||||||
|
CONFIG_LOADED = "config_loaded"
|
||||||
|
CONFIG_CHANGED = "config_changed"
|
||||||
|
|
||||||
|
# Access actions
|
||||||
|
API_CALL = "api_call"
|
||||||
|
AUTHENTICATION = "authentication"
|
||||||
|
|
||||||
|
# Approval actions
|
||||||
|
APPROVAL_GRANTED = "approval_granted"
|
||||||
|
APPROVAL_REVOKED = "approval_revoked"
|
||||||
|
CHANGES_REQUESTED = "changes_requested"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuditEvent:
|
||||||
|
"""An auditable event."""
|
||||||
|
|
||||||
|
action: AuditAction
|
||||||
|
timestamp: str
|
||||||
|
actor: str
|
||||||
|
resource_type: str
|
||||||
|
resource_id: str
|
||||||
|
repository: str
|
||||||
|
details: dict[str, Any] = field(default_factory=dict)
|
||||||
|
outcome: str = "success"
|
||||||
|
error: str | None = None
|
||||||
|
correlation_id: str | None = None
|
||||||
|
checksum: str | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Calculate checksum for integrity verification."""
|
||||||
|
if not self.checksum:
|
||||||
|
self.checksum = self._calculate_checksum()
|
||||||
|
|
||||||
|
def _calculate_checksum(self) -> str:
|
||||||
|
"""Calculate SHA-256 checksum of event data."""
|
||||||
|
data = {
|
||||||
|
"action": self.action.value
|
||||||
|
if isinstance(self.action, AuditAction)
|
||||||
|
else self.action,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"actor": self.actor,
|
||||||
|
"resource_type": self.resource_type,
|
||||||
|
"resource_id": self.resource_id,
|
||||||
|
"repository": self.repository,
|
||||||
|
"details": self.details,
|
||||||
|
"outcome": self.outcome,
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
json_str = json.dumps(data, sort_keys=True)
|
||||||
|
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert event to dictionary."""
|
||||||
|
data = asdict(self)
|
||||||
|
if isinstance(self.action, AuditAction):
|
||||||
|
data["action"] = self.action.value
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Convert event to JSON string."""
|
||||||
|
return json.dumps(self.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogger:
|
||||||
|
"""Logger for audit events."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
log_file: str | None = None,
|
||||||
|
log_to_stdout: bool = False,
|
||||||
|
log_level: str = "INFO",
|
||||||
|
):
|
||||||
|
"""Initialize audit logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: Path to audit log file.
|
||||||
|
log_to_stdout: Also log to stdout.
|
||||||
|
log_level: Logging level.
|
||||||
|
"""
|
||||||
|
self.log_file = log_file
|
||||||
|
self.log_to_stdout = log_to_stdout
|
||||||
|
self.logger = logging.getLogger("audit")
|
||||||
|
self.logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||||||
|
|
||||||
|
# Clear existing handlers
|
||||||
|
self.logger.handlers = []
|
||||||
|
|
||||||
|
# Add file handler if specified
|
||||||
|
if log_file:
|
||||||
|
log_dir = os.path.dirname(log_file)
|
||||||
|
if log_dir:
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
file_handler = logging.FileHandler(log_file)
|
||||||
|
file_handler.setFormatter(
|
||||||
|
logging.Formatter("%(message)s") # JSON lines format
|
||||||
|
)
|
||||||
|
self.logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Add stdout handler if requested
|
||||||
|
if log_to_stdout:
|
||||||
|
stdout_handler = logging.StreamHandler()
|
||||||
|
stdout_handler.setFormatter(logging.Formatter("[AUDIT] %(message)s"))
|
||||||
|
self.logger.addHandler(stdout_handler)
|
||||||
|
|
||||||
|
def log(self, event: AuditEvent):
|
||||||
|
"""Log an audit event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The audit event to log.
|
||||||
|
"""
|
||||||
|
self.logger.info(event.to_json())
|
||||||
|
|
||||||
|
def log_action(
|
||||||
|
self,
|
||||||
|
action: AuditAction,
|
||||||
|
actor: str,
|
||||||
|
resource_type: str,
|
||||||
|
resource_id: str,
|
||||||
|
repository: str,
|
||||||
|
details: dict | None = None,
|
||||||
|
outcome: str = "success",
|
||||||
|
error: str | None = None,
|
||||||
|
correlation_id: str | None = None,
|
||||||
|
):
|
||||||
|
"""Log an action as an audit event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action being performed.
|
||||||
|
actor: Who performed the action.
|
||||||
|
resource_type: Type of resource affected.
|
||||||
|
resource_id: ID of the resource.
|
||||||
|
repository: Repository context.
|
||||||
|
details: Additional details.
|
||||||
|
outcome: success, failure, or partial.
|
||||||
|
error: Error message if failed.
|
||||||
|
correlation_id: ID to correlate related events.
|
||||||
|
"""
|
||||||
|
event = AuditEvent(
|
||||||
|
action=action,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
actor=actor,
|
||||||
|
resource_type=resource_type,
|
||||||
|
resource_id=resource_id,
|
||||||
|
repository=repository,
|
||||||
|
details=details or {},
|
||||||
|
outcome=outcome,
|
||||||
|
error=error,
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
self.log(event)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditTrail:
|
||||||
|
"""High-level audit trail management."""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
"""Initialize audit trail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
compliance_config = config.get("compliance", {})
|
||||||
|
audit_config = compliance_config.get("audit", {})
|
||||||
|
|
||||||
|
self.enabled = audit_config.get("enabled", False)
|
||||||
|
self.log_file = audit_config.get("log_file", "audit.log")
|
||||||
|
self.log_to_stdout = audit_config.get("log_to_stdout", False)
|
||||||
|
self.retention_days = audit_config.get("retention_days", 90)
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
self.logger = AuditLogger(
|
||||||
|
log_file=self.log_file,
|
||||||
|
log_to_stdout=self.log_to_stdout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.logger = None
|
||||||
|
|
||||||
|
self._correlation_id = None
|
||||||
|
|
||||||
|
def set_correlation_id(self, correlation_id: str):
|
||||||
|
"""Set correlation ID for subsequent events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
correlation_id: ID to correlate related events.
|
||||||
|
"""
|
||||||
|
self._correlation_id = correlation_id
|
||||||
|
|
||||||
|
def log(
|
||||||
|
self,
|
||||||
|
action: AuditAction,
|
||||||
|
actor: str,
|
||||||
|
resource_type: str,
|
||||||
|
resource_id: str,
|
||||||
|
repository: str,
|
||||||
|
details: dict | None = None,
|
||||||
|
outcome: str = "success",
|
||||||
|
error: str | None = None,
|
||||||
|
):
|
||||||
|
"""Log an audit event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action being performed.
|
||||||
|
actor: Who performed the action.
|
||||||
|
resource_type: Type of resource (pr, issue, comment, etc).
|
||||||
|
resource_id: ID of the resource.
|
||||||
|
repository: Repository (owner/repo).
|
||||||
|
details: Additional details.
|
||||||
|
outcome: success, failure, or partial.
|
||||||
|
error: Error message if failed.
|
||||||
|
"""
|
||||||
|
if not self.enabled or not self.logger:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.log_action(
|
||||||
|
action=action,
|
||||||
|
actor=actor,
|
||||||
|
resource_type=resource_type,
|
||||||
|
resource_id=resource_id,
|
||||||
|
repository=repository,
|
||||||
|
details=details,
|
||||||
|
outcome=outcome,
|
||||||
|
error=error,
|
||||||
|
correlation_id=self._correlation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_review_started(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
reviewer: str = "openrabbit",
|
||||||
|
):
|
||||||
|
"""Log that a review has started."""
|
||||||
|
self.log(
|
||||||
|
action=AuditAction.REVIEW_STARTED,
|
||||||
|
actor=reviewer,
|
||||||
|
resource_type="pull_request",
|
||||||
|
resource_id=str(pr_number),
|
||||||
|
repository=repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_review_completed(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
recommendation: str,
|
||||||
|
findings_count: int,
|
||||||
|
reviewer: str = "openrabbit",
|
||||||
|
):
|
||||||
|
"""Log that a review has completed."""
|
||||||
|
self.log(
|
||||||
|
action=AuditAction.REVIEW_COMPLETED,
|
||||||
|
actor=reviewer,
|
||||||
|
resource_type="pull_request",
|
||||||
|
resource_id=str(pr_number),
|
||||||
|
repository=repository,
|
||||||
|
details={
|
||||||
|
"recommendation": recommendation,
|
||||||
|
"findings_count": findings_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_security_finding(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
finding: dict,
|
||||||
|
scanner: str = "openrabbit",
|
||||||
|
):
|
||||||
|
"""Log a security finding."""
|
||||||
|
self.log(
|
||||||
|
action=AuditAction.SECURITY_FINDING_DETECTED,
|
||||||
|
actor=scanner,
|
||||||
|
resource_type="pull_request",
|
||||||
|
resource_id=str(pr_number),
|
||||||
|
repository=repository,
|
||||||
|
details={
|
||||||
|
"severity": finding.get("severity"),
|
||||||
|
"category": finding.get("category"),
|
||||||
|
"file": finding.get("file"),
|
||||||
|
"line": finding.get("line"),
|
||||||
|
"cwe": finding.get("cwe"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_approval(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
approver: str,
|
||||||
|
approval_type: str = "ai",
|
||||||
|
):
|
||||||
|
"""Log an approval action."""
|
||||||
|
self.log(
|
||||||
|
action=AuditAction.APPROVAL_GRANTED,
|
||||||
|
actor=approver,
|
||||||
|
resource_type="pull_request",
|
||||||
|
resource_id=str(pr_number),
|
||||||
|
repository=repository,
|
||||||
|
details={"approval_type": approval_type},
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_changes_requested(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
requester: str,
|
||||||
|
reason: str | None = None,
|
||||||
|
):
|
||||||
|
"""Log a changes requested action."""
|
||||||
|
self.log(
|
||||||
|
action=AuditAction.CHANGES_REQUESTED,
|
||||||
|
actor=requester,
|
||||||
|
resource_type="pull_request",
|
||||||
|
resource_id=str(pr_number),
|
||||||
|
repository=repository,
|
||||||
|
details={"reason": reason} if reason else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_report(
|
||||||
|
self,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
repository: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate an audit report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: Start of reporting period.
|
||||||
|
end_date: End of reporting period.
|
||||||
|
repository: Filter by repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Report dictionary with statistics and events.
|
||||||
|
"""
|
||||||
|
if not self.log_file or not os.path.exists(self.log_file):
|
||||||
|
return {"events": [], "statistics": {}}
|
||||||
|
|
||||||
|
events = []
|
||||||
|
with open(self.log_file) as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
event = json.loads(line.strip())
|
||||||
|
event_time = datetime.fromisoformat(
|
||||||
|
event["timestamp"].replace("Z", "+00:00")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if start_date and event_time < start_date:
|
||||||
|
continue
|
||||||
|
if end_date and event_time > end_date:
|
||||||
|
continue
|
||||||
|
if repository and event.get("repository") != repository:
|
||||||
|
continue
|
||||||
|
|
||||||
|
events.append(event)
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
action_counts = {}
|
||||||
|
outcome_counts = {"success": 0, "failure": 0, "partial": 0}
|
||||||
|
security_findings = 0
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
action = event.get("action", "unknown")
|
||||||
|
action_counts[action] = action_counts.get(action, 0) + 1
|
||||||
|
|
||||||
|
outcome = event.get("outcome", "success")
|
||||||
|
if outcome in outcome_counts:
|
||||||
|
outcome_counts[outcome] += 1
|
||||||
|
|
||||||
|
if action == "security_finding_detected":
|
||||||
|
security_findings += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"events": events,
|
||||||
|
"statistics": {
|
||||||
|
"total_events": len(events),
|
||||||
|
"action_counts": action_counts,
|
||||||
|
"outcome_counts": outcome_counts,
|
||||||
|
"security_findings": security_findings,
|
||||||
|
},
|
||||||
|
"period": {
|
||||||
|
"start": start_date.isoformat() if start_date else None,
|
||||||
|
"end": end_date.isoformat() if end_date else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
314
tools/ai-review/compliance/codeowners.py
Normal file
314
tools/ai-review/compliance/codeowners.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
"""CODEOWNERS Checker
|
||||||
|
|
||||||
|
Parses and validates CODEOWNERS files for compliance enforcement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeOwnerRule:
|
||||||
|
"""A CODEOWNERS rule."""
|
||||||
|
|
||||||
|
pattern: str
|
||||||
|
owners: list[str]
|
||||||
|
line_number: int
|
||||||
|
is_negation: bool = False
|
||||||
|
|
||||||
|
def matches(self, path: str) -> bool:
|
||||||
|
"""Check if a path matches this rule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path matches.
|
||||||
|
"""
|
||||||
|
path = path.lstrip("/")
|
||||||
|
pattern = self.pattern.lstrip("/")
|
||||||
|
|
||||||
|
# Handle directory patterns
|
||||||
|
if pattern.endswith("/"):
|
||||||
|
return path.startswith(pattern) or fnmatch.fnmatch(path, pattern + "*")
|
||||||
|
|
||||||
|
# Handle ** patterns
|
||||||
|
if "**" in pattern:
|
||||||
|
regex = pattern.replace("**", ".*").replace("*", "[^/]*")
|
||||||
|
return bool(re.match(f"^{regex}$", path))
|
||||||
|
|
||||||
|
# Standard fnmatch
|
||||||
|
return fnmatch.fnmatch(path, pattern) or fnmatch.fnmatch(path, f"**/{pattern}")
|
||||||
|
|
||||||
|
|
||||||
|
class CodeownersChecker:
|
||||||
|
"""Checker for CODEOWNERS file compliance."""
|
||||||
|
|
||||||
|
CODEOWNERS_LOCATIONS = [
|
||||||
|
"CODEOWNERS",
|
||||||
|
".github/CODEOWNERS",
|
||||||
|
".gitea/CODEOWNERS",
|
||||||
|
"docs/CODEOWNERS",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, repo_root: str | None = None):
|
||||||
|
"""Initialize CODEOWNERS checker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_root: Repository root path.
|
||||||
|
"""
|
||||||
|
self.repo_root = repo_root or os.getcwd()
|
||||||
|
self.rules: list[CodeOwnerRule] = []
|
||||||
|
self.codeowners_path: str | None = None
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
self._load_codeowners()
|
||||||
|
|
||||||
|
def _load_codeowners(self):
|
||||||
|
"""Load CODEOWNERS file from repository."""
|
||||||
|
for location in self.CODEOWNERS_LOCATIONS:
|
||||||
|
path = os.path.join(self.repo_root, location)
|
||||||
|
if os.path.exists(path):
|
||||||
|
self.codeowners_path = path
|
||||||
|
self._parse_codeowners(path)
|
||||||
|
break
|
||||||
|
|
||||||
|
def _parse_codeowners(self, path: str):
|
||||||
|
"""Parse a CODEOWNERS file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to CODEOWNERS file.
|
||||||
|
"""
|
||||||
|
with open(path) as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse pattern and owners
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pattern = parts[0]
|
||||||
|
owners = parts[1:]
|
||||||
|
|
||||||
|
# Check for negation (optional syntax)
|
||||||
|
is_negation = pattern.startswith("!")
|
||||||
|
if is_negation:
|
||||||
|
pattern = pattern[1:]
|
||||||
|
|
||||||
|
self.rules.append(
|
||||||
|
CodeOwnerRule(
|
||||||
|
pattern=pattern,
|
||||||
|
owners=owners,
|
||||||
|
line_number=line_num,
|
||||||
|
is_negation=is_negation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_owners(self, path: str) -> list[str]:
|
||||||
|
"""Get owners for a file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of owner usernames/teams.
|
||||||
|
"""
|
||||||
|
owners = []
|
||||||
|
|
||||||
|
# Apply rules in order (later rules override earlier ones)
|
||||||
|
for rule in self.rules:
|
||||||
|
if rule.matches(path):
|
||||||
|
if rule.is_negation:
|
||||||
|
owners = [] # Clear owners for negation
|
||||||
|
else:
|
||||||
|
owners = rule.owners
|
||||||
|
|
||||||
|
return owners
|
||||||
|
|
||||||
|
def get_owners_for_files(self, files: list[str]) -> dict[str, list[str]]:
|
||||||
|
"""Get owners for multiple files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of file paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping file paths to owner lists.
|
||||||
|
"""
|
||||||
|
return {f: self.get_owners(f) for f in files}
|
||||||
|
|
||||||
|
def get_required_reviewers(self, files: list[str]) -> set[str]:
|
||||||
|
"""Get all required reviewers for a set of files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of file paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of all required reviewer usernames/teams.
|
||||||
|
"""
|
||||||
|
reviewers = set()
|
||||||
|
for f in files:
|
||||||
|
reviewers.update(self.get_owners(f))
|
||||||
|
return reviewers
|
||||||
|
|
||||||
|
def check_approval(
|
||||||
|
self,
|
||||||
|
files: list[str],
|
||||||
|
approvers: list[str],
|
||||||
|
) -> dict:
|
||||||
|
"""Check if files have required approvals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of changed files.
|
||||||
|
approvers: List of users who approved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with approval status and missing approvers.
|
||||||
|
"""
|
||||||
|
required = self.get_required_reviewers(files)
|
||||||
|
approvers_set = set(approvers)
|
||||||
|
|
||||||
|
# Normalize @ prefixes
|
||||||
|
required_normalized = {r.lstrip("@") for r in required}
|
||||||
|
approvers_normalized = {a.lstrip("@") for a in approvers_set}
|
||||||
|
|
||||||
|
missing = required_normalized - approvers_normalized
|
||||||
|
|
||||||
|
# Check for team approvals (simplified - actual implementation
|
||||||
|
# would need API calls to check team membership)
|
||||||
|
teams = {r for r in missing if "/" in r}
|
||||||
|
missing_users = missing - teams
|
||||||
|
|
||||||
|
return {
|
||||||
|
"approved": len(missing_users) == 0,
|
||||||
|
"required_reviewers": list(required_normalized),
|
||||||
|
"actual_approvers": list(approvers_normalized),
|
||||||
|
"missing_approvers": list(missing_users),
|
||||||
|
"pending_teams": list(teams),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_coverage_report(self, files: list[str]) -> dict:
|
||||||
|
"""Generate a coverage report for files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of file paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Coverage report with owned and unowned files.
|
||||||
|
"""
|
||||||
|
owned = []
|
||||||
|
unowned = []
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
owners = self.get_owners(f)
|
||||||
|
if owners:
|
||||||
|
owned.append({"file": f, "owners": owners})
|
||||||
|
else:
|
||||||
|
unowned.append(f)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_files": len(files),
|
||||||
|
"owned_files": len(owned),
|
||||||
|
"unowned_files": len(unowned),
|
||||||
|
"coverage_percent": (len(owned) / len(files) * 100) if files else 0,
|
||||||
|
"owned": owned,
|
||||||
|
"unowned": unowned,
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_codeowners(self) -> dict:
|
||||||
|
"""Validate the CODEOWNERS file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validation result with warnings and errors.
|
||||||
|
"""
|
||||||
|
if not self.codeowners_path:
|
||||||
|
return {
|
||||||
|
"valid": False,
|
||||||
|
"errors": ["No CODEOWNERS file found"],
|
||||||
|
"warnings": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check for empty rules
|
||||||
|
for rule in self.rules:
|
||||||
|
if not rule.owners:
|
||||||
|
errors.append(
|
||||||
|
f"Line {rule.line_number}: Pattern '{rule.pattern}' has no owners"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for invalid owner formats
|
||||||
|
for rule in self.rules:
|
||||||
|
for owner in rule.owners:
|
||||||
|
if not owner.startswith("@") and "/" not in owner:
|
||||||
|
warnings.append(
|
||||||
|
f"Line {rule.line_number}: Owner '{owner}' should start with @ or be a team (org/team)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for overlapping patterns
|
||||||
|
patterns_seen = {}
|
||||||
|
for rule in self.rules:
|
||||||
|
if rule.pattern in patterns_seen:
|
||||||
|
warnings.append(
|
||||||
|
f"Line {rule.line_number}: Pattern '{rule.pattern}' duplicates line {patterns_seen[rule.pattern]}"
|
||||||
|
)
|
||||||
|
patterns_seen[rule.pattern] = rule.line_number
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": len(errors) == 0,
|
||||||
|
"errors": errors,
|
||||||
|
"warnings": warnings,
|
||||||
|
"rules_count": len(self.rules),
|
||||||
|
"file_path": self.codeowners_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_content(cls, content: str) -> "CodeownersChecker":
|
||||||
|
"""Create checker from CODEOWNERS content string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: CODEOWNERS file content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CodeownersChecker instance.
|
||||||
|
"""
|
||||||
|
checker = cls.__new__(cls)
|
||||||
|
checker.repo_root = None
|
||||||
|
checker.rules = []
|
||||||
|
checker.codeowners_path = "<string>"
|
||||||
|
checker.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
for line_num, line in enumerate(content.split("\n"), 1):
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pattern = parts[0]
|
||||||
|
owners = parts[1:]
|
||||||
|
is_negation = pattern.startswith("!")
|
||||||
|
if is_negation:
|
||||||
|
pattern = pattern[1:]
|
||||||
|
|
||||||
|
checker.rules.append(
|
||||||
|
CodeOwnerRule(
|
||||||
|
pattern=pattern,
|
||||||
|
owners=owners,
|
||||||
|
line_number=line_num,
|
||||||
|
is_negation=is_negation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return checker
|
||||||
@@ -1,21 +1,61 @@
|
|||||||
provider: openai # openai | openrouter | ollama
|
# OpenRabbit AI Code Review Configuration
|
||||||
|
# =========================================
|
||||||
|
|
||||||
|
# LLM Provider Configuration
|
||||||
|
# --------------------------
|
||||||
|
# Available providers: openai | openrouter | ollama | anthropic | azure | gemini
|
||||||
|
provider: openai
|
||||||
|
|
||||||
model:
|
model:
|
||||||
openai: gpt-4.1-mini
|
openai: gpt-4.1-mini
|
||||||
openrouter: anthropic/claude-3.5-sonnet
|
openrouter: anthropic/claude-3.5-sonnet
|
||||||
ollama: codellama:13b
|
ollama: codellama:13b
|
||||||
|
anthropic: claude-3-5-sonnet-20241022
|
||||||
|
azure: gpt-4 # Deployment name
|
||||||
|
gemini: gemini-1.5-pro
|
||||||
|
|
||||||
temperature: 0
|
temperature: 0
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
|
|
||||||
|
# Azure OpenAI specific settings (when provider: azure)
|
||||||
|
azure:
|
||||||
|
endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var
|
||||||
|
deployment: "" # Set via AZURE_OPENAI_DEPLOYMENT env var
|
||||||
|
api_version: "2024-02-15-preview"
|
||||||
|
|
||||||
|
# Google Gemini specific settings (when provider: gemini)
|
||||||
|
gemini:
|
||||||
|
project: "" # For Vertex AI, set via GOOGLE_CLOUD_PROJECT env var
|
||||||
|
region: "us-central1"
|
||||||
|
|
||||||
|
# Rate Limits and Timeouts
|
||||||
|
# ------------------------
|
||||||
|
rate_limits:
|
||||||
|
min_interval: 1.0 # Minimum seconds between API requests
|
||||||
|
|
||||||
|
timeouts:
|
||||||
|
llm: 120 # LLM API timeout in seconds (OpenAI, OpenRouter, Anthropic, etc.)
|
||||||
|
ollama: 300 # Ollama timeout (longer for local models)
|
||||||
|
gitea: 30 # Gitea/GitHub API timeout
|
||||||
|
|
||||||
# Review settings
|
# Review settings
|
||||||
|
# ---------------
|
||||||
review:
|
review:
|
||||||
fail_on_severity: HIGH
|
fail_on_severity: HIGH
|
||||||
max_diff_lines: 800
|
max_diff_lines: 800
|
||||||
inline_comments: true
|
inline_comments: true
|
||||||
security_scan: true
|
security_scan: true
|
||||||
|
|
||||||
# Agent settings
|
# File Ignore Patterns
|
||||||
|
# --------------------
|
||||||
|
# Similar to .gitignore, controls which files are excluded from review
|
||||||
|
ignore:
|
||||||
|
use_defaults: true # Include default patterns (node_modules, .git, etc.)
|
||||||
|
file: ".ai-reviewignore" # Custom ignore file name
|
||||||
|
patterns: [] # Additional patterns to ignore
|
||||||
|
|
||||||
|
# Agent Configuration
|
||||||
|
# -------------------
|
||||||
agents:
|
agents:
|
||||||
issue:
|
issue:
|
||||||
enabled: true
|
enabled: true
|
||||||
@@ -33,46 +73,137 @@ agents:
|
|||||||
- opened
|
- opened
|
||||||
- synchronize
|
- synchronize
|
||||||
auto_summary:
|
auto_summary:
|
||||||
enabled: true # Auto-generate summary for PRs with empty descriptions
|
enabled: true
|
||||||
post_as_comment: true # true = post as comment, false = update PR description
|
post_as_comment: true
|
||||||
codebase:
|
codebase:
|
||||||
enabled: true
|
enabled: true
|
||||||
schedule: "0 0 * * 0" # Weekly on Sunday
|
schedule: "0 0 * * 0" # Weekly on Sunday
|
||||||
chat:
|
chat:
|
||||||
enabled: true
|
enabled: true
|
||||||
name: "Bartender"
|
name: "Bartender"
|
||||||
max_iterations: 5 # Max tool call iterations per chat
|
max_iterations: 5
|
||||||
tools:
|
tools:
|
||||||
- search_codebase
|
- search_codebase
|
||||||
- read_file
|
- read_file
|
||||||
- search_web
|
- search_web
|
||||||
searxng_url: "" # Set via SEARXNG_URL env var or here
|
searxng_url: "" # Set via SEARXNG_URL env var
|
||||||
|
|
||||||
# Interaction settings
|
# Dependency Security Agent
|
||||||
|
dependency:
|
||||||
|
enabled: true
|
||||||
|
scan_on_pr: true # Auto-scan PRs that modify dependency files
|
||||||
|
vulnerability_threshold: "medium" # low | medium | high | critical
|
||||||
|
update_suggestions: true # Suggest version updates
|
||||||
|
|
||||||
|
# Test Coverage Agent
|
||||||
|
test_coverage:
|
||||||
|
enabled: true
|
||||||
|
suggest_tests: true
|
||||||
|
min_coverage_percent: 80 # Warn if coverage below this
|
||||||
|
|
||||||
|
# Architecture Compliance Agent
|
||||||
|
architecture:
|
||||||
|
enabled: true
|
||||||
|
layers:
|
||||||
|
api:
|
||||||
|
can_import_from: [utils, models, services]
|
||||||
|
cannot_import_from: [db, repositories]
|
||||||
|
services:
|
||||||
|
can_import_from: [utils, models, repositories]
|
||||||
|
cannot_import_from: [api]
|
||||||
|
repositories:
|
||||||
|
can_import_from: [utils, models, db]
|
||||||
|
cannot_import_from: [api, services]
|
||||||
|
|
||||||
|
# Interaction Settings
|
||||||
|
# --------------------
|
||||||
# CUSTOMIZE YOUR BOT NAME HERE!
|
# CUSTOMIZE YOUR BOT NAME HERE!
|
||||||
# Change mention_prefix to your preferred bot name:
|
|
||||||
# "@ai-bot" - Default
|
|
||||||
# "@bartender" - Friendly bar theme
|
|
||||||
# "@uni" - Short and simple
|
|
||||||
# "@joey" - Personal assistant name
|
|
||||||
# "@codebot" - Code-focused name
|
|
||||||
# NOTE: Also update the workflow files (.github/workflows/ or .gitea/workflows/)
|
|
||||||
# to match this prefix in the 'if: contains(...)' condition
|
|
||||||
interaction:
|
interaction:
|
||||||
respond_to_mentions: true
|
respond_to_mentions: true
|
||||||
mention_prefix: "@codebot" # Change this to customize your bot's name!
|
mention_prefix: "@codebot"
|
||||||
commands:
|
commands:
|
||||||
- help
|
- help
|
||||||
- explain
|
- explain
|
||||||
- suggest
|
- suggest
|
||||||
- security
|
- security
|
||||||
- summarize # Generate PR summary (works on both issues and PRs)
|
- summarize
|
||||||
- changelog # Generate Keep a Changelog format entries (PR comments only)
|
- changelog
|
||||||
- explain-diff # Explain code changes in plain language (PR comments only)
|
- explain-diff
|
||||||
- triage
|
- triage
|
||||||
- review-again
|
- review-again
|
||||||
|
# New commands
|
||||||
|
- check-deps # Check dependencies for vulnerabilities
|
||||||
|
- suggest-tests # Suggest test cases
|
||||||
|
- refactor-suggest # Suggest refactoring opportunities
|
||||||
|
- architecture # Check architecture compliance
|
||||||
|
- arch-check # Alias for architecture
|
||||||
|
|
||||||
# Enterprise settings
|
# Security Scanning
|
||||||
|
# -----------------
|
||||||
|
security:
|
||||||
|
enabled: true
|
||||||
|
fail_on_high: true
|
||||||
|
rules_file: "security/security_rules.yml"
|
||||||
|
|
||||||
|
# SAST Integration
|
||||||
|
sast:
|
||||||
|
enabled: true
|
||||||
|
bandit: true # Python AST-based security scanner
|
||||||
|
semgrep: true # Polyglot security scanner with custom rules
|
||||||
|
trivy: false # Container/filesystem scanner (requires trivy installed)
|
||||||
|
|
||||||
|
# Notifications
|
||||||
|
# -------------
|
||||||
|
notifications:
|
||||||
|
enabled: false
|
||||||
|
threshold: "warning" # info | warning | error | critical
|
||||||
|
|
||||||
|
slack:
|
||||||
|
enabled: false
|
||||||
|
webhook_url: "" # Set via SLACK_WEBHOOK_URL env var
|
||||||
|
channel: "" # Override channel (optional)
|
||||||
|
username: "OpenRabbit"
|
||||||
|
|
||||||
|
discord:
|
||||||
|
enabled: false
|
||||||
|
webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var
|
||||||
|
username: "OpenRabbit"
|
||||||
|
avatar_url: ""
|
||||||
|
|
||||||
|
# Custom webhooks for other integrations
|
||||||
|
webhooks: []
|
||||||
|
# Example:
|
||||||
|
# - url: "https://your-webhook.example.com/notify"
|
||||||
|
# enabled: true
|
||||||
|
# headers:
|
||||||
|
# Authorization: "Bearer your-token"
|
||||||
|
|
||||||
|
# Compliance & Audit
|
||||||
|
# ------------------
|
||||||
|
compliance:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
# Audit Trail
|
||||||
|
audit:
|
||||||
|
enabled: false
|
||||||
|
log_file: "audit.log"
|
||||||
|
log_to_stdout: false
|
||||||
|
retention_days: 90
|
||||||
|
|
||||||
|
# CODEOWNERS Enforcement
|
||||||
|
codeowners:
|
||||||
|
enabled: false
|
||||||
|
require_approval: true # Require approval from code owners
|
||||||
|
|
||||||
|
# Regulatory Compliance
|
||||||
|
regulations:
|
||||||
|
hipaa: false
|
||||||
|
soc2: false
|
||||||
|
pci_dss: false
|
||||||
|
gdpr: false
|
||||||
|
|
||||||
|
# Enterprise Settings
|
||||||
|
# -------------------
|
||||||
enterprise:
|
enterprise:
|
||||||
audit_log: true
|
audit_log: true
|
||||||
audit_path: "/var/log/ai-review/"
|
audit_path: "/var/log/ai-review/"
|
||||||
@@ -81,44 +212,44 @@ enterprise:
|
|||||||
requests_per_minute: 30
|
requests_per_minute: 30
|
||||||
max_concurrent: 4
|
max_concurrent: 4
|
||||||
|
|
||||||
# Label mappings for auto-labeling
|
# Label Mappings
|
||||||
|
# --------------
|
||||||
# Each label has:
|
# Each label has:
|
||||||
# name: The label name to use/create (string) or full config (dict)
|
# name: The label name to use/create
|
||||||
# aliases: Alternative names for auto-detection (optional)
|
# aliases: Alternative names for auto-detection
|
||||||
# color: Hex color code without # (optional, for label creation)
|
# color: Hex color code without #
|
||||||
# description: Label description (optional, for label creation)
|
# description: Label description
|
||||||
labels:
|
labels:
|
||||||
priority:
|
priority:
|
||||||
critical:
|
critical:
|
||||||
name: "priority: critical"
|
name: "priority: critical"
|
||||||
color: "b60205" # Dark Red
|
color: "b60205"
|
||||||
description: "Critical priority - immediate attention required"
|
description: "Critical priority - immediate attention required"
|
||||||
aliases:
|
aliases: ["Priority - Critical", "P0", "critical", "Priority/Critical"]
|
||||||
["Priority - Critical", "P0", "critical", "Priority/Critical"]
|
|
||||||
high:
|
high:
|
||||||
name: "priority: high"
|
name: "priority: high"
|
||||||
color: "d73a4a" # Red
|
color: "d73a4a"
|
||||||
description: "High priority issue"
|
description: "High priority issue"
|
||||||
aliases: ["Priority - High", "P1", "high", "Priority/High"]
|
aliases: ["Priority - High", "P1", "high", "Priority/High"]
|
||||||
medium:
|
medium:
|
||||||
name: "priority: medium"
|
name: "priority: medium"
|
||||||
color: "fbca04" # Yellow
|
color: "fbca04"
|
||||||
description: "Medium priority issue"
|
description: "Medium priority issue"
|
||||||
aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"]
|
aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"]
|
||||||
low:
|
low:
|
||||||
name: "priority: low"
|
name: "priority: low"
|
||||||
color: "28a745" # Green
|
color: "28a745"
|
||||||
description: "Low priority issue"
|
description: "Low priority issue"
|
||||||
aliases: ["Priority - Low", "P3", "low", "Priority/Low"]
|
aliases: ["Priority - Low", "P3", "low", "Priority/Low"]
|
||||||
type:
|
type:
|
||||||
bug:
|
bug:
|
||||||
name: "type: bug"
|
name: "type: bug"
|
||||||
color: "d73a4a" # Red
|
color: "d73a4a"
|
||||||
description: "Something isn't working"
|
description: "Something isn't working"
|
||||||
aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"]
|
aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"]
|
||||||
feature:
|
feature:
|
||||||
name: "type: feature"
|
name: "type: feature"
|
||||||
color: "1d76db" # Blue
|
color: "1d76db"
|
||||||
description: "New feature request"
|
description: "New feature request"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -132,7 +263,7 @@ labels:
|
|||||||
]
|
]
|
||||||
question:
|
question:
|
||||||
name: "type: question"
|
name: "type: question"
|
||||||
color: "cc317c" # Purple
|
color: "cc317c"
|
||||||
description: "Further information is requested"
|
description: "Further information is requested"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -144,7 +275,7 @@ labels:
|
|||||||
]
|
]
|
||||||
docs:
|
docs:
|
||||||
name: "type: documentation"
|
name: "type: documentation"
|
||||||
color: "0075ca" # Light Blue
|
color: "0075ca"
|
||||||
description: "Documentation improvements"
|
description: "Documentation improvements"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -157,7 +288,7 @@ labels:
|
|||||||
]
|
]
|
||||||
security:
|
security:
|
||||||
name: "type: security"
|
name: "type: security"
|
||||||
color: "b60205" # Dark Red
|
color: "b60205"
|
||||||
description: "Security vulnerability or concern"
|
description: "Security vulnerability or concern"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -169,7 +300,7 @@ labels:
|
|||||||
]
|
]
|
||||||
testing:
|
testing:
|
||||||
name: "type: testing"
|
name: "type: testing"
|
||||||
color: "0e8a16" # Green
|
color: "0e8a16"
|
||||||
description: "Related to testing"
|
description: "Related to testing"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -183,7 +314,7 @@ labels:
|
|||||||
status:
|
status:
|
||||||
ai_approved:
|
ai_approved:
|
||||||
name: "ai-approved"
|
name: "ai-approved"
|
||||||
color: "28a745" # Green
|
color: "28a745"
|
||||||
description: "AI review approved this PR"
|
description: "AI review approved this PR"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -194,7 +325,7 @@ labels:
|
|||||||
]
|
]
|
||||||
ai_changes_required:
|
ai_changes_required:
|
||||||
name: "ai-changes-required"
|
name: "ai-changes-required"
|
||||||
color: "d73a4a" # Red
|
color: "d73a4a"
|
||||||
description: "AI review found issues requiring changes"
|
description: "AI review found issues requiring changes"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -205,7 +336,7 @@ labels:
|
|||||||
]
|
]
|
||||||
ai_reviewed:
|
ai_reviewed:
|
||||||
name: "ai-reviewed"
|
name: "ai-reviewed"
|
||||||
color: "1d76db" # Blue
|
color: "1d76db"
|
||||||
description: "This issue/PR has been reviewed by AI"
|
description: "This issue/PR has been reviewed by AI"
|
||||||
aliases:
|
aliases:
|
||||||
[
|
[
|
||||||
@@ -216,18 +347,9 @@ labels:
|
|||||||
"Status - Reviewed",
|
"Status - Reviewed",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Label schema detection patterns
|
# Label Pattern Detection
|
||||||
# Used by setup-labels command to detect existing naming conventions
|
# -----------------------
|
||||||
label_patterns:
|
label_patterns:
|
||||||
# Detect prefix-based naming (e.g., Kind/Bug, Type/Feature)
|
|
||||||
prefix_slash: "^(Kind|Type|Category)/(.+)$"
|
prefix_slash: "^(Kind|Type|Category)/(.+)$"
|
||||||
# Detect dash-separated naming (e.g., Priority - High, Status - Blocked)
|
|
||||||
prefix_dash: "^(Priority|Status|Reviewed) - (.+)$"
|
prefix_dash: "^(Priority|Status|Reviewed) - (.+)$"
|
||||||
# Detect colon-separated naming (e.g., type: bug, priority: high)
|
|
||||||
colon: "^(type|priority|status): (.+)$"
|
colon: "^(type|priority|status): (.+)$"
|
||||||
|
|
||||||
# Security scanning rules
|
|
||||||
security:
|
|
||||||
enabled: true
|
|
||||||
fail_on_high: true
|
|
||||||
rules_file: "security/security_rules.yml"
|
|
||||||
|
|||||||
20
tools/ai-review/notifications/__init__.py
Normal file
20
tools/ai-review/notifications/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Notifications Package
|
||||||
|
|
||||||
|
Provides webhook-based notifications for Slack, Discord, and other platforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from notifications.notifier import (
|
||||||
|
DiscordNotifier,
|
||||||
|
Notifier,
|
||||||
|
NotifierFactory,
|
||||||
|
SlackNotifier,
|
||||||
|
WebhookNotifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Notifier",
|
||||||
|
"SlackNotifier",
|
||||||
|
"DiscordNotifier",
|
||||||
|
"WebhookNotifier",
|
||||||
|
"NotifierFactory",
|
||||||
|
]
|
||||||
542
tools/ai-review/notifications/notifier.py
Normal file
542
tools/ai-review/notifications/notifier.py
Normal file
@@ -0,0 +1,542 @@
|
|||||||
|
"""Notification System
|
||||||
|
|
||||||
|
Provides webhook-based notifications for Slack, Discord, and other platforms.
|
||||||
|
Supports critical security findings, review summaries, and custom alerts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationLevel(Enum):
|
||||||
|
"""Notification severity levels."""
|
||||||
|
|
||||||
|
INFO = "info"
|
||||||
|
WARNING = "warning"
|
||||||
|
ERROR = "error"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NotificationMessage:
|
||||||
|
"""A notification message."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
message: str
|
||||||
|
level: NotificationLevel = NotificationLevel.INFO
|
||||||
|
fields: dict[str, str] | None = None
|
||||||
|
url: str | None = None
|
||||||
|
footer: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Notifier(ABC):
|
||||||
|
"""Abstract base class for notification providers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def send(self, message: NotificationMessage) -> bool:
|
||||||
|
"""Send a notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The notification message to send.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if sent successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def send_raw(self, payload: dict) -> bool:
|
||||||
|
"""Send a raw payload to the webhook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Raw payload in provider-specific format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if sent successfully, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SlackNotifier(Notifier):
|
||||||
|
"""Slack webhook notifier."""
|
||||||
|
|
||||||
|
# Color mapping for different levels
|
||||||
|
LEVEL_COLORS = {
|
||||||
|
NotificationLevel.INFO: "#36a64f", # Green
|
||||||
|
NotificationLevel.WARNING: "#ffcc00", # Yellow
|
||||||
|
NotificationLevel.ERROR: "#ff6600", # Orange
|
||||||
|
NotificationLevel.CRITICAL: "#cc0000", # Red
|
||||||
|
}
|
||||||
|
|
||||||
|
LEVEL_EMOJIS = {
|
||||||
|
NotificationLevel.INFO: ":information_source:",
|
||||||
|
NotificationLevel.WARNING: ":warning:",
|
||||||
|
NotificationLevel.ERROR: ":x:",
|
||||||
|
NotificationLevel.CRITICAL: ":rotating_light:",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
webhook_url: str | None = None,
|
||||||
|
channel: str | None = None,
|
||||||
|
username: str = "OpenRabbit",
|
||||||
|
icon_emoji: str = ":robot_face:",
|
||||||
|
timeout: int = 10,
|
||||||
|
):
|
||||||
|
"""Initialize Slack notifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook_url: Slack incoming webhook URL.
|
||||||
|
Defaults to SLACK_WEBHOOK_URL env var.
|
||||||
|
channel: Override channel (optional).
|
||||||
|
username: Bot username to display.
|
||||||
|
icon_emoji: Bot icon emoji.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.webhook_url = webhook_url or os.environ.get("SLACK_WEBHOOK_URL", "")
|
||||||
|
self.channel = channel
|
||||||
|
self.username = username
|
||||||
|
self.icon_emoji = icon_emoji
|
||||||
|
self.timeout = timeout
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def send(self, message: NotificationMessage) -> bool:
|
||||||
|
"""Send a notification to Slack."""
|
||||||
|
if not self.webhook_url:
|
||||||
|
self.logger.warning("Slack webhook URL not configured")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Build attachment
|
||||||
|
attachment = {
|
||||||
|
"color": self.LEVEL_COLORS.get(message.level, "#36a64f"),
|
||||||
|
"title": f"{self.LEVEL_EMOJIS.get(message.level, '')} {message.title}",
|
||||||
|
"text": message.message,
|
||||||
|
"mrkdwn_in": ["text", "fields"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.url:
|
||||||
|
attachment["title_link"] = message.url
|
||||||
|
|
||||||
|
if message.fields:
|
||||||
|
attachment["fields"] = [
|
||||||
|
{"title": k, "value": v, "short": len(v) < 40}
|
||||||
|
for k, v in message.fields.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
if message.footer:
|
||||||
|
attachment["footer"] = message.footer
|
||||||
|
attachment["footer_icon"] = "https://github.com/favicon.ico"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"username": self.username,
|
||||||
|
"icon_emoji": self.icon_emoji,
|
||||||
|
"attachments": [attachment],
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.channel:
|
||||||
|
payload["channel"] = self.channel
|
||||||
|
|
||||||
|
return self.send_raw(payload)
|
||||||
|
|
||||||
|
def send_raw(self, payload: dict) -> bool:
|
||||||
|
"""Send raw payload to Slack webhook."""
|
||||||
|
if not self.webhook_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.webhook_url,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return True
|
||||||
|
except requests.RequestException as e:
|
||||||
|
self.logger.error(f"Failed to send Slack notification: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordNotifier(Notifier):
|
||||||
|
"""Discord webhook notifier."""
|
||||||
|
|
||||||
|
# Color mapping for different levels (Discord uses decimal colors)
|
||||||
|
LEVEL_COLORS = {
|
||||||
|
NotificationLevel.INFO: 3066993, # Green
|
||||||
|
NotificationLevel.WARNING: 16776960, # Yellow
|
||||||
|
NotificationLevel.ERROR: 16744448, # Orange
|
||||||
|
NotificationLevel.CRITICAL: 13369344, # Red
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
webhook_url: str | None = None,
|
||||||
|
username: str = "OpenRabbit",
|
||||||
|
avatar_url: str | None = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
):
|
||||||
|
"""Initialize Discord notifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook_url: Discord webhook URL.
|
||||||
|
Defaults to DISCORD_WEBHOOK_URL env var.
|
||||||
|
username: Bot username to display.
|
||||||
|
avatar_url: Bot avatar URL.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.webhook_url = webhook_url or os.environ.get("DISCORD_WEBHOOK_URL", "")
|
||||||
|
self.username = username
|
||||||
|
self.avatar_url = avatar_url
|
||||||
|
self.timeout = timeout
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def send(self, message: NotificationMessage) -> bool:
|
||||||
|
"""Send a notification to Discord."""
|
||||||
|
if not self.webhook_url:
|
||||||
|
self.logger.warning("Discord webhook URL not configured")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Build embed
|
||||||
|
embed = {
|
||||||
|
"title": message.title,
|
||||||
|
"description": message.message,
|
||||||
|
"color": self.LEVEL_COLORS.get(message.level, 3066993),
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.url:
|
||||||
|
embed["url"] = message.url
|
||||||
|
|
||||||
|
if message.fields:
|
||||||
|
embed["fields"] = [
|
||||||
|
{"name": k, "value": v, "inline": len(v) < 40}
|
||||||
|
for k, v in message.fields.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
if message.footer:
|
||||||
|
embed["footer"] = {"text": message.footer}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"username": self.username,
|
||||||
|
"embeds": [embed],
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.avatar_url:
|
||||||
|
payload["avatar_url"] = self.avatar_url
|
||||||
|
|
||||||
|
return self.send_raw(payload)
|
||||||
|
|
||||||
|
def send_raw(self, payload: dict) -> bool:
|
||||||
|
"""Send raw payload to Discord webhook."""
|
||||||
|
if not self.webhook_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.webhook_url,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return True
|
||||||
|
except requests.RequestException as e:
|
||||||
|
self.logger.error(f"Failed to send Discord notification: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookNotifier(Notifier):
|
||||||
|
"""Generic webhook notifier for custom integrations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
webhook_url: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
):
|
||||||
|
"""Initialize generic webhook notifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
webhook_url: Webhook URL.
|
||||||
|
headers: Custom headers to include.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.webhook_url = webhook_url
|
||||||
|
self.headers = headers or {"Content-Type": "application/json"}
|
||||||
|
self.timeout = timeout
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def send(self, message: NotificationMessage) -> bool:
|
||||||
|
"""Send a notification as JSON payload."""
|
||||||
|
payload = {
|
||||||
|
"title": message.title,
|
||||||
|
"message": message.message,
|
||||||
|
"level": message.level.value,
|
||||||
|
"fields": message.fields or {},
|
||||||
|
"url": message.url,
|
||||||
|
"footer": message.footer,
|
||||||
|
}
|
||||||
|
return self.send_raw(payload)
|
||||||
|
|
||||||
|
def send_raw(self, payload: dict) -> bool:
|
||||||
|
"""Send raw payload to webhook."""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.webhook_url,
|
||||||
|
json=payload,
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return True
|
||||||
|
except requests.RequestException as e:
|
||||||
|
self.logger.error(f"Failed to send webhook notification: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class NotifierFactory:
|
||||||
|
"""Factory for creating notifier instances from config."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from_config(config: dict) -> list[Notifier]:
|
||||||
|
"""Create notifier instances from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with 'notifications' section.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of configured notifier instances.
|
||||||
|
"""
|
||||||
|
notifiers = []
|
||||||
|
notifications_config = config.get("notifications", {})
|
||||||
|
|
||||||
|
if not notifications_config.get("enabled", False):
|
||||||
|
return notifiers
|
||||||
|
|
||||||
|
# Slack
|
||||||
|
slack_config = notifications_config.get("slack", {})
|
||||||
|
if slack_config.get("enabled", False):
|
||||||
|
webhook_url = slack_config.get("webhook_url") or os.environ.get(
|
||||||
|
"SLACK_WEBHOOK_URL"
|
||||||
|
)
|
||||||
|
if webhook_url:
|
||||||
|
notifiers.append(
|
||||||
|
SlackNotifier(
|
||||||
|
webhook_url=webhook_url,
|
||||||
|
channel=slack_config.get("channel"),
|
||||||
|
username=slack_config.get("username", "OpenRabbit"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Discord
|
||||||
|
discord_config = notifications_config.get("discord", {})
|
||||||
|
if discord_config.get("enabled", False):
|
||||||
|
webhook_url = discord_config.get("webhook_url") or os.environ.get(
|
||||||
|
"DISCORD_WEBHOOK_URL"
|
||||||
|
)
|
||||||
|
if webhook_url:
|
||||||
|
notifiers.append(
|
||||||
|
DiscordNotifier(
|
||||||
|
webhook_url=webhook_url,
|
||||||
|
username=discord_config.get("username", "OpenRabbit"),
|
||||||
|
avatar_url=discord_config.get("avatar_url"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generic webhooks
|
||||||
|
webhooks_config = notifications_config.get("webhooks", [])
|
||||||
|
for webhook in webhooks_config:
|
||||||
|
if webhook.get("enabled", True) and webhook.get("url"):
|
||||||
|
notifiers.append(
|
||||||
|
WebhookNotifier(
|
||||||
|
webhook_url=webhook["url"],
|
||||||
|
headers=webhook.get("headers"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return notifiers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_notify(
|
||||||
|
level: NotificationLevel,
|
||||||
|
config: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if notification should be sent based on level threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Notification level.
|
||||||
|
config: Configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification should be sent.
|
||||||
|
"""
|
||||||
|
notifications_config = config.get("notifications", {})
|
||||||
|
if not notifications_config.get("enabled", False):
|
||||||
|
return False
|
||||||
|
|
||||||
|
threshold = notifications_config.get("threshold", "warning")
|
||||||
|
level_order = ["info", "warning", "error", "critical"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
threshold_idx = level_order.index(threshold)
|
||||||
|
level_idx = level_order.index(level.value)
|
||||||
|
return level_idx >= threshold_idx
|
||||||
|
except ValueError:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationService:
|
||||||
|
"""High-level notification service for sending alerts."""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
"""Initialize notification service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.notifiers = NotifierFactory.create_from_config(config)
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def notify(self, message: NotificationMessage) -> bool:
|
||||||
|
"""Send notification to all configured notifiers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Notification message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if at least one notification succeeded.
|
||||||
|
"""
|
||||||
|
if not NotifierFactory.should_notify(message.level, self.config):
|
||||||
|
return True # Not an error, just below threshold
|
||||||
|
|
||||||
|
if not self.notifiers:
|
||||||
|
self.logger.debug("No notifiers configured")
|
||||||
|
return True
|
||||||
|
|
||||||
|
success = False
|
||||||
|
for notifier in self.notifiers:
|
||||||
|
try:
|
||||||
|
if notifier.send(message):
|
||||||
|
success = True
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Notifier failed: {e}")
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
def notify_security_finding(
|
||||||
|
self,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
finding: dict,
|
||||||
|
pr_url: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Send notification for a security finding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name (owner/repo).
|
||||||
|
pr_number: Pull request number.
|
||||||
|
finding: Security finding dict with severity, description, etc.
|
||||||
|
pr_url: URL to the pull request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification succeeded.
|
||||||
|
"""
|
||||||
|
severity = finding.get("severity", "MEDIUM").upper()
|
||||||
|
level_map = {
|
||||||
|
"HIGH": NotificationLevel.CRITICAL,
|
||||||
|
"MEDIUM": NotificationLevel.WARNING,
|
||||||
|
"LOW": NotificationLevel.INFO,
|
||||||
|
}
|
||||||
|
level = level_map.get(severity, NotificationLevel.WARNING)
|
||||||
|
|
||||||
|
message = NotificationMessage(
|
||||||
|
title=f"Security Finding in {repo} PR #{pr_number}",
|
||||||
|
message=finding.get("description", "Security issue detected"),
|
||||||
|
level=level,
|
||||||
|
fields={
|
||||||
|
"Severity": severity,
|
||||||
|
"Category": finding.get("category", "Unknown"),
|
||||||
|
"File": finding.get("file", "N/A"),
|
||||||
|
"Line": str(finding.get("line", "N/A")),
|
||||||
|
},
|
||||||
|
url=pr_url,
|
||||||
|
footer=f"CWE: {finding.get('cwe', 'N/A')}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.notify(message)
|
||||||
|
|
||||||
|
def notify_review_complete(
|
||||||
|
self,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
summary: dict,
|
||||||
|
pr_url: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Send notification for completed review.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name (owner/repo).
|
||||||
|
pr_number: Pull request number.
|
||||||
|
summary: Review summary dict.
|
||||||
|
pr_url: URL to the pull request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification succeeded.
|
||||||
|
"""
|
||||||
|
recommendation = summary.get("recommendation", "COMMENT")
|
||||||
|
level_map = {
|
||||||
|
"APPROVE": NotificationLevel.INFO,
|
||||||
|
"COMMENT": NotificationLevel.INFO,
|
||||||
|
"REQUEST_CHANGES": NotificationLevel.WARNING,
|
||||||
|
}
|
||||||
|
level = level_map.get(recommendation, NotificationLevel.INFO)
|
||||||
|
|
||||||
|
high_issues = summary.get("high_severity_count", 0)
|
||||||
|
if high_issues > 0:
|
||||||
|
level = NotificationLevel.ERROR
|
||||||
|
|
||||||
|
message = NotificationMessage(
|
||||||
|
title=f"Review Complete: {repo} PR #{pr_number}",
|
||||||
|
message=summary.get("summary", "AI review completed"),
|
||||||
|
level=level,
|
||||||
|
fields={
|
||||||
|
"Recommendation": recommendation,
|
||||||
|
"High Severity": str(high_issues),
|
||||||
|
"Medium Severity": str(summary.get("medium_severity_count", 0)),
|
||||||
|
"Files Reviewed": str(summary.get("files_reviewed", 0)),
|
||||||
|
},
|
||||||
|
url=pr_url,
|
||||||
|
footer="OpenRabbit AI Review",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.notify(message)
|
||||||
|
|
||||||
|
def notify_error(
|
||||||
|
self,
|
||||||
|
repo: str,
|
||||||
|
error: str,
|
||||||
|
context: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Send notification for an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo: Repository name.
|
||||||
|
error: Error message.
|
||||||
|
context: Additional context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification succeeded.
|
||||||
|
"""
|
||||||
|
message = NotificationMessage(
|
||||||
|
title=f"Error in {repo}",
|
||||||
|
message=error,
|
||||||
|
level=NotificationLevel.ERROR,
|
||||||
|
fields={"Context": context} if context else None,
|
||||||
|
footer="OpenRabbit Error Alert",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.notify(message)
|
||||||
431
tools/ai-review/security/sast_scanner.py
Normal file
431
tools/ai-review/security/sast_scanner.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
"""SAST Scanner Integration
|
||||||
|
|
||||||
|
Integrates with external SAST tools like Bandit and Semgrep
|
||||||
|
to provide comprehensive security analysis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SASTFinding:
|
||||||
|
"""A finding from a SAST tool."""
|
||||||
|
|
||||||
|
tool: str
|
||||||
|
rule_id: str
|
||||||
|
severity: str # CRITICAL, HIGH, MEDIUM, LOW
|
||||||
|
file: str
|
||||||
|
line: int
|
||||||
|
message: str
|
||||||
|
code_snippet: str | None = None
|
||||||
|
cwe: str | None = None
|
||||||
|
owasp: str | None = None
|
||||||
|
fix_recommendation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SASTReport:
|
||||||
|
"""Combined report from all SAST tools."""
|
||||||
|
|
||||||
|
total_findings: int
|
||||||
|
findings_by_severity: dict[str, int]
|
||||||
|
findings_by_tool: dict[str, int]
|
||||||
|
findings: list[SASTFinding]
|
||||||
|
tools_run: list[str]
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class SASTScanner:
|
||||||
|
"""Aggregator for multiple SAST tools."""
|
||||||
|
|
||||||
|
def __init__(self, config: dict | None = None):
|
||||||
|
"""Initialize the SAST scanner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary with tool settings.
|
||||||
|
"""
|
||||||
|
self.config = config or {}
|
||||||
|
self.logger = logging.getLogger(self.__class__.__name__)
|
||||||
|
|
||||||
|
def scan_directory(self, path: str) -> SASTReport:
|
||||||
|
"""Scan a directory with all enabled SAST tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the directory to scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined SASTReport from all tools.
|
||||||
|
"""
|
||||||
|
all_findings = []
|
||||||
|
tools_run = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
sast_config = self.config.get("security", {}).get("sast", {})
|
||||||
|
|
||||||
|
# Run Bandit (Python)
|
||||||
|
if sast_config.get("bandit", True):
|
||||||
|
if self._is_tool_available("bandit"):
|
||||||
|
try:
|
||||||
|
findings = self._run_bandit(path)
|
||||||
|
all_findings.extend(findings)
|
||||||
|
tools_run.append("bandit")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Bandit error: {e}")
|
||||||
|
else:
|
||||||
|
self.logger.debug("Bandit not installed, skipping")
|
||||||
|
|
||||||
|
# Run Semgrep
|
||||||
|
if sast_config.get("semgrep", True):
|
||||||
|
if self._is_tool_available("semgrep"):
|
||||||
|
try:
|
||||||
|
findings = self._run_semgrep(path)
|
||||||
|
all_findings.extend(findings)
|
||||||
|
tools_run.append("semgrep")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Semgrep error: {e}")
|
||||||
|
else:
|
||||||
|
self.logger.debug("Semgrep not installed, skipping")
|
||||||
|
|
||||||
|
# Run Trivy (if enabled for filesystem scanning)
|
||||||
|
if sast_config.get("trivy", False):
|
||||||
|
if self._is_tool_available("trivy"):
|
||||||
|
try:
|
||||||
|
findings = self._run_trivy(path)
|
||||||
|
all_findings.extend(findings)
|
||||||
|
tools_run.append("trivy")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Trivy error: {e}")
|
||||||
|
else:
|
||||||
|
self.logger.debug("Trivy not installed, skipping")
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
by_severity = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0}
|
||||||
|
by_tool = {}
|
||||||
|
|
||||||
|
for finding in all_findings:
|
||||||
|
sev = finding.severity.upper()
|
||||||
|
if sev in by_severity:
|
||||||
|
by_severity[sev] += 1
|
||||||
|
tool = finding.tool
|
||||||
|
by_tool[tool] = by_tool.get(tool, 0) + 1
|
||||||
|
|
||||||
|
return SASTReport(
|
||||||
|
total_findings=len(all_findings),
|
||||||
|
findings_by_severity=by_severity,
|
||||||
|
findings_by_tool=by_tool,
|
||||||
|
findings=all_findings,
|
||||||
|
tools_run=tools_run,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
def scan_content(self, content: str, filename: str) -> list[SASTFinding]:
|
||||||
|
"""Scan file content with SAST tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content to scan.
|
||||||
|
filename: Name of the file (for language detection).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SASTFinding objects.
|
||||||
|
"""
|
||||||
|
# Create temporary file for scanning
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
suffix=os.path.splitext(filename)[1],
|
||||||
|
delete=False,
|
||||||
|
) as f:
|
||||||
|
f.write(content)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
report = self.scan_directory(os.path.dirname(temp_path))
|
||||||
|
# Filter findings for our specific file
|
||||||
|
findings = [
|
||||||
|
f
|
||||||
|
for f in report.findings
|
||||||
|
if os.path.basename(f.file) == os.path.basename(temp_path)
|
||||||
|
]
|
||||||
|
# Update file path to original filename
|
||||||
|
for finding in findings:
|
||||||
|
finding.file = filename
|
||||||
|
return findings
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def scan_diff(self, diff: str) -> list[SASTFinding]:
|
||||||
|
"""Scan a diff for security issues.
|
||||||
|
|
||||||
|
Only scans added/modified lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
diff: Git diff content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SASTFinding objects.
|
||||||
|
"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
# Parse diff and extract added content per file
|
||||||
|
files_content = {}
|
||||||
|
current_file = None
|
||||||
|
current_content = []
|
||||||
|
|
||||||
|
for line in diff.splitlines():
|
||||||
|
if line.startswith("diff --git"):
|
||||||
|
if current_file and current_content:
|
||||||
|
files_content[current_file] = "\n".join(current_content)
|
||||||
|
current_file = None
|
||||||
|
current_content = []
|
||||||
|
# Extract filename
|
||||||
|
match = line.split(" b/")
|
||||||
|
if len(match) > 1:
|
||||||
|
current_file = match[1]
|
||||||
|
elif line.startswith("+") and not line.startswith("+++"):
|
||||||
|
if current_file:
|
||||||
|
current_content.append(line[1:]) # Remove + prefix
|
||||||
|
|
||||||
|
# Don't forget last file
|
||||||
|
if current_file and current_content:
|
||||||
|
files_content[current_file] = "\n".join(current_content)
|
||||||
|
|
||||||
|
# Scan each file's content
|
||||||
|
for filename, content in files_content.items():
|
||||||
|
if content.strip():
|
||||||
|
file_findings = self.scan_content(content, filename)
|
||||||
|
findings.extend(file_findings)
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
def _is_tool_available(self, tool: str) -> bool:
|
||||||
|
"""Check if a tool is installed and available."""
|
||||||
|
return shutil.which(tool) is not None
|
||||||
|
|
||||||
|
def _run_bandit(self, path: str) -> list[SASTFinding]:
|
||||||
|
"""Run Bandit security scanner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SASTFinding objects.
|
||||||
|
"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"bandit",
|
||||||
|
"-r",
|
||||||
|
path,
|
||||||
|
"-f",
|
||||||
|
"json",
|
||||||
|
"-ll", # Only high and medium severity
|
||||||
|
"--quiet",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.stdout:
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
|
||||||
|
for issue in data.get("results", []):
|
||||||
|
severity = issue.get("issue_severity", "MEDIUM").upper()
|
||||||
|
|
||||||
|
findings.append(
|
||||||
|
SASTFinding(
|
||||||
|
tool="bandit",
|
||||||
|
rule_id=issue.get("test_id", ""),
|
||||||
|
severity=severity,
|
||||||
|
file=issue.get("filename", ""),
|
||||||
|
line=issue.get("line_number", 0),
|
||||||
|
message=issue.get("issue_text", ""),
|
||||||
|
code_snippet=issue.get("code", ""),
|
||||||
|
cwe=f"CWE-{issue.get('issue_cwe', {}).get('id', '')}"
|
||||||
|
if issue.get("issue_cwe")
|
||||||
|
else None,
|
||||||
|
fix_recommendation=issue.get("more_info", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
self.logger.warning("Bandit scan timed out")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.warning(f"Failed to parse Bandit output: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Bandit scan failed: {e}")
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
def _run_semgrep(self, path: str) -> list[SASTFinding]:
|
||||||
|
"""Run Semgrep security scanner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SASTFinding objects.
|
||||||
|
"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
# Get Semgrep config from settings
|
||||||
|
sast_config = self.config.get("security", {}).get("sast", {})
|
||||||
|
semgrep_rules = sast_config.get("semgrep_rules", "p/security-audit")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"semgrep",
|
||||||
|
"--config",
|
||||||
|
semgrep_rules,
|
||||||
|
"--json",
|
||||||
|
"--quiet",
|
||||||
|
path,
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=180,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.stdout:
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
|
||||||
|
for finding in data.get("results", []):
|
||||||
|
# Map Semgrep severity to our scale
|
||||||
|
sev_map = {
|
||||||
|
"ERROR": "HIGH",
|
||||||
|
"WARNING": "MEDIUM",
|
||||||
|
"INFO": "LOW",
|
||||||
|
}
|
||||||
|
severity = sev_map.get(
|
||||||
|
finding.get("extra", {}).get("severity", "WARNING"), "MEDIUM"
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = finding.get("extra", {}).get("metadata", {})
|
||||||
|
|
||||||
|
findings.append(
|
||||||
|
SASTFinding(
|
||||||
|
tool="semgrep",
|
||||||
|
rule_id=finding.get("check_id", ""),
|
||||||
|
severity=severity,
|
||||||
|
file=finding.get("path", ""),
|
||||||
|
line=finding.get("start", {}).get("line", 0),
|
||||||
|
message=finding.get("extra", {}).get("message", ""),
|
||||||
|
code_snippet=finding.get("extra", {}).get("lines", ""),
|
||||||
|
cwe=metadata.get("cwe", [None])[0]
|
||||||
|
if metadata.get("cwe")
|
||||||
|
else None,
|
||||||
|
owasp=metadata.get("owasp", [None])[0]
|
||||||
|
if metadata.get("owasp")
|
||||||
|
else None,
|
||||||
|
fix_recommendation=metadata.get("fix", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
self.logger.warning("Semgrep scan timed out")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.warning(f"Failed to parse Semgrep output: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Semgrep scan failed: {e}")
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
def _run_trivy(self, path: str) -> list[SASTFinding]:
|
||||||
|
"""Run Trivy filesystem scanner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SASTFinding objects.
|
||||||
|
"""
|
||||||
|
findings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"trivy",
|
||||||
|
"fs",
|
||||||
|
"--format",
|
||||||
|
"json",
|
||||||
|
"--security-checks",
|
||||||
|
"vuln,secret,config",
|
||||||
|
path,
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=180,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.stdout:
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
|
||||||
|
for result_item in data.get("Results", []):
|
||||||
|
target = result_item.get("Target", "")
|
||||||
|
|
||||||
|
# Process vulnerabilities
|
||||||
|
for vuln in result_item.get("Vulnerabilities", []):
|
||||||
|
severity = vuln.get("Severity", "MEDIUM").upper()
|
||||||
|
|
||||||
|
findings.append(
|
||||||
|
SASTFinding(
|
||||||
|
tool="trivy",
|
||||||
|
rule_id=vuln.get("VulnerabilityID", ""),
|
||||||
|
severity=severity,
|
||||||
|
file=target,
|
||||||
|
line=0,
|
||||||
|
message=vuln.get("Title", ""),
|
||||||
|
cwe=vuln.get("CweIDs", [None])[0]
|
||||||
|
if vuln.get("CweIDs")
|
||||||
|
else None,
|
||||||
|
fix_recommendation=f"Upgrade to {vuln.get('FixedVersion', 'latest')}"
|
||||||
|
if vuln.get("FixedVersion")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process secrets
|
||||||
|
for secret in result_item.get("Secrets", []):
|
||||||
|
findings.append(
|
||||||
|
SASTFinding(
|
||||||
|
tool="trivy",
|
||||||
|
rule_id=secret.get("RuleID", ""),
|
||||||
|
severity="HIGH",
|
||||||
|
file=target,
|
||||||
|
line=secret.get("StartLine", 0),
|
||||||
|
message=f"Secret detected: {secret.get('Title', '')}",
|
||||||
|
code_snippet=secret.get("Match", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
self.logger.warning("Trivy scan timed out")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.warning(f"Failed to parse Trivy output: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Trivy scan failed: {e}")
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
def get_sast_scanner(config: dict | None = None) -> SASTScanner:
|
||||||
|
"""Get a configured SAST scanner instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured SASTScanner instance.
|
||||||
|
"""
|
||||||
|
return SASTScanner(config=config)
|
||||||
@@ -1,9 +1,14 @@
|
|||||||
"""Utility Functions Package
|
"""Utility Functions Package
|
||||||
|
|
||||||
This package contains utility functions for webhook sanitization,
|
This package contains utility functions for webhook sanitization,
|
||||||
safe event dispatching, and other helper functions.
|
safe event dispatching, ignore patterns, and other helper functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from utils.ignore_patterns import (
|
||||||
|
IgnorePatterns,
|
||||||
|
get_ignore_patterns,
|
||||||
|
should_ignore_file,
|
||||||
|
)
|
||||||
from utils.webhook_sanitizer import (
|
from utils.webhook_sanitizer import (
|
||||||
extract_minimal_context,
|
extract_minimal_context,
|
||||||
sanitize_webhook_data,
|
sanitize_webhook_data,
|
||||||
@@ -16,4 +21,7 @@ __all__ = [
|
|||||||
"validate_repository_format",
|
"validate_repository_format",
|
||||||
"extract_minimal_context",
|
"extract_minimal_context",
|
||||||
"validate_webhook_signature",
|
"validate_webhook_signature",
|
||||||
|
"IgnorePatterns",
|
||||||
|
"get_ignore_patterns",
|
||||||
|
"should_ignore_file",
|
||||||
]
|
]
|
||||||
|
|||||||
358
tools/ai-review/utils/ignore_patterns.py
Normal file
358
tools/ai-review/utils/ignore_patterns.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""AI Review Ignore Patterns
|
||||||
|
|
||||||
|
Provides .gitignore-style pattern matching for excluding files from AI review.
|
||||||
|
Reads patterns from .ai-reviewignore files in the repository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IgnoreRule:
|
||||||
|
"""A single ignore rule."""
|
||||||
|
|
||||||
|
pattern: str
|
||||||
|
negation: bool = False
|
||||||
|
directory_only: bool = False
|
||||||
|
anchored: bool = False
|
||||||
|
regex: re.Pattern = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Compile the pattern to regex."""
|
||||||
|
self.regex = self._compile_pattern()
|
||||||
|
|
||||||
|
def _compile_pattern(self) -> re.Pattern:
|
||||||
|
"""Convert gitignore pattern to regex."""
|
||||||
|
pattern = self.pattern
|
||||||
|
|
||||||
|
# Handle directory-only patterns
|
||||||
|
if pattern.endswith("/"):
|
||||||
|
pattern = pattern[:-1]
|
||||||
|
self.directory_only = True
|
||||||
|
|
||||||
|
# Handle anchored patterns (starting with /)
|
||||||
|
if pattern.startswith("/"):
|
||||||
|
pattern = pattern[1:]
|
||||||
|
self.anchored = True
|
||||||
|
|
||||||
|
# Escape special regex characters except * and ?
|
||||||
|
regex_pattern = ""
|
||||||
|
i = 0
|
||||||
|
while i < len(pattern):
|
||||||
|
c = pattern[i]
|
||||||
|
if c == "*":
|
||||||
|
if i + 1 < len(pattern) and pattern[i + 1] == "*":
|
||||||
|
# ** matches everything including /
|
||||||
|
if i + 2 < len(pattern) and pattern[i + 2] == "/":
|
||||||
|
regex_pattern += "(?:.*/)?"
|
||||||
|
i += 3
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
regex_pattern += ".*"
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# * matches everything except /
|
||||||
|
regex_pattern += "[^/]*"
|
||||||
|
elif c == "?":
|
||||||
|
regex_pattern += "[^/]"
|
||||||
|
elif c == "[":
|
||||||
|
# Character class
|
||||||
|
j = i + 1
|
||||||
|
if j < len(pattern) and pattern[j] == "!":
|
||||||
|
regex_pattern += "[^"
|
||||||
|
j += 1
|
||||||
|
else:
|
||||||
|
regex_pattern += "["
|
||||||
|
while j < len(pattern) and pattern[j] != "]":
|
||||||
|
regex_pattern += pattern[j]
|
||||||
|
j += 1
|
||||||
|
if j < len(pattern):
|
||||||
|
regex_pattern += "]"
|
||||||
|
i = j
|
||||||
|
elif c in ".^$+{}|()":
|
||||||
|
regex_pattern += "\\" + c
|
||||||
|
else:
|
||||||
|
regex_pattern += c
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Anchor pattern
|
||||||
|
if self.anchored:
|
||||||
|
regex_pattern = "^" + regex_pattern
|
||||||
|
else:
|
||||||
|
regex_pattern = "(?:^|/)" + regex_pattern
|
||||||
|
|
||||||
|
# Match to end or as directory prefix
|
||||||
|
regex_pattern += "(?:$|/)"
|
||||||
|
|
||||||
|
return re.compile(regex_pattern)
|
||||||
|
|
||||||
|
def matches(self, path: str, is_directory: bool = False) -> bool:
|
||||||
|
"""Check if a path matches this rule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Relative path to check.
|
||||||
|
is_directory: Whether the path is a directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path matches.
|
||||||
|
"""
|
||||||
|
if self.directory_only and not is_directory:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Normalize path
|
||||||
|
path = path.replace("\\", "/")
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = "/" + path
|
||||||
|
|
||||||
|
return bool(self.regex.search(path))
|
||||||
|
|
||||||
|
|
||||||
|
class IgnorePatterns:
|
||||||
|
"""Manages .ai-reviewignore patterns for a repository."""
|
||||||
|
|
||||||
|
DEFAULT_PATTERNS = [
|
||||||
|
# Version control
|
||||||
|
".git/",
|
||||||
|
".svn/",
|
||||||
|
".hg/",
|
||||||
|
# Dependencies
|
||||||
|
"node_modules/",
|
||||||
|
"vendor/",
|
||||||
|
"venv/",
|
||||||
|
".venv/",
|
||||||
|
"__pycache__/",
|
||||||
|
"*.pyc",
|
||||||
|
# Build outputs
|
||||||
|
"dist/",
|
||||||
|
"build/",
|
||||||
|
"out/",
|
||||||
|
"target/",
|
||||||
|
"*.min.js",
|
||||||
|
"*.min.css",
|
||||||
|
"*.bundle.js",
|
||||||
|
# IDE/Editor
|
||||||
|
".idea/",
|
||||||
|
".vscode/",
|
||||||
|
"*.swp",
|
||||||
|
"*.swo",
|
||||||
|
# Generated files
|
||||||
|
"*.lock",
|
||||||
|
"package-lock.json",
|
||||||
|
"yarn.lock",
|
||||||
|
"poetry.lock",
|
||||||
|
"Pipfile.lock",
|
||||||
|
"Cargo.lock",
|
||||||
|
"go.sum",
|
||||||
|
# Binary files
|
||||||
|
"*.exe",
|
||||||
|
"*.dll",
|
||||||
|
"*.so",
|
||||||
|
"*.dylib",
|
||||||
|
"*.bin",
|
||||||
|
"*.o",
|
||||||
|
"*.a",
|
||||||
|
# Media files
|
||||||
|
"*.png",
|
||||||
|
"*.jpg",
|
||||||
|
"*.jpeg",
|
||||||
|
"*.gif",
|
||||||
|
"*.svg",
|
||||||
|
"*.ico",
|
||||||
|
"*.mp3",
|
||||||
|
"*.mp4",
|
||||||
|
"*.wav",
|
||||||
|
"*.pdf",
|
||||||
|
# Archives
|
||||||
|
"*.zip",
|
||||||
|
"*.tar",
|
||||||
|
"*.gz",
|
||||||
|
"*.rar",
|
||||||
|
"*.7z",
|
||||||
|
# Large data files
|
||||||
|
"*.csv",
|
||||||
|
"*.json.gz",
|
||||||
|
"*.sql",
|
||||||
|
"*.sqlite",
|
||||||
|
"*.db",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_root: str | None = None,
|
||||||
|
ignore_file: str = ".ai-reviewignore",
|
||||||
|
use_defaults: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize ignore patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_root: Repository root path.
|
||||||
|
ignore_file: Name of ignore file to read.
|
||||||
|
use_defaults: Whether to include default patterns.
|
||||||
|
"""
|
||||||
|
self.repo_root = repo_root or os.getcwd()
|
||||||
|
self.ignore_file = ignore_file
|
||||||
|
self.rules: list[IgnoreRule] = []
|
||||||
|
|
||||||
|
# Load default patterns
|
||||||
|
if use_defaults:
|
||||||
|
for pattern in self.DEFAULT_PATTERNS:
|
||||||
|
self._add_pattern(pattern)
|
||||||
|
|
||||||
|
# Load patterns from ignore file
|
||||||
|
self._load_ignore_file()
|
||||||
|
|
||||||
|
def _load_ignore_file(self):
|
||||||
|
"""Load patterns from .ai-reviewignore file."""
|
||||||
|
ignore_path = os.path.join(self.repo_root, self.ignore_file)
|
||||||
|
|
||||||
|
if not os.path.exists(ignore_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(ignore_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.rstrip("\n\r")
|
||||||
|
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._add_pattern(line)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore errors reading the file
|
||||||
|
|
||||||
|
def _add_pattern(self, pattern: str):
|
||||||
|
"""Add a pattern to the rules list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern string (gitignore format).
|
||||||
|
"""
|
||||||
|
# Check for negation
|
||||||
|
negation = False
|
||||||
|
if pattern.startswith("!"):
|
||||||
|
negation = True
|
||||||
|
pattern = pattern[1:]
|
||||||
|
|
||||||
|
if not pattern:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.rules.append(IgnoreRule(pattern=pattern, negation=negation))
|
||||||
|
|
||||||
|
def is_ignored(self, path: str, is_directory: bool = False) -> bool:
|
||||||
|
"""Check if a path should be ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Relative path to check.
|
||||||
|
is_directory: Whether the path is a directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path should be ignored.
|
||||||
|
"""
|
||||||
|
# Normalize path
|
||||||
|
path = path.replace("\\", "/").lstrip("/")
|
||||||
|
|
||||||
|
# Check each rule in order (later rules override earlier ones)
|
||||||
|
ignored = False
|
||||||
|
for rule in self.rules:
|
||||||
|
if rule.matches(path, is_directory):
|
||||||
|
ignored = not rule.negation
|
||||||
|
|
||||||
|
return ignored
|
||||||
|
|
||||||
|
def filter_paths(self, paths: list[str]) -> list[str]:
|
||||||
|
"""Filter a list of paths, removing ignored ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: List of relative paths.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of paths.
|
||||||
|
"""
|
||||||
|
return [p for p in paths if not self.is_ignored(p)]
|
||||||
|
|
||||||
|
def filter_diff_files(self, files: list[dict]) -> list[dict]:
|
||||||
|
"""Filter diff file objects, removing ignored ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of file dicts with 'filename' or 'path' key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of file dicts.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for f in files:
|
||||||
|
path = f.get("filename") or f.get("path") or f.get("name", "")
|
||||||
|
if not self.is_ignored(path):
|
||||||
|
result.append(f)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def should_review_file(self, filename: str) -> bool:
|
||||||
|
"""Check if a file should be reviewed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: File path to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file should be reviewed.
|
||||||
|
"""
|
||||||
|
return not self.is_ignored(filename)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls, config: dict, repo_root: str | None = None
|
||||||
|
) -> "IgnorePatterns":
|
||||||
|
"""Create IgnorePatterns from config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary.
|
||||||
|
repo_root: Repository root path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IgnorePatterns instance.
|
||||||
|
"""
|
||||||
|
ignore_config = config.get("ignore", {})
|
||||||
|
use_defaults = ignore_config.get("use_defaults", True)
|
||||||
|
ignore_file = ignore_config.get("file", ".ai-reviewignore")
|
||||||
|
|
||||||
|
instance = cls(
|
||||||
|
repo_root=repo_root,
|
||||||
|
ignore_file=ignore_file,
|
||||||
|
use_defaults=use_defaults,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add any extra patterns from config
|
||||||
|
extra_patterns = ignore_config.get("patterns", [])
|
||||||
|
for pattern in extra_patterns:
|
||||||
|
instance._add_pattern(pattern)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def get_ignore_patterns(repo_root: str | None = None) -> IgnorePatterns:
|
||||||
|
"""Get ignore patterns for a repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_root: Repository root path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IgnorePatterns instance.
|
||||||
|
"""
|
||||||
|
return IgnorePatterns(repo_root=repo_root)
|
||||||
|
|
||||||
|
|
||||||
|
def should_ignore_file(filename: str, repo_root: str | None = None) -> bool:
|
||||||
|
"""Quick check if a file should be ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: File path to check.
|
||||||
|
repo_root: Repository root path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file should be ignored.
|
||||||
|
"""
|
||||||
|
return get_ignore_patterns(repo_root).is_ignored(filename)
|
||||||
Reference in New Issue
Block a user