just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
17
CLAUDE.md
17
CLAUDE.md
@@ -69,11 +69,17 @@ The codebase uses an **agent-based architecture** where specialized agents handl
|
||||
- `execute(context)` - Main execution logic
|
||||
- Returns `AgentResult` with success status, message, data, and actions taken
|
||||
|
||||
**Core Agents:**
|
||||
- **PRAgent** - Reviews pull requests with inline comments and security scanning
|
||||
- **IssueAgent** - Triages issues and responds to @ai-bot commands
|
||||
- **IssueAgent** - Triages issues and responds to @codebot commands
|
||||
- **CodebaseAgent** - Analyzes entire codebase health and tech debt
|
||||
- **ChatAgent** - Interactive assistant with tool calling (search_codebase, read_file, search_web)
|
||||
|
||||
**Specialized Agents:**
|
||||
- **DependencyAgent** - Scans dependencies for security vulnerabilities (Python, JavaScript)
|
||||
- **TestCoverageAgent** - Analyzes code for test coverage gaps and suggests test cases
|
||||
- **ArchitectureAgent** - Enforces layer separation and detects architecture violations
|
||||
|
||||
3. **Dispatcher** (`dispatcher.py`) - Routes events to appropriate agents:
|
||||
- Registers agents at startup
|
||||
- Determines which agents can handle each event
|
||||
@@ -84,14 +90,23 @@ The codebase uses an **agent-based architecture** where specialized agents handl
|
||||
|
||||
The `LLMClient` (`clients/llm_client.py`) provides a unified interface for multiple LLM providers:
|
||||
|
||||
**Core Providers (in llm_client.py):**
|
||||
- **OpenAI** - Primary provider (gpt-4.1-mini default)
|
||||
- **OpenRouter** - Multi-provider access (claude-3.5-sonnet)
|
||||
- **Ollama** - Self-hosted models (codellama:13b)
|
||||
|
||||
**Additional Providers (in clients/providers/):**
|
||||
- **AnthropicProvider** - Direct Anthropic Claude API (claude-3.5-sonnet)
|
||||
- **AzureOpenAIProvider** - Azure OpenAI Service with API key auth
|
||||
- **AzureOpenAIWithAADProvider** - Azure OpenAI with Azure AD authentication
|
||||
- **GeminiProvider** - Google Gemini API (public)
|
||||
- **VertexAIGeminiProvider** - Google Vertex AI Gemini (enterprise GCP)
|
||||
|
||||
Key features:
|
||||
- Tool/function calling support via `call_with_tools(messages, tools)`
|
||||
- JSON response parsing with fallback extraction
|
||||
- Provider-specific configuration via `config.yml`
|
||||
- Configurable timeouts per provider
|
||||
|
||||
### Platform Abstraction
|
||||
|
||||
|
||||
116
README.md
116
README.md
@@ -1,6 +1,6 @@
|
||||
# OpenRabbit
|
||||
|
||||
Enterprise-grade AI code review system for **Gitea** with automated PR review, issue triage, interactive chat, and codebase analysis.
|
||||
Enterprise-grade AI code review system for **Gitea** and **GitHub** with automated PR review, issue triage, interactive chat, and codebase analysis.
|
||||
|
||||
---
|
||||
|
||||
@@ -14,9 +14,15 @@ Enterprise-grade AI code review system for **Gitea** with automated PR review, i
|
||||
| **Chat** | Interactive AI chat with codebase search and web search tools |
|
||||
| **@codebot Commands** | `@codebot summarize`, `changelog`, `explain-diff`, `explain`, `suggest`, `triage`, `review-again` in comments |
|
||||
| **Codebase Analysis** | Health scores, tech debt tracking, weekly reports |
|
||||
| **Security Scanner** | 17 OWASP-aligned rules for vulnerability detection |
|
||||
| **Security Scanner** | 17 OWASP-aligned rules + SAST integration (Bandit, Semgrep) |
|
||||
| **Dependency Scanning** | Vulnerability detection for Python, JavaScript dependencies |
|
||||
| **Test Coverage** | AI-powered test suggestions for untested code |
|
||||
| **Architecture Compliance** | Layer separation enforcement, circular dependency detection |
|
||||
| **Notifications** | Slack/Discord alerts for security findings and reviews |
|
||||
| **Compliance** | Audit trail, CODEOWNERS enforcement, regulatory support |
|
||||
| **Multi-Provider LLM** | OpenAI, Anthropic Claude, Azure OpenAI, Google Gemini, Ollama |
|
||||
| **Enterprise Ready** | Audit logging, metrics, Prometheus export |
|
||||
| **Gitea Native** | Built for Gitea workflows and API |
|
||||
| **Gitea Native** | Built for Gitea workflows and API (also works with GitHub) |
|
||||
|
||||
---
|
||||
|
||||
@@ -116,12 +122,28 @@ tools/ai-review/
|
||||
│ ├── issue_agent.py # Issue triage & @codebot commands
|
||||
│ ├── pr_agent.py # PR review with security scan
|
||||
│ ├── codebase_agent.py # Codebase health analysis
|
||||
│ └── chat_agent.py # Interactive chat with tool calling
|
||||
│ ├── chat_agent.py # Interactive chat with tool calling
|
||||
│ ├── dependency_agent.py # Dependency vulnerability scanning
|
||||
│ ├── test_coverage_agent.py # Test coverage analysis
|
||||
│ └── architecture_agent.py # Architecture compliance checking
|
||||
├── clients/ # API clients
|
||||
│ ├── gitea_client.py # Gitea REST API wrapper
|
||||
│ └── llm_client.py # Multi-provider LLM client with tool support
|
||||
│ ├── llm_client.py # Multi-provider LLM client with tool support
|
||||
│ └── providers/ # Additional LLM providers
|
||||
│ ├── anthropic_provider.py # Direct Anthropic Claude API
|
||||
│ ├── azure_provider.py # Azure OpenAI Service
|
||||
│ └── gemini_provider.py # Google Gemini API
|
||||
├── security/ # Security scanning
|
||||
│ └── security_scanner.py # 17 OWASP-aligned rules
|
||||
│ ├── security_scanner.py # 17 OWASP-aligned rules
|
||||
│ └── sast_scanner.py # Bandit, Semgrep, Trivy integration
|
||||
├── notifications/ # Alerting system
|
||||
│ └── notifier.py # Slack, Discord, webhook notifications
|
||||
├── compliance/ # Compliance & audit
|
||||
│ ├── audit_trail.py # Audit logging with integrity verification
|
||||
│ └── codeowners.py # CODEOWNERS enforcement
|
||||
├── utils/ # Utility functions
|
||||
│ ├── ignore_patterns.py # .ai-reviewignore support
|
||||
│ └── webhook_sanitizer.py # Input validation
|
||||
├── enterprise/ # Enterprise features
|
||||
│ ├── audit_logger.py # JSONL audit logging
|
||||
│ └── metrics.py # Prometheus-compatible metrics
|
||||
@@ -182,6 +204,10 @@ In any issue comment:
|
||||
| `@codebot summarize` | Summarize the issue in 2-3 sentences |
|
||||
| `@codebot explain` | Explain what the issue is about |
|
||||
| `@codebot suggest` | Suggest solutions or next steps |
|
||||
| `@codebot check-deps` | Scan dependencies for security vulnerabilities |
|
||||
| `@codebot suggest-tests` | Suggest test cases for changed code |
|
||||
| `@codebot refactor-suggest` | Suggest refactoring opportunities |
|
||||
| `@codebot architecture` | Check architecture compliance (alias: `arch-check`) |
|
||||
| `@codebot` (any question) | Chat with AI using codebase/web search tools |
|
||||
|
||||
### Pull Request Commands
|
||||
@@ -522,19 +548,91 @@ Replace `'Bartender'` with your bot's Gitea username. This prevents the bot from
|
||||
|
||||
| Provider | Model | Use Case |
|
||||
|----------|-------|----------|
|
||||
| OpenAI | gpt-4.1-mini | Fast, reliable |
|
||||
| OpenAI | gpt-4.1-mini | Fast, reliable, default |
|
||||
| Anthropic | claude-3.5-sonnet | Direct Claude API access |
|
||||
| Azure OpenAI | gpt-4 (deployment) | Enterprise Azure deployments |
|
||||
| Google Gemini | gemini-1.5-pro | GCP customers, Vertex AI |
|
||||
| OpenRouter | claude-3.5-sonnet | Multi-provider access |
|
||||
| Ollama | codellama:13b | Self-hosted, private |
|
||||
|
||||
### Provider Configuration
|
||||
|
||||
```yaml
|
||||
# In config.yml
|
||||
provider: anthropic # openai | anthropic | azure | gemini | openrouter | ollama
|
||||
|
||||
# Azure OpenAI
|
||||
azure:
|
||||
endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var
|
||||
deployment: "gpt-4"
|
||||
api_version: "2024-02-15-preview"
|
||||
|
||||
# Google Gemini (Vertex AI)
|
||||
gemini:
|
||||
project: "" # Set via GOOGLE_CLOUD_PROJECT env var
|
||||
region: "us-central1"
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Provider | Description |
|
||||
|----------|----------|-------------|
|
||||
| `OPENAI_API_KEY` | OpenAI | API key |
|
||||
| `ANTHROPIC_API_KEY` | Anthropic | API key |
|
||||
| `AZURE_OPENAI_ENDPOINT` | Azure | Service endpoint URL |
|
||||
| `AZURE_OPENAI_API_KEY` | Azure | API key |
|
||||
| `AZURE_OPENAI_DEPLOYMENT` | Azure | Deployment name |
|
||||
| `GOOGLE_API_KEY` | Gemini | API key (public API) |
|
||||
| `GOOGLE_CLOUD_PROJECT` | Vertex AI | GCP project ID |
|
||||
| `OPENROUTER_API_KEY` | OpenRouter | API key |
|
||||
| `OLLAMA_HOST` | Ollama | Server URL (default: localhost:11434) |
|
||||
|
||||
---
|
||||
|
||||
## Enterprise Features
|
||||
|
||||
- **Audit Logging**: JSONL logs with daily rotation
|
||||
- **Audit Logging**: JSONL logs with integrity checksums and daily rotation
|
||||
- **Compliance**: HIPAA, SOC2, PCI-DSS, GDPR support with configurable rules
|
||||
- **CODEOWNERS Enforcement**: Validate approvals against CODEOWNERS file
|
||||
- **Notifications**: Slack/Discord webhooks for critical findings
|
||||
- **SAST Integration**: Bandit, Semgrep, Trivy for advanced security scanning
|
||||
- **Metrics**: Prometheus-compatible export
|
||||
- **Rate Limiting**: Configurable request limits
|
||||
- **Rate Limiting**: Configurable request limits and timeouts
|
||||
- **Custom Security Rules**: Define your own patterns via YAML
|
||||
- **Tool Calling**: LLM function calling for interactive chat
|
||||
- **Ignore Patterns**: `.ai-reviewignore` for excluding files from review
|
||||
|
||||
### Notifications Configuration
|
||||
|
||||
```yaml
|
||||
# In config.yml
|
||||
notifications:
|
||||
enabled: true
|
||||
threshold: "warning" # info | warning | error | critical
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
webhook_url: "" # Set via SLACK_WEBHOOK_URL env var
|
||||
channel: "#code-review"
|
||||
|
||||
discord:
|
||||
enabled: true
|
||||
webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var
|
||||
```
|
||||
|
||||
### Compliance Configuration
|
||||
|
||||
```yaml
|
||||
compliance:
|
||||
enabled: true
|
||||
audit:
|
||||
enabled: true
|
||||
log_file: "audit.log"
|
||||
retention_days: 90
|
||||
codeowners:
|
||||
enabled: true
|
||||
require_approval: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
296
tests/test_dispatcher.py
Normal file
296
tests/test_dispatcher.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Test Suite for Dispatcher
|
||||
|
||||
Tests for event routing and agent execution.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDispatcherCreation:
|
||||
"""Test dispatcher initialization."""
|
||||
|
||||
def test_create_dispatcher(self):
|
||||
"""Test creating dispatcher."""
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
assert dispatcher is not None
|
||||
assert dispatcher.agents == []
|
||||
|
||||
def test_create_dispatcher_with_config(self):
|
||||
"""Test creating dispatcher with config."""
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
config = {"dispatcher": {"max_workers": 4}}
|
||||
dispatcher = Dispatcher(config=config)
|
||||
assert dispatcher.config == config
|
||||
|
||||
|
||||
class TestAgentRegistration:
|
||||
"""Test agent registration."""
|
||||
|
||||
def test_register_agent(self):
|
||||
"""Test registering an agent."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "test"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="done")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
dispatcher.register_agent(agent)
|
||||
|
||||
assert len(dispatcher.agents) == 1
|
||||
assert dispatcher.agents[0] == agent
|
||||
|
||||
def test_register_multiple_agents(self):
|
||||
"""Test registering multiple agents."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent1(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "type1"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent1")
|
||||
|
||||
class MockAgent2(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "type2"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent2")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
MockAgent1(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
dispatcher.register_agent(
|
||||
MockAgent2(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
assert len(dispatcher.agents) == 2
|
||||
|
||||
|
||||
class TestEventRouting:
|
||||
"""Test event routing to agents."""
|
||||
|
||||
def test_route_to_matching_agent(self):
|
||||
"""Test that events are routed to matching agents."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="handled")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
dispatcher.register_agent(agent)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={"action": "opened"},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
assert len(result.agents_run) == 1
|
||||
assert result.results[0].success is True
|
||||
|
||||
def test_no_matching_agent(self):
|
||||
"""Test dispatch when no agent matches."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="handled")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
dispatcher.register_agent(agent)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="pull_request", # Different event type
|
||||
event_data={"action": "opened"},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
assert len(result.agents_run) == 0
|
||||
|
||||
def test_multiple_matching_agents(self):
|
||||
"""Test dispatch when multiple agents match."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent1(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent1")
|
||||
|
||||
class MockAgent2(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent2")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
MockAgent1(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
dispatcher.register_agent(
|
||||
MockAgent2(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={"action": "opened"},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
assert len(result.agents_run) == 2
|
||||
|
||||
|
||||
class TestDispatchResult:
|
||||
"""Test dispatch result structure."""
|
||||
|
||||
def test_result_structure(self):
|
||||
"""Test DispatchResult has correct structure."""
|
||||
from dispatcher import DispatchResult
|
||||
|
||||
result = DispatchResult(
|
||||
agents_run=["Agent1", "Agent2"],
|
||||
results=[],
|
||||
errors=[],
|
||||
)
|
||||
|
||||
assert result.agents_run == ["Agent1", "Agent2"]
|
||||
assert result.results == []
|
||||
assert result.errors == []
|
||||
|
||||
def test_result_with_errors(self):
|
||||
"""Test DispatchResult with errors."""
|
||||
from dispatcher import DispatchResult
|
||||
|
||||
result = DispatchResult(
|
||||
agents_run=["Agent1"],
|
||||
results=[],
|
||||
errors=["Error 1", "Error 2"],
|
||||
)
|
||||
|
||||
assert len(result.errors) == 2
|
||||
|
||||
|
||||
class TestAgentExecution:
|
||||
"""Test agent execution through dispatcher."""
|
||||
|
||||
def test_agent_receives_context(self):
|
||||
"""Test that agents receive proper context."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
received_context = None
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return True
|
||||
|
||||
def execute(self, context):
|
||||
nonlocal received_context
|
||||
received_context = context
|
||||
return AgentResult(success=True, message="done")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={"action": "opened", "issue": {"number": 123}},
|
||||
owner="testowner",
|
||||
repo="testrepo",
|
||||
)
|
||||
|
||||
assert received_context is not None
|
||||
assert received_context.owner == "testowner"
|
||||
assert received_context.repo == "testrepo"
|
||||
assert received_context.event_type == "issues"
|
||||
assert received_context.event_data["action"] == "opened"
|
||||
|
||||
def test_agent_failure_captured(self):
|
||||
"""Test that agent failures are captured in results."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class FailingAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return True
|
||||
|
||||
def execute(self, context):
|
||||
raise Exception("Test error")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
FailingAgent(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
# Agent should still be in agents_run
|
||||
assert len(result.agents_run) == 1
|
||||
# Result should indicate failure
|
||||
assert result.results[0].success is False
|
||||
|
||||
|
||||
class TestGetDispatcher:
|
||||
"""Test get_dispatcher factory function."""
|
||||
|
||||
def test_get_dispatcher_returns_singleton(self):
|
||||
"""Test that get_dispatcher returns configured dispatcher."""
|
||||
from dispatcher import get_dispatcher
|
||||
|
||||
dispatcher = get_dispatcher()
|
||||
assert dispatcher is not None
|
||||
|
||||
def test_get_dispatcher_with_config(self):
|
||||
"""Test get_dispatcher with custom config."""
|
||||
from dispatcher import get_dispatcher
|
||||
|
||||
config = {"test": "value"}
|
||||
dispatcher = get_dispatcher(config=config)
|
||||
assert dispatcher.config.get("test") == "value"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
456
tests/test_llm_client.py
Normal file
456
tests/test_llm_client.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Test Suite for LLM Client
|
||||
|
||||
Tests for LLM client functionality including provider support,
|
||||
tool calling, and JSON parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLLMClientCreation:
|
||||
"""Test LLM client initialization."""
|
||||
|
||||
def test_create_openai_client(self):
|
||||
"""Test creating OpenAI client."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient(
|
||||
provider="openai",
|
||||
config={"model": "gpt-4o-mini", "api_key": "test-key"},
|
||||
)
|
||||
assert client.provider_name == "openai"
|
||||
|
||||
def test_create_openrouter_client(self):
|
||||
"""Test creating OpenRouter client."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient(
|
||||
provider="openrouter",
|
||||
config={"model": "anthropic/claude-3.5-sonnet", "api_key": "test-key"},
|
||||
)
|
||||
assert client.provider_name == "openrouter"
|
||||
|
||||
def test_create_ollama_client(self):
|
||||
"""Test creating Ollama client."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient(
|
||||
provider="ollama",
|
||||
config={"model": "codellama:13b", "host": "http://localhost:11434"},
|
||||
)
|
||||
assert client.provider_name == "ollama"
|
||||
|
||||
def test_invalid_provider_raises_error(self):
|
||||
"""Test that invalid provider raises ValueError."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider"):
|
||||
LLMClient(provider="invalid_provider")
|
||||
|
||||
def test_from_config_openai(self):
|
||||
"""Test creating client from config dict."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": {"openai": "gpt-4o-mini"},
|
||||
"temperature": 0,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
client = LLMClient.from_config(config)
|
||||
assert client.provider_name == "openai"
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Test LLM response dataclass."""
|
||||
|
||||
def test_response_creation(self):
|
||||
"""Test creating LLMResponse."""
|
||||
from clients.llm_client import LLMResponse
|
||||
|
||||
response = LLMResponse(
|
||||
content="Test response",
|
||||
model="gpt-4o-mini",
|
||||
provider="openai",
|
||||
tokens_used=100,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.model == "gpt-4o-mini"
|
||||
assert response.provider == "openai"
|
||||
assert response.tokens_used == 100
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.tool_calls is None
|
||||
|
||||
def test_response_with_tool_calls(self):
|
||||
"""Test LLMResponse with tool calls."""
|
||||
from clients.llm_client import LLMResponse, ToolCall
|
||||
|
||||
tool_calls = [ToolCall(id="call_1", name="search", arguments={"query": "test"})]
|
||||
|
||||
response = LLMResponse(
|
||||
content="",
|
||||
model="gpt-4o-mini",
|
||||
provider="openai",
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
assert response.tool_calls is not None
|
||||
assert len(response.tool_calls) == 1
|
||||
assert response.tool_calls[0].name == "search"
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
"""Test ToolCall dataclass."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test creating ToolCall."""
|
||||
from clients.llm_client import ToolCall
|
||||
|
||||
tool_call = ToolCall(
|
||||
id="call_123",
|
||||
name="search_codebase",
|
||||
arguments={"query": "authentication", "file_pattern": "*.py"},
|
||||
)
|
||||
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.name == "search_codebase"
|
||||
assert tool_call.arguments["query"] == "authentication"
|
||||
assert tool_call.arguments["file_pattern"] == "*.py"
|
||||
|
||||
|
||||
class TestJSONParsing:
|
||||
"""Test JSON extraction and parsing."""
|
||||
|
||||
def test_parse_direct_json(self):
|
||||
"""Test parsing direct JSON response."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = '{"key": "value", "number": 42}'
|
||||
result = client._extract_json(content)
|
||||
|
||||
assert result["key"] == "value"
|
||||
assert result["number"] == 42
|
||||
|
||||
def test_parse_json_in_code_block(self):
|
||||
"""Test parsing JSON in markdown code block."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """Here is the analysis:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "bug",
|
||||
"priority": "high"
|
||||
}
|
||||
```
|
||||
|
||||
That's my analysis."""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["type"] == "bug"
|
||||
assert result["priority"] == "high"
|
||||
|
||||
def test_parse_json_in_plain_code_block(self):
|
||||
"""Test parsing JSON in plain code block (no json specifier)."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """Analysis:
|
||||
|
||||
```
|
||||
{"status": "success", "count": 5}
|
||||
```
|
||||
"""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["status"] == "success"
|
||||
assert result["count"] == 5
|
||||
|
||||
def test_parse_json_with_preamble(self):
|
||||
"""Test parsing JSON with text before it."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """Based on my analysis, here is the result:
|
||||
{"findings": ["issue1", "issue2"], "severity": "medium"}
|
||||
"""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["findings"] == ["issue1", "issue2"]
|
||||
assert result["severity"] == "medium"
|
||||
|
||||
def test_parse_json_with_postamble(self):
|
||||
"""Test parsing JSON with text after it."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """{"result": true}
|
||||
|
||||
Let me know if you need more details."""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["result"] is True
|
||||
|
||||
def test_parse_nested_json(self):
|
||||
"""Test parsing nested JSON objects."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """{
|
||||
"outer": {
|
||||
"inner": {
|
||||
"value": "deep"
|
||||
}
|
||||
},
|
||||
"array": [1, 2, 3]
|
||||
}"""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["outer"]["inner"]["value"] == "deep"
|
||||
assert result["array"] == [1, 2, 3]
|
||||
|
||||
def test_parse_invalid_json_raises_error(self):
|
||||
"""Test that invalid JSON raises ValueError."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = "This is not JSON at all"
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse JSON"):
|
||||
client._extract_json(content)
|
||||
|
||||
def test_parse_truncated_json_raises_error(self):
|
||||
"""Test that truncated JSON raises ValueError."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = '{"key": "value", "incomplete'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client._extract_json(content)
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
"""Test OpenAI provider."""
|
||||
|
||||
def test_provider_creation(self):
|
||||
"""Test OpenAI provider initialization."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.5,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
assert provider.model == "gpt-4o-mini"
|
||||
assert provider.temperature == 0.5
|
||||
assert provider.max_tokens == 2048
|
||||
|
||||
def test_provider_requires_api_key(self):
|
||||
"""Test that calling without API key raises error."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
provider = OpenAIProvider(api_key="")
|
||||
|
||||
with pytest.raises(ValueError, match="API key is required"):
|
||||
provider.call("test prompt")
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_provider_call_success(self, mock_post):
|
||||
"""Test successful API call."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": "Test response"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {"total_tokens": 50},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OpenAIProvider(api_key="test-key")
|
||||
response = provider.call("Hello")
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.provider == "openai"
|
||||
assert response.tokens_used == 50
|
||||
|
||||
|
||||
class TestOpenRouterProvider:
|
||||
"""Test OpenRouter provider."""
|
||||
|
||||
def test_provider_creation(self):
|
||||
"""Test OpenRouter provider initialization."""
|
||||
from clients.llm_client import OpenRouterProvider
|
||||
|
||||
provider = OpenRouterProvider(
|
||||
api_key="test-key",
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
)
|
||||
|
||||
assert provider.model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_provider_call_success(self, mock_post):
|
||||
"""Test successful API call."""
|
||||
from clients.llm_client import OpenRouterProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": "Claude response"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "anthropic/claude-3.5-sonnet",
|
||||
"usage": {"total_tokens": 75},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
response = provider.call("Hello")
|
||||
|
||||
assert response.content == "Claude response"
|
||||
assert response.provider == "openrouter"
|
||||
|
||||
|
||||
class TestOllamaProvider:
|
||||
"""Test Ollama provider."""
|
||||
|
||||
def test_provider_creation(self):
|
||||
"""Test Ollama provider initialization."""
|
||||
from clients.llm_client import OllamaProvider
|
||||
|
||||
provider = OllamaProvider(
|
||||
host="http://localhost:11434",
|
||||
model="codellama:13b",
|
||||
)
|
||||
|
||||
assert provider.model == "codellama:13b"
|
||||
assert provider.host == "http://localhost:11434"
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_provider_call_success(self, mock_post):
|
||||
"""Test successful API call."""
|
||||
from clients.llm_client import OllamaProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": "Ollama response",
|
||||
"model": "codellama:13b",
|
||||
"done": True,
|
||||
"eval_count": 30,
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OllamaProvider()
|
||||
response = provider.call("Hello")
|
||||
|
||||
assert response.content == "Ollama response"
|
||||
assert response.provider == "ollama"
|
||||
|
||||
|
||||
class TestToolCalling:
|
||||
"""Test tool/function calling support."""
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_openai_tool_calling(self, mock_post):
|
||||
"""Test OpenAI tool calling."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"function": {
|
||||
"name": "search_codebase",
|
||||
"arguments": '{"query": "auth"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {"total_tokens": 100},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OpenAIProvider(api_key="test-key")
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_codebase",
|
||||
"description": "Search the codebase",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = provider.call_with_tools(
|
||||
messages=[{"role": "user", "content": "Search for auth"}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
assert response.tool_calls is not None
|
||||
assert len(response.tool_calls) == 1
|
||||
assert response.tool_calls[0].name == "search_codebase"
|
||||
assert response.tool_calls[0].arguments["query"] == "auth"
|
||||
|
||||
def test_ollama_tool_calling_not_supported(self):
|
||||
"""Test that Ollama raises NotImplementedError for tool calling."""
|
||||
from clients.llm_client import OllamaProvider
|
||||
|
||||
provider = OllamaProvider()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
provider.call_with_tools(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -2,20 +2,40 @@
|
||||
|
||||
This package contains the modular agent implementations for the
|
||||
enterprise AI code review system.
|
||||
|
||||
Core Agents:
|
||||
- PRAgent: Pull request review and analysis
|
||||
- IssueAgent: Issue triage and response
|
||||
- CodebaseAgent: Codebase health analysis
|
||||
- ChatAgent: Interactive chat with tool calling
|
||||
|
||||
Specialized Agents:
|
||||
- DependencyAgent: Dependency vulnerability scanning
|
||||
- TestCoverageAgent: Test coverage analysis and suggestions
|
||||
- ArchitectureAgent: Architecture compliance checking
|
||||
"""
|
||||
|
||||
from agents.architecture_agent import ArchitectureAgent
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from agents.chat_agent import ChatAgent
|
||||
from agents.codebase_agent import CodebaseAgent
|
||||
from agents.dependency_agent import DependencyAgent
|
||||
from agents.issue_agent import IssueAgent
|
||||
from agents.pr_agent import PRAgent
|
||||
from agents.test_coverage_agent import TestCoverageAgent
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"BaseAgent",
|
||||
"AgentContext",
|
||||
"AgentResult",
|
||||
# Core Agents
|
||||
"IssueAgent",
|
||||
"PRAgent",
|
||||
"CodebaseAgent",
|
||||
"ChatAgent",
|
||||
# Specialized Agents
|
||||
"DependencyAgent",
|
||||
"TestCoverageAgent",
|
||||
"ArchitectureAgent",
|
||||
]
|
||||
|
||||
547
tools/ai-review/agents/architecture_agent.py
Normal file
547
tools/ai-review/agents/architecture_agent.py
Normal file
@@ -0,0 +1,547 @@
|
||||
"""Architecture Compliance Agent
|
||||
|
||||
AI agent for enforcing architectural patterns and layer separation.
|
||||
Detects cross-layer violations and circular dependencies.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchitectureViolation:
|
||||
"""An architecture violation."""
|
||||
|
||||
file: str
|
||||
line: int
|
||||
violation_type: str # cross_layer, circular, naming, structure
|
||||
severity: str # HIGH, MEDIUM, LOW
|
||||
description: str
|
||||
recommendation: str
|
||||
source_layer: str | None = None
|
||||
target_layer: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchitectureReport:
|
||||
"""Report of architecture analysis."""
|
||||
|
||||
violations: list[ArchitectureViolation]
|
||||
layers_detected: dict[str, list[str]]
|
||||
circular_dependencies: list[tuple[str, str]]
|
||||
compliance_score: float
|
||||
recommendations: list[str]
|
||||
|
||||
|
||||
class ArchitectureAgent(BaseAgent):
|
||||
"""Agent for enforcing architectural compliance."""
|
||||
|
||||
# Marker for architecture comments
|
||||
ARCH_AI_MARKER = "<!-- AI_ARCHITECTURE_CHECK -->"
|
||||
|
||||
# Default layer definitions
|
||||
DEFAULT_LAYERS = {
|
||||
"api": {
|
||||
"patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"],
|
||||
"can_import": ["services", "models", "utils", "config"],
|
||||
"cannot_import": ["db", "repositories", "infrastructure"],
|
||||
},
|
||||
"services": {
|
||||
"patterns": ["services/", "usecases/", "application/"],
|
||||
"can_import": ["models", "repositories", "utils", "config"],
|
||||
"cannot_import": ["api", "controllers", "handlers"],
|
||||
},
|
||||
"repositories": {
|
||||
"patterns": ["repositories/", "repos/", "data/"],
|
||||
"can_import": ["models", "db", "utils", "config"],
|
||||
"cannot_import": ["api", "services", "controllers"],
|
||||
},
|
||||
"models": {
|
||||
"patterns": ["models/", "entities/", "domain/", "schemas/"],
|
||||
"can_import": ["utils", "config"],
|
||||
"cannot_import": ["api", "services", "repositories", "db"],
|
||||
},
|
||||
"db": {
|
||||
"patterns": ["db/", "database/", "infrastructure/"],
|
||||
"can_import": ["models", "config"],
|
||||
"cannot_import": ["api", "services"],
|
||||
},
|
||||
"utils": {
|
||||
"patterns": ["utils/", "helpers/", "common/", "lib/"],
|
||||
"can_import": ["config"],
|
||||
"cannot_import": [],
|
||||
},
|
||||
}
|
||||
|
||||
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||
"""Check if this agent handles the given event."""
|
||||
agent_config = self.config.get("agents", {}).get("architecture", {})
|
||||
if not agent_config.get("enabled", False):
|
||||
return False
|
||||
|
||||
# Handle PR events
|
||||
if event_type == "pull_request":
|
||||
action = event_data.get("action", "")
|
||||
if action in ["opened", "synchronize"]:
|
||||
return True
|
||||
|
||||
# Handle @codebot architecture command
|
||||
if event_type == "issue_comment":
|
||||
comment_body = event_data.get("comment", {}).get("body", "")
|
||||
mention_prefix = self.config.get("interaction", {}).get(
|
||||
"mention_prefix", "@codebot"
|
||||
)
|
||||
if f"{mention_prefix} architecture" in comment_body.lower():
|
||||
return True
|
||||
if f"{mention_prefix} arch-check" in comment_body.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def execute(self, context: AgentContext) -> AgentResult:
|
||||
"""Execute the architecture agent."""
|
||||
self.logger.info(f"Checking architecture for {context.owner}/{context.repo}")
|
||||
|
||||
actions_taken = []
|
||||
|
||||
# Get layer configuration
|
||||
agent_config = self.config.get("agents", {}).get("architecture", {})
|
||||
layers = agent_config.get("layers", self.DEFAULT_LAYERS)
|
||||
|
||||
# Determine issue number
|
||||
if context.event_type == "issue_comment":
|
||||
issue = context.event_data.get("issue", {})
|
||||
issue_number = issue.get("number")
|
||||
comment_author = (
|
||||
context.event_data.get("comment", {})
|
||||
.get("user", {})
|
||||
.get("login", "user")
|
||||
)
|
||||
is_pr = issue.get("pull_request") is not None
|
||||
else:
|
||||
pr = context.event_data.get("pull_request", {})
|
||||
issue_number = pr.get("number")
|
||||
comment_author = None
|
||||
is_pr = True
|
||||
|
||||
if is_pr and issue_number:
|
||||
# Analyze PR diff
|
||||
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
|
||||
report = self._analyze_diff(diff, layers)
|
||||
actions_taken.append(f"Analyzed PR diff for architecture violations")
|
||||
else:
|
||||
# Analyze repository structure
|
||||
report = self._analyze_repository(context.owner, context.repo, layers)
|
||||
actions_taken.append(f"Analyzed repository architecture")
|
||||
|
||||
# Post report
|
||||
if issue_number:
|
||||
comment = self._format_architecture_report(report, comment_author)
|
||||
self.upsert_comment(
|
||||
context.owner,
|
||||
context.repo,
|
||||
issue_number,
|
||||
comment,
|
||||
marker=self.ARCH_AI_MARKER,
|
||||
)
|
||||
actions_taken.append("Posted architecture report")
|
||||
|
||||
return AgentResult(
|
||||
success=len(report.violations) == 0 or report.compliance_score >= 0.8,
|
||||
message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance",
|
||||
data={
|
||||
"violations_count": len(report.violations),
|
||||
"compliance_score": report.compliance_score,
|
||||
"circular_dependencies": len(report.circular_dependencies),
|
||||
},
|
||||
actions_taken=actions_taken,
|
||||
)
|
||||
|
||||
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
|
||||
"""Get the PR diff."""
|
||||
try:
|
||||
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get PR diff: {e}")
|
||||
return ""
|
||||
|
||||
def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport:
|
||||
"""Analyze PR diff for architecture violations."""
|
||||
violations = []
|
||||
imports_by_file = {}
|
||||
|
||||
current_file = None
|
||||
current_language = None
|
||||
|
||||
for line in diff.splitlines():
|
||||
# Track current file
|
||||
if line.startswith("diff --git"):
|
||||
match = re.search(r"b/(.+)$", line)
|
||||
if match:
|
||||
current_file = match.group(1)
|
||||
current_language = self._detect_language(current_file)
|
||||
imports_by_file[current_file] = []
|
||||
|
||||
# Look for import statements in added lines
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
if current_file and current_language:
|
||||
imports = self._extract_imports(line[1:], current_language)
|
||||
imports_by_file.setdefault(current_file, []).extend(imports)
|
||||
|
||||
# Check for violations
|
||||
for file_path, imports in imports_by_file.items():
|
||||
source_layer = self._detect_layer(file_path, layers)
|
||||
if not source_layer:
|
||||
continue
|
||||
|
||||
layer_config = layers.get(source_layer, {})
|
||||
cannot_import = layer_config.get("cannot_import", [])
|
||||
|
||||
for imp in imports:
|
||||
target_layer = self._detect_layer_from_import(imp, layers)
|
||||
if target_layer and target_layer in cannot_import:
|
||||
violations.append(
|
||||
ArchitectureViolation(
|
||||
file=file_path,
|
||||
line=0, # Line number not tracked in this simple implementation
|
||||
violation_type="cross_layer",
|
||||
severity="HIGH",
|
||||
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
|
||||
recommendation=f"Move this import to an allowed layer or refactor the dependency",
|
||||
source_layer=source_layer,
|
||||
target_layer=target_layer,
|
||||
)
|
||||
)
|
||||
|
||||
# Detect circular dependencies
|
||||
circular = self._detect_circular_dependencies(imports_by_file, layers)
|
||||
|
||||
# Calculate compliance score
|
||||
total_imports = sum(len(imps) for imps in imports_by_file.values())
|
||||
if total_imports > 0:
|
||||
compliance = 1.0 - (len(violations) / max(total_imports, 1))
|
||||
else:
|
||||
compliance = 1.0
|
||||
|
||||
return ArchitectureReport(
|
||||
violations=violations,
|
||||
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
|
||||
circular_dependencies=circular,
|
||||
compliance_score=max(0.0, compliance),
|
||||
recommendations=self._generate_recommendations(violations),
|
||||
)
|
||||
|
||||
def _analyze_repository(
|
||||
self, owner: str, repo: str, layers: dict
|
||||
) -> ArchitectureReport:
|
||||
"""Analyze repository structure for architecture compliance."""
|
||||
violations = []
|
||||
imports_by_file = {}
|
||||
|
||||
# Collect files from each layer
|
||||
for layer_name, layer_config in layers.items():
|
||||
for pattern in layer_config.get("patterns", []):
|
||||
try:
|
||||
path = pattern.rstrip("/")
|
||||
contents = self.gitea.get_file_contents(owner, repo, path)
|
||||
if isinstance(contents, list):
|
||||
for item in contents[:20]: # Limit files per layer
|
||||
if item.get("type") == "file":
|
||||
filepath = item.get("path", "")
|
||||
imports = self._get_file_imports(owner, repo, filepath)
|
||||
imports_by_file[filepath] = imports
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for violations
|
||||
for file_path, imports in imports_by_file.items():
|
||||
source_layer = self._detect_layer(file_path, layers)
|
||||
if not source_layer:
|
||||
continue
|
||||
|
||||
layer_config = layers.get(source_layer, {})
|
||||
cannot_import = layer_config.get("cannot_import", [])
|
||||
|
||||
for imp in imports:
|
||||
target_layer = self._detect_layer_from_import(imp, layers)
|
||||
if target_layer and target_layer in cannot_import:
|
||||
violations.append(
|
||||
ArchitectureViolation(
|
||||
file=file_path,
|
||||
line=0,
|
||||
violation_type="cross_layer",
|
||||
severity="HIGH",
|
||||
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
|
||||
recommendation=f"Refactor to remove dependency on '{target_layer}'",
|
||||
source_layer=source_layer,
|
||||
target_layer=target_layer,
|
||||
)
|
||||
)
|
||||
|
||||
# Detect circular dependencies
|
||||
circular = self._detect_circular_dependencies(imports_by_file, layers)
|
||||
|
||||
# Calculate compliance
|
||||
total_imports = sum(len(imps) for imps in imports_by_file.values())
|
||||
if total_imports > 0:
|
||||
compliance = 1.0 - (len(violations) / max(total_imports, 1))
|
||||
else:
|
||||
compliance = 1.0
|
||||
|
||||
return ArchitectureReport(
|
||||
violations=violations,
|
||||
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
|
||||
circular_dependencies=circular,
|
||||
compliance_score=max(0.0, compliance),
|
||||
recommendations=self._generate_recommendations(violations),
|
||||
)
|
||||
|
||||
def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]:
|
||||
"""Get imports from a file."""
|
||||
imports = []
|
||||
language = self._detect_language(filepath)
|
||||
|
||||
if not language:
|
||||
return imports
|
||||
|
||||
try:
|
||||
content_data = self.gitea.get_file_contents(owner, repo, filepath)
|
||||
if content_data.get("content"):
|
||||
content = base64.b64decode(content_data["content"]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
for line in content.splitlines():
|
||||
imports.extend(self._extract_imports(line, language))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return imports
|
||||
|
||||
def _detect_language(self, filepath: str) -> str | None:
|
||||
"""Detect programming language from file path."""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".go": "go",
|
||||
".java": "java",
|
||||
".rb": "ruby",
|
||||
}
|
||||
ext = os.path.splitext(filepath)[1]
|
||||
return ext_map.get(ext)
|
||||
|
||||
def _extract_imports(self, line: str, language: str) -> list[str]:
|
||||
"""Extract import statements from a line of code."""
|
||||
imports = []
|
||||
line = line.strip()
|
||||
|
||||
if language == "python":
|
||||
# from x import y, import x
|
||||
match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line)
|
||||
if match:
|
||||
imp = match.group(1) or match.group(2)
|
||||
if imp:
|
||||
imports.append(imp.split(".")[0])
|
||||
|
||||
elif language in ("javascript", "typescript"):
|
||||
# import x from 'y', require('y')
|
||||
match = re.search(
|
||||
r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line
|
||||
)
|
||||
if match:
|
||||
imp = match.group(1) or match.group(2)
|
||||
if imp and not imp.startswith("."):
|
||||
imports.append(imp.split("/")[0])
|
||||
elif imp:
|
||||
imports.append(imp)
|
||||
|
||||
elif language == "go":
|
||||
# import "package"
|
||||
match = re.search(r'import\s+["\']([^"\']+)["\']', line)
|
||||
if match:
|
||||
imports.append(match.group(1).split("/")[-1])
|
||||
|
||||
elif language == "java":
|
||||
# import package.Class
|
||||
match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line)
|
||||
if match:
|
||||
parts = match.group(1).split(".")
|
||||
if len(parts) > 1:
|
||||
imports.append(parts[-2]) # Package name
|
||||
|
||||
return imports
|
||||
|
||||
def _detect_layer(self, filepath: str, layers: dict) -> str | None:
|
||||
"""Detect which layer a file belongs to."""
|
||||
for layer_name, layer_config in layers.items():
|
||||
for pattern in layer_config.get("patterns", []):
|
||||
if pattern.rstrip("/") in filepath:
|
||||
return layer_name
|
||||
return None
|
||||
|
||||
def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None:
|
||||
"""Detect which layer an import refers to."""
|
||||
for layer_name, layer_config in layers.items():
|
||||
for pattern in layer_config.get("patterns", []):
|
||||
pattern_name = pattern.rstrip("/").split("/")[-1]
|
||||
if pattern_name in import_path or import_path.startswith(pattern_name):
|
||||
return layer_name
|
||||
return None
|
||||
|
||||
def _detect_circular_dependencies(
|
||||
self, imports_by_file: dict, layers: dict
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Detect circular dependencies between layers."""
|
||||
circular = []
|
||||
|
||||
# Build layer dependency graph
|
||||
layer_deps = {}
|
||||
for file_path, imports in imports_by_file.items():
|
||||
source_layer = self._detect_layer(file_path, layers)
|
||||
if not source_layer:
|
||||
continue
|
||||
|
||||
if source_layer not in layer_deps:
|
||||
layer_deps[source_layer] = set()
|
||||
|
||||
for imp in imports:
|
||||
target_layer = self._detect_layer_from_import(imp, layers)
|
||||
if target_layer and target_layer != source_layer:
|
||||
layer_deps[source_layer].add(target_layer)
|
||||
|
||||
# Check for circular dependencies
|
||||
for layer_a, deps_a in layer_deps.items():
|
||||
for layer_b in deps_a:
|
||||
if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()):
|
||||
pair = tuple(sorted([layer_a, layer_b]))
|
||||
if pair not in circular:
|
||||
circular.append(pair)
|
||||
|
||||
return circular
|
||||
|
||||
def _group_files_by_layer(
|
||||
self, files: list[str], layers: dict
|
||||
) -> dict[str, list[str]]:
|
||||
"""Group files by their layer."""
|
||||
grouped = {}
|
||||
for filepath in files:
|
||||
layer = self._detect_layer(filepath, layers)
|
||||
if layer:
|
||||
if layer not in grouped:
|
||||
grouped[layer] = []
|
||||
grouped[layer].append(filepath)
|
||||
return grouped
|
||||
|
||||
def _generate_recommendations(
|
||||
self, violations: list[ArchitectureViolation]
|
||||
) -> list[str]:
|
||||
"""Generate recommendations based on violations."""
|
||||
recommendations = []
|
||||
|
||||
# Count violations by type
|
||||
cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer")
|
||||
|
||||
if cross_layer > 0:
|
||||
recommendations.append(
|
||||
f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces"
|
||||
)
|
||||
|
||||
if cross_layer > 5:
|
||||
recommendations.append(
|
||||
"Consider using dependency injection to reduce coupling between layers"
|
||||
)
|
||||
|
||||
return recommendations
|
||||
|
||||
def _format_architecture_report(
|
||||
self, report: ArchitectureReport, user: str | None
|
||||
) -> str:
|
||||
"""Format the architecture report as a comment."""
|
||||
lines = []
|
||||
|
||||
if user:
|
||||
lines.append(f"@{user}")
|
||||
lines.append("")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"{self.AI_DISCLAIMER}",
|
||||
"",
|
||||
"## 🏗️ Architecture Compliance Check",
|
||||
"",
|
||||
"### Summary",
|
||||
"",
|
||||
f"| Metric | Value |",
|
||||
f"|--------|-------|",
|
||||
f"| Compliance Score | {report.compliance_score:.0%} |",
|
||||
f"| Violations | {len(report.violations)} |",
|
||||
f"| Circular Dependencies | {len(report.circular_dependencies)} |",
|
||||
f"| Layers Detected | {len(report.layers_detected)} |",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
# Compliance bar
|
||||
filled = int(report.compliance_score * 10)
|
||||
bar = "█" * filled + "░" * (10 - filled)
|
||||
lines.append(f"`[{bar}]` {report.compliance_score:.0%}")
|
||||
lines.append("")
|
||||
|
||||
# Violations
|
||||
if report.violations:
|
||||
lines.append("### 🚨 Violations")
|
||||
lines.append("")
|
||||
|
||||
for v in report.violations[:10]: # Limit display
|
||||
severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
|
||||
lines.append(
|
||||
f"{severity_emoji.get(v.severity, '⚪')} **{v.violation_type.upper()}** in `{v.file}`"
|
||||
)
|
||||
lines.append(f" - {v.description}")
|
||||
lines.append(f" - 💡 {v.recommendation}")
|
||||
lines.append("")
|
||||
|
||||
if len(report.violations) > 10:
|
||||
lines.append(f"*... and {len(report.violations) - 10} more violations*")
|
||||
lines.append("")
|
||||
|
||||
# Circular dependencies
|
||||
if report.circular_dependencies:
|
||||
lines.append("### 🔄 Circular Dependencies")
|
||||
lines.append("")
|
||||
for a, b in report.circular_dependencies:
|
||||
lines.append(f"- `{a}` ↔ `{b}`")
|
||||
lines.append("")
|
||||
|
||||
# Layers detected
|
||||
if report.layers_detected:
|
||||
lines.append("### 📁 Layers Detected")
|
||||
lines.append("")
|
||||
for layer, files in report.layers_detected.items():
|
||||
lines.append(f"- **{layer}**: {len(files)} files")
|
||||
lines.append("")
|
||||
|
||||
# Recommendations
|
||||
if report.recommendations:
|
||||
lines.append("### 💡 Recommendations")
|
||||
lines.append("")
|
||||
for rec in report.recommendations:
|
||||
lines.append(f"- {rec}")
|
||||
lines.append("")
|
||||
|
||||
# Overall status
|
||||
if report.compliance_score >= 0.9:
|
||||
lines.append("---")
|
||||
lines.append("✅ **Excellent architecture compliance!**")
|
||||
elif report.compliance_score >= 0.7:
|
||||
lines.append("---")
|
||||
lines.append("⚠️ **Some architectural issues to address**")
|
||||
else:
|
||||
lines.append("---")
|
||||
lines.append("❌ **Significant architectural violations detected**")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -65,9 +65,10 @@ class BaseAgent(ABC):
|
||||
self.llm = llm_client or LLMClient.from_config(self.config)
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
# Rate limiting
|
||||
# Rate limiting - now configurable
|
||||
self._last_request_time = 0.0
|
||||
self._min_request_interval = 1.0 # seconds
|
||||
rate_limits = self.config.get("rate_limits", {})
|
||||
self._min_request_interval = rate_limits.get("min_interval", 1.0) # seconds
|
||||
|
||||
@staticmethod
|
||||
def _load_config() -> dict:
|
||||
|
||||
548
tools/ai-review/agents/dependency_agent.py
Normal file
548
tools/ai-review/agents/dependency_agent.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""Dependency Security Agent
|
||||
|
||||
AI agent for scanning dependency files for known vulnerabilities
|
||||
and outdated packages. Supports multiple package managers.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
class DependencyFinding:
|
||||
"""A security finding in a dependency."""
|
||||
|
||||
package: str
|
||||
version: str
|
||||
severity: str # CRITICAL, HIGH, MEDIUM, LOW
|
||||
vulnerability_id: str # CVE, GHSA, etc.
|
||||
title: str
|
||||
description: str
|
||||
fixed_version: str | None = None
|
||||
references: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DependencyReport:
|
||||
"""Report of dependency analysis."""
|
||||
|
||||
total_packages: int
|
||||
vulnerable_packages: int
|
||||
outdated_packages: int
|
||||
findings: list[DependencyFinding]
|
||||
recommendations: list[str]
|
||||
files_scanned: list[str]
|
||||
|
||||
|
||||
class DependencyAgent(BaseAgent):
|
||||
"""Agent for scanning dependencies for security vulnerabilities."""
|
||||
|
||||
# Marker for dependency comments
|
||||
DEP_AI_MARKER = "<!-- AI_DEPENDENCY_SCAN -->"
|
||||
|
||||
# Supported dependency files
|
||||
DEPENDENCY_FILES = {
|
||||
"python": ["requirements.txt", "Pipfile", "pyproject.toml", "setup.py"],
|
||||
"javascript": ["package.json", "package-lock.json", "yarn.lock"],
|
||||
"ruby": ["Gemfile", "Gemfile.lock"],
|
||||
"go": ["go.mod", "go.sum"],
|
||||
"rust": ["Cargo.toml", "Cargo.lock"],
|
||||
"java": ["pom.xml", "build.gradle", "build.gradle.kts"],
|
||||
"php": ["composer.json", "composer.lock"],
|
||||
"dotnet": ["*.csproj", "packages.config", "*.fsproj"],
|
||||
}
|
||||
|
||||
# Common vulnerable package patterns
|
||||
KNOWN_VULNERABILITIES = {
|
||||
"python": {
|
||||
"requests": {
|
||||
"< 2.31.0": "CVE-2023-32681 - Proxy-Authorization header leak"
|
||||
},
|
||||
"urllib3": {
|
||||
"< 2.0.7": "CVE-2023-45803 - Request body not stripped on redirects"
|
||||
},
|
||||
"cryptography": {"< 41.0.0": "Multiple CVEs - Update recommended"},
|
||||
"pillow": {"< 10.0.0": "CVE-2023-4863 - WebP vulnerability"},
|
||||
"django": {"< 4.2.0": "Multiple security fixes"},
|
||||
"flask": {"< 2.3.0": "Security improvements"},
|
||||
"pyyaml": {"< 6.0": "CVE-2020-14343 - Arbitrary code execution"},
|
||||
"jinja2": {"< 3.1.0": "Security fixes"},
|
||||
},
|
||||
"javascript": {
|
||||
"lodash": {"< 4.17.21": "CVE-2021-23337 - Prototype pollution"},
|
||||
"axios": {"< 1.6.0": "CVE-2023-45857 - CSRF vulnerability"},
|
||||
"express": {"< 4.18.0": "Security updates"},
|
||||
"jquery": {"< 3.5.0": "XSS vulnerabilities"},
|
||||
"minimist": {"< 1.2.6": "Prototype pollution"},
|
||||
"node-fetch": {"< 3.3.0": "Security fixes"},
|
||||
},
|
||||
}
|
||||
|
||||
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||
"""Check if this agent handles the given event."""
|
||||
agent_config = self.config.get("agents", {}).get("dependency", {})
|
||||
if not agent_config.get("enabled", True):
|
||||
return False
|
||||
|
||||
# Handle PR events that modify dependency files
|
||||
if event_type == "pull_request":
|
||||
action = event_data.get("action", "")
|
||||
if action in ["opened", "synchronize"]:
|
||||
# Check if any dependency files are modified
|
||||
files = event_data.get("files", [])
|
||||
for f in files:
|
||||
if self._is_dependency_file(f.get("filename", "")):
|
||||
return True
|
||||
|
||||
# Handle @codebot check-deps command
|
||||
if event_type == "issue_comment":
|
||||
comment_body = event_data.get("comment", {}).get("body", "")
|
||||
mention_prefix = self.config.get("interaction", {}).get(
|
||||
"mention_prefix", "@codebot"
|
||||
)
|
||||
if f"{mention_prefix} check-deps" in comment_body.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_dependency_file(self, filename: str) -> bool:
|
||||
"""Check if a file is a dependency file."""
|
||||
basename = os.path.basename(filename)
|
||||
for lang, files in self.DEPENDENCY_FILES.items():
|
||||
for pattern in files:
|
||||
if pattern.startswith("*"):
|
||||
if basename.endswith(pattern[1:]):
|
||||
return True
|
||||
elif basename == pattern:
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute(self, context: AgentContext) -> AgentResult:
|
||||
"""Execute the dependency agent."""
|
||||
self.logger.info(f"Scanning dependencies for {context.owner}/{context.repo}")
|
||||
|
||||
actions_taken = []
|
||||
|
||||
# Determine if this is a command or PR event
|
||||
if context.event_type == "issue_comment":
|
||||
issue = context.event_data.get("issue", {})
|
||||
issue_number = issue.get("number")
|
||||
comment_author = (
|
||||
context.event_data.get("comment", {})
|
||||
.get("user", {})
|
||||
.get("login", "user")
|
||||
)
|
||||
else:
|
||||
pr = context.event_data.get("pull_request", {})
|
||||
issue_number = pr.get("number")
|
||||
comment_author = None
|
||||
|
||||
# Collect dependency files
|
||||
dep_files = self._collect_dependency_files(context.owner, context.repo)
|
||||
if not dep_files:
|
||||
message = "No dependency files found in repository."
|
||||
if issue_number:
|
||||
self.gitea.create_issue_comment(
|
||||
context.owner,
|
||||
context.repo,
|
||||
issue_number,
|
||||
f"{self.AI_DISCLAIMER}\n\n{message}",
|
||||
)
|
||||
return AgentResult(
|
||||
success=True,
|
||||
message=message,
|
||||
)
|
||||
|
||||
actions_taken.append(f"Found {len(dep_files)} dependency files")
|
||||
|
||||
# Analyze dependencies
|
||||
report = self._analyze_dependencies(context.owner, context.repo, dep_files)
|
||||
actions_taken.append(f"Analyzed {report.total_packages} packages")
|
||||
|
||||
# Run external scanners if available
|
||||
external_findings = self._run_external_scanners(context.owner, context.repo)
|
||||
if external_findings:
|
||||
report.findings.extend(external_findings)
|
||||
actions_taken.append(
|
||||
f"External scanner found {len(external_findings)} issues"
|
||||
)
|
||||
|
||||
# Generate and post report
|
||||
if issue_number:
|
||||
comment = self._format_dependency_report(report, comment_author)
|
||||
self.upsert_comment(
|
||||
context.owner,
|
||||
context.repo,
|
||||
issue_number,
|
||||
comment,
|
||||
marker=self.DEP_AI_MARKER,
|
||||
)
|
||||
actions_taken.append("Posted dependency report")
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
message=f"Dependency scan complete: {report.vulnerable_packages} vulnerable, {report.outdated_packages} outdated",
|
||||
data={
|
||||
"total_packages": report.total_packages,
|
||||
"vulnerable_packages": report.vulnerable_packages,
|
||||
"outdated_packages": report.outdated_packages,
|
||||
"findings_count": len(report.findings),
|
||||
},
|
||||
actions_taken=actions_taken,
|
||||
)
|
||||
|
||||
def _collect_dependency_files(
|
||||
self, owner: str, repo: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Collect all dependency files from the repository."""
|
||||
dep_files = {}
|
||||
|
||||
# Common paths to check
|
||||
paths_to_check = [
|
||||
"", # Root
|
||||
"backend/",
|
||||
"frontend/",
|
||||
"api/",
|
||||
"services/",
|
||||
]
|
||||
|
||||
for base_path in paths_to_check:
|
||||
for lang, filenames in self.DEPENDENCY_FILES.items():
|
||||
for filename in filenames:
|
||||
if filename.startswith("*"):
|
||||
continue # Skip glob patterns for now
|
||||
|
||||
filepath = f"{base_path}{filename}".lstrip("/")
|
||||
try:
|
||||
content_data = self.gitea.get_file_contents(
|
||||
owner, repo, filepath
|
||||
)
|
||||
if content_data.get("content"):
|
||||
content = base64.b64decode(content_data["content"]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
dep_files[filepath] = {
|
||||
"language": lang,
|
||||
"content": content,
|
||||
}
|
||||
except Exception:
|
||||
pass # File doesn't exist
|
||||
|
||||
return dep_files
|
||||
|
||||
def _analyze_dependencies(
|
||||
self, owner: str, repo: str, dep_files: dict
|
||||
) -> DependencyReport:
|
||||
"""Analyze dependency files for vulnerabilities."""
|
||||
findings = []
|
||||
total_packages = 0
|
||||
vulnerable_count = 0
|
||||
outdated_count = 0
|
||||
recommendations = []
|
||||
files_scanned = list(dep_files.keys())
|
||||
|
||||
for filepath, file_info in dep_files.items():
|
||||
lang = file_info["language"]
|
||||
content = file_info["content"]
|
||||
|
||||
if lang == "python":
|
||||
packages = self._parse_python_deps(content, filepath)
|
||||
elif lang == "javascript":
|
||||
packages = self._parse_javascript_deps(content, filepath)
|
||||
else:
|
||||
packages = []
|
||||
|
||||
total_packages += len(packages)
|
||||
|
||||
# Check for known vulnerabilities
|
||||
known_vulns = self.KNOWN_VULNERABILITIES.get(lang, {})
|
||||
for pkg_name, version in packages:
|
||||
if pkg_name.lower() in known_vulns:
|
||||
vuln_info = known_vulns[pkg_name.lower()]
|
||||
for version_constraint, vuln_desc in vuln_info.items():
|
||||
if self._version_matches_constraint(
|
||||
version, version_constraint
|
||||
):
|
||||
findings.append(
|
||||
DependencyFinding(
|
||||
package=pkg_name,
|
||||
version=version or "unknown",
|
||||
severity="HIGH",
|
||||
vulnerability_id=vuln_desc.split(" - ")[0]
|
||||
if " - " in vuln_desc
|
||||
else "VULN",
|
||||
title=vuln_desc,
|
||||
description=f"Package {pkg_name} version {version} has known vulnerabilities",
|
||||
fixed_version=version_constraint.replace("< ", ""),
|
||||
)
|
||||
)
|
||||
vulnerable_count += 1
|
||||
|
||||
# Add recommendations
|
||||
if vulnerable_count > 0:
|
||||
recommendations.append(
|
||||
f"Update {vulnerable_count} packages with known vulnerabilities"
|
||||
)
|
||||
if total_packages > 50:
|
||||
recommendations.append(
|
||||
"Consider auditing dependencies to reduce attack surface"
|
||||
)
|
||||
|
||||
return DependencyReport(
|
||||
total_packages=total_packages,
|
||||
vulnerable_packages=vulnerable_count,
|
||||
outdated_packages=outdated_count,
|
||||
findings=findings,
|
||||
recommendations=recommendations,
|
||||
files_scanned=files_scanned,
|
||||
)
|
||||
|
||||
def _parse_python_deps(
|
||||
self, content: str, filepath: str
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Parse Python dependency file."""
|
||||
packages = []
|
||||
|
||||
if "requirements" in filepath.lower():
|
||||
# requirements.txt format
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or line.startswith("-"):
|
||||
continue
|
||||
|
||||
# Parse package==version, package>=version, package
|
||||
match = re.match(r"([a-zA-Z0-9_-]+)([<>=!]+)?(.+)?", line)
|
||||
if match:
|
||||
pkg_name = match.group(1)
|
||||
version = match.group(3) if match.group(3) else None
|
||||
packages.append((pkg_name, version))
|
||||
|
||||
elif filepath.endswith("pyproject.toml"):
|
||||
# pyproject.toml format
|
||||
in_deps = False
|
||||
for line in content.splitlines():
|
||||
if (
|
||||
"[project.dependencies]" in line
|
||||
or "[tool.poetry.dependencies]" in line
|
||||
):
|
||||
in_deps = True
|
||||
continue
|
||||
if in_deps:
|
||||
if line.startswith("["):
|
||||
in_deps = False
|
||||
continue
|
||||
match = re.match(r'"?([a-zA-Z0-9_-]+)"?\s*[=<>]', line)
|
||||
if match:
|
||||
packages.append((match.group(1), None))
|
||||
|
||||
return packages
|
||||
|
||||
def _parse_javascript_deps(
|
||||
self, content: str, filepath: str
|
||||
) -> list[tuple[str, str | None]]:
|
||||
"""Parse JavaScript dependency file."""
|
||||
packages = []
|
||||
|
||||
if filepath.endswith("package.json"):
|
||||
try:
|
||||
data = json.loads(content)
|
||||
for dep_type in ["dependencies", "devDependencies"]:
|
||||
deps = data.get(dep_type, {})
|
||||
for name, version in deps.items():
|
||||
# Strip version prefixes like ^, ~, >=
|
||||
clean_version = re.sub(r"^[\^~>=<]+", "", version)
|
||||
packages.append((name, clean_version))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return packages
|
||||
|
||||
def _version_matches_constraint(self, version: str | None, constraint: str) -> bool:
|
||||
"""Check if version matches a vulnerability constraint."""
|
||||
if not version:
|
||||
return True # Assume vulnerable if version unknown
|
||||
|
||||
# Simple version comparison
|
||||
if constraint.startswith("< "):
|
||||
target = constraint[2:]
|
||||
try:
|
||||
return self._compare_versions(version, target) < 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _compare_versions(self, v1: str, v2: str) -> int:
|
||||
"""Compare two version strings. Returns -1, 0, or 1."""
|
||||
|
||||
def normalize(v):
|
||||
return [int(x) for x in re.sub(r"[^0-9.]", "", v).split(".") if x]
|
||||
|
||||
try:
|
||||
parts1 = normalize(v1)
|
||||
parts2 = normalize(v2)
|
||||
|
||||
for i in range(max(len(parts1), len(parts2))):
|
||||
p1 = parts1[i] if i < len(parts1) else 0
|
||||
p2 = parts2[i] if i < len(parts2) else 0
|
||||
if p1 < p2:
|
||||
return -1
|
||||
if p1 > p2:
|
||||
return 1
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _run_external_scanners(self, owner: str, repo: str) -> list[DependencyFinding]:
|
||||
"""Run external vulnerability scanners if available."""
|
||||
findings = []
|
||||
agent_config = self.config.get("agents", {}).get("dependency", {})
|
||||
|
||||
# Try pip-audit for Python
|
||||
if agent_config.get("pip_audit", False):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["pip-audit", "--format", "json"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
data = json.loads(result.stdout)
|
||||
for vuln in data.get("vulnerabilities", []):
|
||||
findings.append(
|
||||
DependencyFinding(
|
||||
package=vuln.get("name", ""),
|
||||
version=vuln.get("version", ""),
|
||||
severity=vuln.get("severity", "MEDIUM"),
|
||||
vulnerability_id=vuln.get("id", ""),
|
||||
title=vuln.get("description", "")[:100],
|
||||
description=vuln.get("description", ""),
|
||||
fixed_version=vuln.get("fix_versions", [None])[0],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"pip-audit not available: {e}")
|
||||
|
||||
# Try npm audit for JavaScript
|
||||
if agent_config.get("npm_audit", False):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "audit", "--json"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
data = json.loads(result.stdout)
|
||||
for vuln_id, vuln in data.get("vulnerabilities", {}).items():
|
||||
findings.append(
|
||||
DependencyFinding(
|
||||
package=vuln.get("name", vuln_id),
|
||||
version=vuln.get("range", ""),
|
||||
severity=vuln.get("severity", "moderate").upper(),
|
||||
vulnerability_id=vuln_id,
|
||||
title=vuln.get("title", ""),
|
||||
description=vuln.get("overview", ""),
|
||||
fixed_version=vuln.get("fixAvailable", {}).get("version"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"npm audit not available: {e}")
|
||||
|
||||
return findings
|
||||
|
||||
def _format_dependency_report(
|
||||
self, report: DependencyReport, user: str | None = None
|
||||
) -> str:
|
||||
"""Format the dependency report as a comment."""
|
||||
lines = []
|
||||
|
||||
if user:
|
||||
lines.append(f"@{user}")
|
||||
lines.append("")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"{self.AI_DISCLAIMER}",
|
||||
"",
|
||||
"## 🔍 Dependency Security Scan",
|
||||
"",
|
||||
"### Summary",
|
||||
"",
|
||||
f"| Metric | Value |",
|
||||
f"|--------|-------|",
|
||||
f"| Total Packages | {report.total_packages} |",
|
||||
f"| Vulnerable | {report.vulnerable_packages} |",
|
||||
f"| Outdated | {report.outdated_packages} |",
|
||||
f"| Files Scanned | {len(report.files_scanned)} |",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
# Findings by severity
|
||||
if report.findings:
|
||||
lines.append("### 🚨 Security Findings")
|
||||
lines.append("")
|
||||
|
||||
# Group by severity
|
||||
by_severity = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []}
|
||||
for finding in report.findings:
|
||||
sev = finding.severity.upper()
|
||||
if sev in by_severity:
|
||||
by_severity[sev].append(finding)
|
||||
|
||||
severity_emoji = {
|
||||
"CRITICAL": "🔴",
|
||||
"HIGH": "🟠",
|
||||
"MEDIUM": "🟡",
|
||||
"LOW": "🔵",
|
||||
}
|
||||
|
||||
for severity in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]:
|
||||
findings = by_severity[severity]
|
||||
if findings:
|
||||
lines.append(f"#### {severity_emoji[severity]} {severity}")
|
||||
lines.append("")
|
||||
for f in findings[:10]: # Limit display
|
||||
lines.append(f"- **{f.package}** `{f.version}`")
|
||||
lines.append(f" - {f.vulnerability_id}: {f.title}")
|
||||
if f.fixed_version:
|
||||
lines.append(f" - ✅ Fix: Upgrade to `{f.fixed_version}`")
|
||||
if len(findings) > 10:
|
||||
lines.append(f" - ... and {len(findings) - 10} more")
|
||||
lines.append("")
|
||||
|
||||
# Files scanned
|
||||
lines.append("### 📁 Files Scanned")
|
||||
lines.append("")
|
||||
for f in report.files_scanned:
|
||||
lines.append(f"- `{f}`")
|
||||
lines.append("")
|
||||
|
||||
# Recommendations
|
||||
if report.recommendations:
|
||||
lines.append("### 💡 Recommendations")
|
||||
lines.append("")
|
||||
for rec in report.recommendations:
|
||||
lines.append(f"- {rec}")
|
||||
lines.append("")
|
||||
|
||||
# Overall status
|
||||
if report.vulnerable_packages == 0:
|
||||
lines.append("---")
|
||||
lines.append("✅ **No known vulnerabilities detected**")
|
||||
else:
|
||||
lines.append("---")
|
||||
lines.append(
|
||||
f"⚠️ **{report.vulnerable_packages} vulnerable packages require attention**"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -365,9 +365,20 @@ class IssueAgent(BaseAgent):
|
||||
"commands", ["explain", "suggest", "security", "summarize", "triage"]
|
||||
)
|
||||
|
||||
# Also check for setup-labels command (not in config since it's a setup command)
|
||||
if f"{mention_prefix} setup-labels" in body.lower():
|
||||
return "setup-labels"
|
||||
# Built-in commands not in config
|
||||
builtin_commands = [
|
||||
"setup-labels",
|
||||
"check-deps",
|
||||
"suggest-tests",
|
||||
"refactor-suggest",
|
||||
"architecture",
|
||||
"arch-check",
|
||||
]
|
||||
|
||||
# Check built-in commands first
|
||||
for command in builtin_commands:
|
||||
if f"{mention_prefix} {command}" in body.lower():
|
||||
return command
|
||||
|
||||
for command in commands:
|
||||
if f"{mention_prefix} {command}" in body.lower():
|
||||
@@ -392,6 +403,14 @@ class IssueAgent(BaseAgent):
|
||||
return self._command_triage(context, issue)
|
||||
elif command == "setup-labels":
|
||||
return self._command_setup_labels(context, issue)
|
||||
elif command == "check-deps":
|
||||
return self._command_check_deps(context)
|
||||
elif command == "suggest-tests":
|
||||
return self._command_suggest_tests(context)
|
||||
elif command == "refactor-suggest":
|
||||
return self._command_refactor_suggest(context)
|
||||
elif command == "architecture" or command == "arch-check":
|
||||
return self._command_architecture(context)
|
||||
|
||||
return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`."
|
||||
|
||||
@@ -464,6 +483,12 @@ Be practical and concise."""
|
||||
- `{mention_prefix} suggest` - Solution suggestions or next steps
|
||||
- `{mention_prefix} security` - Security-focused analysis of the issue
|
||||
|
||||
### Code Quality & Security
|
||||
- `{mention_prefix} check-deps` - Scan dependencies for security vulnerabilities
|
||||
- `{mention_prefix} suggest-tests` - Suggest test cases for changed/new code
|
||||
- `{mention_prefix} refactor-suggest` - Suggest refactoring opportunities
|
||||
- `{mention_prefix} architecture` - Check architecture compliance (alias: `arch-check`)
|
||||
|
||||
### Interactive Chat
|
||||
- `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools)
|
||||
- Example: `{mention_prefix} how does authentication work?`
|
||||
@@ -494,9 +519,19 @@ PR reviews run automatically when you open or update a pull request. The bot pro
|
||||
{mention_prefix} triage
|
||||
```
|
||||
|
||||
**Get help understanding:**
|
||||
**Check for dependency vulnerabilities:**
|
||||
```
|
||||
{mention_prefix} explain
|
||||
{mention_prefix} check-deps
|
||||
```
|
||||
|
||||
**Get test suggestions:**
|
||||
```
|
||||
{mention_prefix} suggest-tests
|
||||
```
|
||||
|
||||
**Check architecture compliance:**
|
||||
```
|
||||
{mention_prefix} architecture
|
||||
```
|
||||
|
||||
**Ask about the codebase:**
|
||||
@@ -504,11 +539,6 @@ PR reviews run automatically when you open or update a pull request. The bot pro
|
||||
{mention_prefix} how does the authentication system work?
|
||||
```
|
||||
|
||||
**Setup repository labels:**
|
||||
```
|
||||
{mention_prefix} setup-labels
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)*
|
||||
@@ -854,3 +884,145 @@ PR reviews run automatically when you open or update a pull request. The bot pro
|
||||
return f"{prefix} - {value}"
|
||||
else: # colon or unknown
|
||||
return base_name
|
||||
|
||||
def _command_check_deps(self, context: AgentContext) -> str:
|
||||
"""Check dependencies for security vulnerabilities."""
|
||||
try:
|
||||
from agents.dependency_agent import DependencyAgent
|
||||
|
||||
agent = DependencyAgent(config=self.config)
|
||||
result = agent.run(context)
|
||||
|
||||
if result.success:
|
||||
return result.data.get(
|
||||
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
|
||||
)
|
||||
else:
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Failed**\n\n{result.error or result.message}"
|
||||
except ImportError:
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Dependency Agent Not Available**\n\nThe dependency security agent is not installed."
|
||||
except Exception as e:
|
||||
self.logger.error(f"Dependency check failed: {e}")
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Error**\n\n{e}"
|
||||
|
||||
def _command_suggest_tests(self, context: AgentContext) -> str:
|
||||
"""Suggest tests for changed or new code."""
|
||||
try:
|
||||
from agents.test_coverage_agent import TestCoverageAgent
|
||||
|
||||
agent = TestCoverageAgent(config=self.config)
|
||||
result = agent.run(context)
|
||||
|
||||
if result.success:
|
||||
return result.data.get(
|
||||
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
|
||||
)
|
||||
else:
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Failed**\n\n{result.error or result.message}"
|
||||
except ImportError:
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Test Coverage Agent Not Available**\n\nThe test coverage agent is not installed."
|
||||
except Exception as e:
|
||||
self.logger.error(f"Test suggestion failed: {e}")
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Error**\n\n{e}"
|
||||
|
||||
def _command_architecture(self, context: AgentContext) -> str:
|
||||
"""Check architecture compliance."""
|
||||
try:
|
||||
from agents.architecture_agent import ArchitectureAgent
|
||||
|
||||
agent = ArchitectureAgent(config=self.config)
|
||||
result = agent.run(context)
|
||||
|
||||
if result.success:
|
||||
return result.data.get(
|
||||
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
|
||||
)
|
||||
else:
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Failed**\n\n{result.error or result.message}"
|
||||
except ImportError:
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Architecture Agent Not Available**\n\nThe architecture compliance agent is not installed."
|
||||
except Exception as e:
|
||||
self.logger.error(f"Architecture check failed: {e}")
|
||||
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Error**\n\n{e}"
|
||||
|
||||
def _command_refactor_suggest(self, context: AgentContext) -> str:
|
||||
"""Suggest refactoring opportunities."""
|
||||
issue = context.event_data.get("issue", {})
|
||||
title = issue.get("title", "")
|
||||
body = issue.get("body", "")
|
||||
|
||||
# Use LLM to analyze for refactoring opportunities
|
||||
prompt = f"""Analyze the following issue/context and suggest refactoring opportunities:
|
||||
|
||||
Issue Title: {title}
|
||||
Issue Body: {body}
|
||||
|
||||
Based on common refactoring patterns, suggest:
|
||||
1. Code smell detection (if any code is referenced)
|
||||
2. Design pattern opportunities
|
||||
3. Complexity reduction suggestions
|
||||
4. DRY principle violations
|
||||
5. SOLID principle improvements
|
||||
|
||||
Format your response as a structured report with actionable recommendations.
|
||||
If no code is referenced in the issue, provide general refactoring guidance based on the context.
|
||||
|
||||
Return as JSON:
|
||||
{{
|
||||
"summary": "Brief summary of refactoring opportunities",
|
||||
"suggestions": [
|
||||
{{
|
||||
"category": "Code Smell | Design Pattern | Complexity | DRY | SOLID",
|
||||
"title": "Short title",
|
||||
"description": "Detailed description",
|
||||
"priority": "high | medium | low",
|
||||
"effort": "small | medium | large"
|
||||
}}
|
||||
],
|
||||
"general_advice": "Any general refactoring advice"
|
||||
}}"""
|
||||
|
||||
try:
|
||||
result = self.call_llm_json(prompt)
|
||||
|
||||
lines = [f"{self.AI_DISCLAIMER}\n"]
|
||||
lines.append("## Refactoring Suggestions\n")
|
||||
|
||||
if result.get("summary"):
|
||||
lines.append(f"**Summary:** {result['summary']}\n")
|
||||
|
||||
suggestions = result.get("suggestions", [])
|
||||
if suggestions:
|
||||
lines.append("### Recommendations\n")
|
||||
lines.append("| Priority | Category | Suggestion | Effort |")
|
||||
lines.append("|----------|----------|------------|--------|")
|
||||
|
||||
for s in suggestions:
|
||||
priority = s.get("priority", "medium").upper()
|
||||
priority_icon = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🟢"}.get(
|
||||
priority, "⚪"
|
||||
)
|
||||
lines.append(
|
||||
f"| {priority_icon} {priority} | {s.get('category', 'General')} | "
|
||||
f"**{s.get('title', 'Suggestion')}** | {s.get('effort', 'medium')} |"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
|
||||
# Detailed descriptions
|
||||
lines.append("### Details\n")
|
||||
for i, s in enumerate(suggestions, 1):
|
||||
lines.append(f"**{i}. {s.get('title', 'Suggestion')}**")
|
||||
lines.append(f"{s.get('description', 'No description')}\n")
|
||||
|
||||
if result.get("general_advice"):
|
||||
lines.append("### General Advice\n")
|
||||
lines.append(result["general_advice"])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Refactor suggestion failed: {e}")
|
||||
return (
|
||||
f"{self.AI_DISCLAIMER}\n\n**Refactor Suggestion Failed**\n\nError: {e}"
|
||||
)
|
||||
|
||||
480
tools/ai-review/agents/test_coverage_agent.py
Normal file
480
tools/ai-review/agents/test_coverage_agent.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""Test Coverage Agent
|
||||
|
||||
AI agent for analyzing code changes and suggesting test cases.
|
||||
Helps improve test coverage by identifying untested code paths.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestSuggestion:
|
||||
"""A suggested test case."""
|
||||
|
||||
function_name: str
|
||||
file_path: str
|
||||
test_type: str # unit, integration, edge_case
|
||||
description: str
|
||||
example_code: str | None = None
|
||||
priority: str = "MEDIUM" # HIGH, MEDIUM, LOW
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoverageReport:
|
||||
"""Report of test coverage analysis."""
|
||||
|
||||
functions_analyzed: int
|
||||
functions_with_tests: int
|
||||
functions_without_tests: int
|
||||
suggestions: list[TestSuggestion]
|
||||
existing_tests: list[str]
|
||||
coverage_estimate: float
|
||||
|
||||
|
||||
class TestCoverageAgent(BaseAgent):
|
||||
"""Agent for analyzing test coverage and suggesting tests."""
|
||||
|
||||
# Marker for test coverage comments
|
||||
TEST_AI_MARKER = "<!-- AI_TEST_COVERAGE -->"
|
||||
|
||||
# Test file patterns by language
|
||||
TEST_PATTERNS = {
|
||||
"python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"],
|
||||
"javascript": [
|
||||
r".*\.test\.[jt]sx?$",
|
||||
r".*\.spec\.[jt]sx?$",
|
||||
r"__tests__/.*\.[jt]sx?$",
|
||||
],
|
||||
"go": [r".*_test\.go$"],
|
||||
"rust": [r"tests?/.*\.rs$"],
|
||||
"java": [r".*Test\.java$", r".*Tests\.java$"],
|
||||
"ruby": [r".*_spec\.rb$", r"test_.*\.rb$"],
|
||||
}
|
||||
|
||||
# Function/method patterns by language
|
||||
FUNCTION_PATTERNS = {
|
||||
"python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(",
|
||||
"javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))",
|
||||
"go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(",
|
||||
"rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)",
|
||||
"java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{",
|
||||
"ruby": r"^\s*def\s+(\w+)",
|
||||
}
|
||||
|
||||
def can_handle(self, event_type: str, event_data: dict) -> bool:
|
||||
"""Check if this agent handles the given event."""
|
||||
agent_config = self.config.get("agents", {}).get("test_coverage", {})
|
||||
if not agent_config.get("enabled", True):
|
||||
return False
|
||||
|
||||
# Handle @codebot suggest-tests command
|
||||
if event_type == "issue_comment":
|
||||
comment_body = event_data.get("comment", {}).get("body", "")
|
||||
mention_prefix = self.config.get("interaction", {}).get(
|
||||
"mention_prefix", "@codebot"
|
||||
)
|
||||
if f"{mention_prefix} suggest-tests" in comment_body.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def execute(self, context: AgentContext) -> AgentResult:
|
||||
"""Execute the test coverage agent."""
|
||||
self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}")
|
||||
|
||||
actions_taken = []
|
||||
|
||||
# Get issue/PR number and author
|
||||
issue = context.event_data.get("issue", {})
|
||||
issue_number = issue.get("number")
|
||||
comment_author = (
|
||||
context.event_data.get("comment", {}).get("user", {}).get("login", "user")
|
||||
)
|
||||
|
||||
# Check if this is a PR
|
||||
is_pr = issue.get("pull_request") is not None
|
||||
|
||||
if is_pr:
|
||||
# Analyze PR diff for changed functions
|
||||
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
|
||||
changed_functions = self._extract_changed_functions(diff)
|
||||
actions_taken.append(f"Analyzed {len(changed_functions)} changed functions")
|
||||
else:
|
||||
# Analyze entire repository
|
||||
changed_functions = self._analyze_repository(context.owner, context.repo)
|
||||
actions_taken.append(
|
||||
f"Analyzed {len(changed_functions)} functions in repository"
|
||||
)
|
||||
|
||||
# Find existing tests
|
||||
existing_tests = self._find_existing_tests(context.owner, context.repo)
|
||||
actions_taken.append(f"Found {len(existing_tests)} existing test files")
|
||||
|
||||
# Generate test suggestions using LLM
|
||||
report = self._generate_suggestions(
|
||||
context.owner, context.repo, changed_functions, existing_tests
|
||||
)
|
||||
|
||||
# Post report
|
||||
if issue_number:
|
||||
comment = self._format_coverage_report(report, comment_author, is_pr)
|
||||
self.upsert_comment(
|
||||
context.owner,
|
||||
context.repo,
|
||||
issue_number,
|
||||
comment,
|
||||
marker=self.TEST_AI_MARKER,
|
||||
)
|
||||
actions_taken.append("Posted test coverage report")
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
message=f"Generated {len(report.suggestions)} test suggestions",
|
||||
data={
|
||||
"functions_analyzed": report.functions_analyzed,
|
||||
"suggestions_count": len(report.suggestions),
|
||||
"coverage_estimate": report.coverage_estimate,
|
||||
},
|
||||
actions_taken=actions_taken,
|
||||
)
|
||||
|
||||
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
|
||||
"""Get the PR diff."""
|
||||
try:
|
||||
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get PR diff: {e}")
|
||||
return ""
|
||||
|
||||
def _extract_changed_functions(self, diff: str) -> list[dict]:
|
||||
"""Extract changed functions from diff."""
|
||||
functions = []
|
||||
current_file = None
|
||||
current_language = None
|
||||
|
||||
for line in diff.splitlines():
|
||||
# Track current file
|
||||
if line.startswith("diff --git"):
|
||||
match = re.search(r"b/(.+)$", line)
|
||||
if match:
|
||||
current_file = match.group(1)
|
||||
current_language = self._detect_language(current_file)
|
||||
|
||||
# Look for function definitions in added lines
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
if current_language and current_language in self.FUNCTION_PATTERNS:
|
||||
pattern = self.FUNCTION_PATTERNS[current_language]
|
||||
match = re.search(pattern, line[1:]) # Skip the + prefix
|
||||
if match:
|
||||
func_name = next(g for g in match.groups() if g)
|
||||
functions.append(
|
||||
{
|
||||
"name": func_name,
|
||||
"file": current_file,
|
||||
"language": current_language,
|
||||
"line": line[1:].strip(),
|
||||
}
|
||||
)
|
||||
|
||||
return functions
|
||||
|
||||
def _analyze_repository(self, owner: str, repo: str) -> list[dict]:
|
||||
"""Analyze repository for functions without tests."""
|
||||
functions = []
|
||||
code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"}
|
||||
|
||||
# Get repository contents (limited to avoid API exhaustion)
|
||||
try:
|
||||
contents = self.gitea.get_file_contents(owner, repo, "")
|
||||
if isinstance(contents, list):
|
||||
for item in contents[:50]: # Limit files
|
||||
if item.get("type") == "file":
|
||||
filepath = item.get("path", "")
|
||||
ext = os.path.splitext(filepath)[1]
|
||||
if ext in code_extensions:
|
||||
file_functions = self._extract_functions_from_file(
|
||||
owner, repo, filepath
|
||||
)
|
||||
functions.extend(file_functions)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to analyze repository: {e}")
|
||||
|
||||
return functions[:100] # Limit total functions
|
||||
|
||||
def _extract_functions_from_file(
|
||||
self, owner: str, repo: str, filepath: str
|
||||
) -> list[dict]:
|
||||
"""Extract function definitions from a file."""
|
||||
functions = []
|
||||
language = self._detect_language(filepath)
|
||||
|
||||
if not language or language not in self.FUNCTION_PATTERNS:
|
||||
return functions
|
||||
|
||||
try:
|
||||
content_data = self.gitea.get_file_contents(owner, repo, filepath)
|
||||
if content_data.get("content"):
|
||||
content = base64.b64decode(content_data["content"]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
pattern = self.FUNCTION_PATTERNS[language]
|
||||
for i, line in enumerate(content.splitlines(), 1):
|
||||
match = re.search(pattern, line)
|
||||
if match:
|
||||
func_name = next((g for g in match.groups() if g), None)
|
||||
if func_name and not func_name.startswith("_"):
|
||||
functions.append(
|
||||
{
|
||||
"name": func_name,
|
||||
"file": filepath,
|
||||
"language": language,
|
||||
"line_number": i,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return functions
|
||||
|
||||
def _detect_language(self, filepath: str) -> str | None:
|
||||
"""Detect programming language from file path."""
|
||||
ext_map = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".jsx": "javascript",
|
||||
".ts": "javascript",
|
||||
".tsx": "javascript",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".java": "java",
|
||||
".rb": "ruby",
|
||||
}
|
||||
ext = os.path.splitext(filepath)[1]
|
||||
return ext_map.get(ext)
|
||||
|
||||
def _find_existing_tests(self, owner: str, repo: str) -> list[str]:
|
||||
"""Find existing test files in the repository."""
|
||||
test_files = []
|
||||
|
||||
# Common test directories
|
||||
test_dirs = ["tests", "test", "__tests__", "spec"]
|
||||
|
||||
for test_dir in test_dirs:
|
||||
try:
|
||||
contents = self.gitea.get_file_contents(owner, repo, test_dir)
|
||||
if isinstance(contents, list):
|
||||
for item in contents:
|
||||
if item.get("type") == "file":
|
||||
test_files.append(item.get("path", ""))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also check root for test files
|
||||
try:
|
||||
contents = self.gitea.get_file_contents(owner, repo, "")
|
||||
if isinstance(contents, list):
|
||||
for item in contents:
|
||||
if item.get("type") == "file":
|
||||
filepath = item.get("path", "")
|
||||
if self._is_test_file(filepath):
|
||||
test_files.append(filepath)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return test_files
|
||||
|
||||
def _is_test_file(self, filepath: str) -> bool:
|
||||
"""Check if a file is a test file."""
|
||||
for lang, patterns in self.TEST_PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, filepath):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _generate_suggestions(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
functions: list[dict],
|
||||
existing_tests: list[str],
|
||||
) -> CoverageReport:
|
||||
"""Generate test suggestions using LLM."""
|
||||
suggestions = []
|
||||
|
||||
# Build prompt for LLM
|
||||
if functions:
|
||||
functions_text = "\n".join(
|
||||
[
|
||||
f"- {f['name']} in {f['file']} ({f['language']})"
|
||||
for f in functions[:20] # Limit for prompt size
|
||||
]
|
||||
)
|
||||
|
||||
prompt = f"""Analyze these functions and suggest test cases:
|
||||
|
||||
Functions to test:
|
||||
{functions_text}
|
||||
|
||||
Existing test files:
|
||||
{", ".join(existing_tests[:10]) if existing_tests else "None found"}
|
||||
|
||||
For each function, suggest:
|
||||
1. What to test (happy path, edge cases, error handling)
|
||||
2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities)
|
||||
3. Brief example test code if possible
|
||||
|
||||
Respond in JSON format:
|
||||
{{
|
||||
"suggestions": [
|
||||
{{
|
||||
"function_name": "function_name",
|
||||
"file_path": "path/to/file",
|
||||
"test_type": "unit|integration|edge_case",
|
||||
"description": "What to test",
|
||||
"example_code": "brief example or null",
|
||||
"priority": "HIGH|MEDIUM|LOW"
|
||||
}}
|
||||
],
|
||||
"coverage_estimate": 0.0 to 1.0
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
result = self.call_llm_json(prompt)
|
||||
|
||||
for s in result.get("suggestions", []):
|
||||
suggestions.append(
|
||||
TestSuggestion(
|
||||
function_name=s.get("function_name", ""),
|
||||
file_path=s.get("file_path", ""),
|
||||
test_type=s.get("test_type", "unit"),
|
||||
description=s.get("description", ""),
|
||||
example_code=s.get("example_code"),
|
||||
priority=s.get("priority", "MEDIUM"),
|
||||
)
|
||||
)
|
||||
|
||||
coverage_estimate = result.get("coverage_estimate", 0.5)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"LLM suggestion failed: {e}")
|
||||
# Generate basic suggestions without LLM
|
||||
for f in functions[:10]:
|
||||
suggestions.append(
|
||||
TestSuggestion(
|
||||
function_name=f["name"],
|
||||
file_path=f["file"],
|
||||
test_type="unit",
|
||||
description=f"Add unit tests for {f['name']}",
|
||||
priority="MEDIUM",
|
||||
)
|
||||
)
|
||||
coverage_estimate = 0.5
|
||||
|
||||
else:
|
||||
coverage_estimate = 1.0 if existing_tests else 0.0
|
||||
|
||||
# Estimate functions with tests
|
||||
functions_with_tests = int(len(functions) * coverage_estimate)
|
||||
|
||||
return CoverageReport(
|
||||
functions_analyzed=len(functions),
|
||||
functions_with_tests=functions_with_tests,
|
||||
functions_without_tests=len(functions) - functions_with_tests,
|
||||
suggestions=suggestions,
|
||||
existing_tests=existing_tests,
|
||||
coverage_estimate=coverage_estimate,
|
||||
)
|
||||
|
||||
def _format_coverage_report(
|
||||
self, report: CoverageReport, user: str | None, is_pr: bool
|
||||
) -> str:
|
||||
"""Format the coverage report as a comment."""
|
||||
lines = []
|
||||
|
||||
if user:
|
||||
lines.append(f"@{user}")
|
||||
lines.append("")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
f"{self.AI_DISCLAIMER}",
|
||||
"",
|
||||
"## 🧪 Test Coverage Analysis",
|
||||
"",
|
||||
"### Summary",
|
||||
"",
|
||||
f"| Metric | Value |",
|
||||
f"|--------|-------|",
|
||||
f"| Functions Analyzed | {report.functions_analyzed} |",
|
||||
f"| Estimated Coverage | {report.coverage_estimate:.0%} |",
|
||||
f"| Test Files Found | {len(report.existing_tests)} |",
|
||||
f"| Suggestions | {len(report.suggestions)} |",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
# Suggestions by priority
|
||||
if report.suggestions:
|
||||
lines.append("### 💡 Test Suggestions")
|
||||
lines.append("")
|
||||
|
||||
# Group by priority
|
||||
by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []}
|
||||
for s in report.suggestions:
|
||||
if s.priority in by_priority:
|
||||
by_priority[s.priority].append(s)
|
||||
|
||||
priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
|
||||
|
||||
for priority in ["HIGH", "MEDIUM", "LOW"]:
|
||||
suggestions = by_priority[priority]
|
||||
if suggestions:
|
||||
lines.append(f"#### {priority_emoji[priority]} {priority} Priority")
|
||||
lines.append("")
|
||||
for s in suggestions[:5]: # Limit display
|
||||
lines.append(f"**`{s.function_name}`** in `{s.file_path}`")
|
||||
lines.append(f"- Type: {s.test_type}")
|
||||
lines.append(f"- {s.description}")
|
||||
if s.example_code:
|
||||
lines.append(f"```")
|
||||
lines.append(s.example_code[:200])
|
||||
lines.append(f"```")
|
||||
lines.append("")
|
||||
if len(suggestions) > 5:
|
||||
lines.append(f"*... and {len(suggestions) - 5} more*")
|
||||
lines.append("")
|
||||
|
||||
# Existing test files
|
||||
if report.existing_tests:
|
||||
lines.append("### 📁 Existing Test Files")
|
||||
lines.append("")
|
||||
for f in report.existing_tests[:10]:
|
||||
lines.append(f"- `{f}`")
|
||||
if len(report.existing_tests) > 10:
|
||||
lines.append(f"- *... and {len(report.existing_tests) - 10} more*")
|
||||
lines.append("")
|
||||
|
||||
# Coverage bar
|
||||
lines.append("### 📊 Coverage Estimate")
|
||||
lines.append("")
|
||||
filled = int(report.coverage_estimate * 10)
|
||||
bar = "█" * filled + "░" * (10 - filled)
|
||||
lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}")
|
||||
lines.append("")
|
||||
|
||||
# Recommendations
|
||||
if report.coverage_estimate < 0.8:
|
||||
lines.append("---")
|
||||
lines.append("⚠️ **Coverage below 80%** - Consider adding more tests")
|
||||
elif report.coverage_estimate >= 0.8:
|
||||
lines.append("---")
|
||||
lines.append("✅ **Good test coverage!**")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -77,11 +77,13 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
model: str = "gpt-4o-mini",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.timeout = timeout
|
||||
self.api_url = "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
@@ -101,7 +103,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -145,7 +147,7 @@ class OpenAIProvider(BaseLLMProvider):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -186,11 +188,13 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
model: str = "anthropic/claude-3.5-sonnet",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.timeout = timeout
|
||||
self.api_url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
@@ -210,7 +214,7 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -254,7 +258,7 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -294,10 +298,12 @@ class OllamaProvider(BaseLLMProvider):
|
||||
host: str | None = None,
|
||||
model: str = "codellama:13b",
|
||||
temperature: float = 0,
|
||||
timeout: int = 300,
|
||||
):
|
||||
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.timeout = timeout
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Call Ollama API."""
|
||||
@@ -311,7 +317,7 @@ class OllamaProvider(BaseLLMProvider):
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
},
|
||||
timeout=300, # Longer timeout for local models
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -477,12 +483,18 @@ class LLMClient:
|
||||
provider = config.get("provider", "openai")
|
||||
provider_config = {}
|
||||
|
||||
# Get timeout configuration
|
||||
timeouts = config.get("timeouts", {})
|
||||
llm_timeout = timeouts.get("llm", 120)
|
||||
ollama_timeout = timeouts.get("ollama", 300)
|
||||
|
||||
# Map config keys to provider-specific settings
|
||||
if provider == "openai":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("openai", "gpt-4o-mini"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"max_tokens": config.get("max_tokens", 16000),
|
||||
"timeout": llm_timeout,
|
||||
}
|
||||
elif provider == "openrouter":
|
||||
provider_config = {
|
||||
@@ -491,11 +503,13 @@ class LLMClient:
|
||||
),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"max_tokens": config.get("max_tokens", 16000),
|
||||
"timeout": llm_timeout,
|
||||
}
|
||||
elif provider == "ollama":
|
||||
provider_config = {
|
||||
"model": config.get("model", {}).get("ollama", "codellama:13b"),
|
||||
"temperature": config.get("temperature", 0),
|
||||
"timeout": ollama_timeout,
|
||||
}
|
||||
|
||||
return cls(provider=provider, config=provider_config)
|
||||
|
||||
27
tools/ai-review/clients/providers/__init__.py
Normal file
27
tools/ai-review/clients/providers/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""LLM Providers Package
|
||||
|
||||
This package contains additional LLM provider implementations
|
||||
beyond the core providers in llm_client.py.
|
||||
|
||||
Providers:
|
||||
- AnthropicProvider: Direct Anthropic Claude API
|
||||
- AzureOpenAIProvider: Azure OpenAI Service with API key auth
|
||||
- AzureOpenAIWithAADProvider: Azure OpenAI with Azure AD auth
|
||||
- GeminiProvider: Google Gemini API (public)
|
||||
- VertexAIGeminiProvider: Google Vertex AI Gemini (enterprise GCP)
|
||||
"""
|
||||
|
||||
from clients.providers.anthropic_provider import AnthropicProvider
|
||||
from clients.providers.azure_provider import (
|
||||
AzureOpenAIProvider,
|
||||
AzureOpenAIWithAADProvider,
|
||||
)
|
||||
from clients.providers.gemini_provider import GeminiProvider, VertexAIGeminiProvider
|
||||
|
||||
__all__ = [
|
||||
"AnthropicProvider",
|
||||
"AzureOpenAIProvider",
|
||||
"AzureOpenAIWithAADProvider",
|
||||
"GeminiProvider",
|
||||
"VertexAIGeminiProvider",
|
||||
]
|
||||
249
tools/ai-review/clients/providers/anthropic_provider.py
Normal file
249
tools/ai-review/clients/providers/anthropic_provider.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Anthropic Claude Provider
|
||||
|
||||
Direct integration with Anthropic's Claude API.
|
||||
Supports Claude 3.5 Sonnet, Claude 3 Opus, and other models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
# Import base classes from parent module
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class AnthropicProvider(BaseLLMProvider):
|
||||
"""Anthropic Claude API provider.
|
||||
|
||||
Provides direct integration with Anthropic's Claude models
|
||||
without going through OpenRouter.
|
||||
|
||||
Supports:
|
||||
- Claude 3.5 Sonnet (claude-3-5-sonnet-20241022)
|
||||
- Claude 3 Opus (claude-3-opus-20240229)
|
||||
- Claude 3 Sonnet (claude-3-sonnet-20240229)
|
||||
- Claude 3 Haiku (claude-3-haiku-20240307)
|
||||
"""
|
||||
|
||||
API_URL = "https://api.anthropic.com/v1/messages"
|
||||
API_VERSION = "2023-06-01"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
|
||||
model: Model to use. Defaults to Claude 3.5 Sonnet.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Anthropic API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (model, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not set.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
|
||||
response = requests.post(
|
||||
self.API_URL,
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": self.API_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": kwargs.get("model", self.model),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
content += block.get("text", "")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=data.get("model", self.model),
|
||||
provider="anthropic",
|
||||
tokens_used=data.get("usage", {}).get("input_tokens", 0)
|
||||
+ data.get("usage", {}).get("output_tokens", 0),
|
||||
finish_reason=data.get("stop_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Anthropic API with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
|
||||
# Convert OpenAI-style messages to Anthropic format
|
||||
anthropic_messages = []
|
||||
system_content = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
|
||||
if role == "system":
|
||||
system_content = msg.get("content", "")
|
||||
elif role == "assistant":
|
||||
# Handle assistant messages with tool calls
|
||||
if msg.get("tool_calls"):
|
||||
content = []
|
||||
if msg.get("content"):
|
||||
content.append({"type": "text", "text": msg["content"]})
|
||||
for tc in msg["tool_calls"]:
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["function"]["name"],
|
||||
"input": json.loads(tc["function"]["arguments"])
|
||||
if isinstance(tc["function"]["arguments"], str)
|
||||
else tc["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
anthropic_messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
)
|
||||
elif role == "tool":
|
||||
# Tool response
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
)
|
||||
|
||||
# Convert OpenAI-style tools to Anthropic format
|
||||
anthropic_tools = None
|
||||
if tools:
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
anthropic_tools.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
|
||||
request_body = {
|
||||
"model": kwargs.get("model", self.model),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"messages": anthropic_messages,
|
||||
}
|
||||
|
||||
if system_content:
|
||||
request_body["system"] = system_content
|
||||
|
||||
if anthropic_tools:
|
||||
request_body["tools"] = anthropic_tools
|
||||
|
||||
response = requests.post(
|
||||
self.API_URL,
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": self.API_VERSION,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=block.get("id", ""),
|
||||
name=block.get("name", ""),
|
||||
arguments=block.get("input", {}),
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=data.get("model", self.model),
|
||||
provider="anthropic",
|
||||
tokens_used=data.get("usage", {}).get("input_tokens", 0)
|
||||
+ data.get("usage", {}).get("output_tokens", 0),
|
||||
finish_reason=data.get("stop_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""Azure OpenAI Provider
|
||||
|
||||
Integration with Azure OpenAI Service for enterprise deployments.
|
||||
Supports custom deployments, regional endpoints, and Azure AD auth.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class AzureOpenAIProvider(BaseLLMProvider):
|
||||
"""Azure OpenAI Service provider.
|
||||
|
||||
Provides integration with Azure-hosted OpenAI models for
|
||||
enterprise customers with Azure deployments.
|
||||
|
||||
Supports:
|
||||
- GPT-4, GPT-4 Turbo, GPT-4o
|
||||
- GPT-3.5 Turbo
|
||||
- Custom fine-tuned models
|
||||
|
||||
Environment Variables:
|
||||
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||
- AZURE_OPENAI_API_KEY: API key for authentication
|
||||
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||
- AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview)
|
||||
"""
|
||||
|
||||
DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Azure OpenAI provider.
|
||||
|
||||
Args:
|
||||
endpoint: Azure OpenAI endpoint URL.
|
||||
Defaults to AZURE_OPENAI_ENDPOINT env var.
|
||||
api_key: API key for authentication.
|
||||
Defaults to AZURE_OPENAI_API_KEY env var.
|
||||
deployment: Deployment name to use.
|
||||
Defaults to AZURE_OPENAI_DEPLOYMENT env var.
|
||||
api_version: API version string.
|
||||
Defaults to AZURE_OPENAI_API_VERSION env var or latest.
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.endpoint = (
|
||||
endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
|
||||
).rstrip("/")
|
||||
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
|
||||
self.api_version = api_version or os.environ.get(
|
||||
"AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION
|
||||
)
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _get_api_url(self, deployment: str | None = None) -> str:
|
||||
"""Build the API URL for a given deployment.
|
||||
|
||||
Args:
|
||||
deployment: Deployment name. Uses default if not specified.
|
||||
|
||||
Returns:
|
||||
Full API URL for chat completions.
|
||||
"""
|
||||
deploy = deployment or self.deployment
|
||||
return (
|
||||
f"{self.endpoint}/openai/deployments/{deploy}"
|
||||
f"/chat/completions?api-version={self.api_version}"
|
||||
)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (deployment, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration is missing.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure OpenAI API key is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API with tool support.
|
||||
|
||||
Azure OpenAI uses the same format as OpenAI for tools.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure OpenAI API key is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
|
||||
request_body = {
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", "") or "",
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIWithAADProvider(AzureOpenAIProvider):
|
||||
"""Azure OpenAI provider with Azure Active Directory authentication.
|
||||
|
||||
Uses Azure AD tokens instead of API keys for authentication.
|
||||
Requires azure-identity package for token acquisition.
|
||||
|
||||
Environment Variables:
|
||||
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||
- AZURE_TENANT_ID: Azure AD tenant ID (optional)
|
||||
- AZURE_CLIENT_ID: Azure AD client ID (optional)
|
||||
- AZURE_CLIENT_SECRET: Azure AD client secret (optional)
|
||||
"""
|
||||
|
||||
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str | None = None,
|
||||
deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
credential=None,
|
||||
):
|
||||
"""Initialize the Azure OpenAI AAD provider.
|
||||
|
||||
Args:
|
||||
endpoint: Azure OpenAI endpoint URL.
|
||||
deployment: Deployment name to use.
|
||||
api_version: API version string.
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens in response.
|
||||
credential: Azure credential object. If not provided,
|
||||
uses DefaultAzureCredential.
|
||||
"""
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
api_key="", # Not used with AAD
|
||||
deployment=deployment,
|
||||
api_version=api_version,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self._credential = credential
|
||||
self._token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Get an Azure AD token for authentication.
|
||||
|
||||
Returns:
|
||||
Bearer token string.
|
||||
|
||||
Raises:
|
||||
ImportError: If azure-identity is not installed.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if self._token and self._token_expires_at > time.time() + 300:
|
||||
return self._token
|
||||
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-identity package is required for AAD authentication. "
|
||||
"Install with: pip install azure-identity"
|
||||
)
|
||||
|
||||
if self._credential is None:
|
||||
self._credential = DefaultAzureCredential()
|
||||
|
||||
token = self._credential.get_token(self.SCOPE)
|
||||
self._token = token.token
|
||||
self._token_expires_at = token.expires_on
|
||||
|
||||
return self._token
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API using AAD auth.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
token = self._get_token()
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API with tool support using AAD auth.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts.
|
||||
tools: List of tool definitions.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
token = self._get_token()
|
||||
|
||||
request_body = {
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", "") or "",
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
599
tools/ai-review/clients/providers/gemini_provider.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Google Gemini Provider
|
||||
|
||||
Integration with Google's Gemini API for GCP customers.
|
||||
Supports Gemini Pro, Gemini Ultra, and other models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class GeminiProvider(BaseLLMProvider):
|
||||
"""Google Gemini API provider.
|
||||
|
||||
Provides integration with Google's Gemini models.
|
||||
|
||||
Supports:
|
||||
- Gemini 1.5 Pro (gemini-1.5-pro)
|
||||
- Gemini 1.5 Flash (gemini-1.5-flash)
|
||||
- Gemini 1.0 Pro (gemini-pro)
|
||||
|
||||
Environment Variables:
|
||||
- GOOGLE_API_KEY: Google AI API key
|
||||
- GEMINI_MODEL: Default model (optional)
|
||||
"""
|
||||
|
||||
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "gemini-1.5-pro",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Gemini provider.
|
||||
|
||||
Args:
|
||||
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
|
||||
model: Model to use. Defaults to gemini-1.5-pro.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
|
||||
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
|
||||
"""Build the API URL for a given model.
|
||||
|
||||
Args:
|
||||
model: Model name. Uses default if not specified.
|
||||
stream: Whether to use streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Full API URL.
|
||||
"""
|
||||
m = model or self.model
|
||||
action = "streamGenerateContent" if stream else "generateContent"
|
||||
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Gemini API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (model, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is not set.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="gemini",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Gemini API with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
|
||||
# Convert OpenAI-style messages to Gemini format
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_instruction = content
|
||||
elif role == "assistant":
|
||||
# Handle assistant messages with tool calls
|
||||
parts = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
gemini_contents.append({"role": "model", "parts": parts})
|
||||
elif role == "tool":
|
||||
# Tool response in Gemini format
|
||||
gemini_contents.append(
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.get("name", ""),
|
||||
"response": {"result": content},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
# User message
|
||||
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
# Convert OpenAI-style tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if function_declarations:
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
request_body = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
request_body["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
if gemini_tools:
|
||||
request_body["tools"] = gemini_tools
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="gemini",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIGeminiProvider(BaseLLMProvider):
|
||||
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
|
||||
|
||||
Uses Vertex AI endpoints instead of the public Gemini API.
|
||||
Supports regional deployments and IAM authentication.
|
||||
|
||||
Environment Variables:
|
||||
- GOOGLE_CLOUD_PROJECT: GCP project ID
|
||||
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
|
||||
- VERTEX_AI_MODEL: Default model (optional)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project: str | None = None,
|
||||
region: str = "us-central1",
|
||||
model: str = "gemini-1.5-pro",
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
credentials=None,
|
||||
):
|
||||
"""Initialize the Vertex AI Gemini provider.
|
||||
|
||||
Args:
|
||||
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
|
||||
region: GCP region. Defaults to us-central1.
|
||||
model: Model to use. Defaults to gemini-1.5-pro.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens in response.
|
||||
credentials: Google credentials object. If not provided,
|
||||
uses Application Default Credentials.
|
||||
"""
|
||||
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
|
||||
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
|
||||
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self._credentials = credentials
|
||||
self._token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Get a Google Cloud access token.
|
||||
|
||||
Returns:
|
||||
Access token string.
|
||||
|
||||
Raises:
|
||||
ImportError: If google-auth is not installed.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if self._token and self._token_expires_at > time.time() + 300:
|
||||
return self._token
|
||||
|
||||
try:
|
||||
import google.auth
|
||||
from google.auth.transport.requests import Request
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-auth package is required for Vertex AI authentication. "
|
||||
"Install with: pip install google-auth"
|
||||
)
|
||||
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
|
||||
if not self._credentials.valid:
|
||||
self._credentials.refresh(Request())
|
||||
|
||||
self._token = self._credentials.token
|
||||
# Tokens typically expire in 1 hour
|
||||
self._token_expires_at = time.time() + 3500
|
||||
|
||||
return self._token
|
||||
|
||||
def _get_api_url(self, model: str | None = None) -> str:
|
||||
"""Build the Vertex AI API URL.
|
||||
|
||||
Args:
|
||||
model: Model name. Uses default if not specified.
|
||||
|
||||
Returns:
|
||||
Full API URL.
|
||||
"""
|
||||
m = model or self.model
|
||||
return (
|
||||
f"https://{self.region}-aiplatform.googleapis.com/v1/"
|
||||
f"projects/{self.project}/locations/{self.region}/"
|
||||
f"publishers/google/models/{m}:generateContent"
|
||||
)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to Vertex AI Gemini.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
if not self.project:
|
||||
raise ValueError("GCP project ID is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
token = self._get_token()
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
|
||||
# Get token counts
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="vertex-ai",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to Vertex AI Gemini with tool support.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts.
|
||||
tools: List of tool definitions.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.project:
|
||||
raise ValueError("GCP project ID is required")
|
||||
|
||||
model = kwargs.get("model", self.model)
|
||||
token = self._get_token()
|
||||
|
||||
# Convert messages to Gemini format (same as GeminiProvider)
|
||||
gemini_contents = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
system_instruction = content
|
||||
elif role == "assistant":
|
||||
parts = []
|
||||
if content:
|
||||
parts.append({"text": content})
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
gemini_contents.append({"role": "model", "parts": parts})
|
||||
elif role == "tool":
|
||||
gemini_contents.append(
|
||||
{
|
||||
"role": "function",
|
||||
"parts": [
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": msg.get("name", ""),
|
||||
"response": {"result": content},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
# Convert tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
function_declarations.append(
|
||||
{
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
if function_declarations:
|
||||
gemini_tools = [{"functionDeclarations": function_declarations}]
|
||||
|
||||
request_body = {
|
||||
"contents": gemini_contents,
|
||||
"generationConfig": {
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
},
|
||||
}
|
||||
|
||||
if system_instruction:
|
||||
request_body["systemInstruction"] = {
|
||||
"parts": [{"text": system_instruction}]
|
||||
}
|
||||
|
||||
if gemini_tools:
|
||||
request_body["tools"] = gemini_tools
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(model),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content += part["text"]
|
||||
elif "functionCall" in part:
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{len(tool_calls)}",
|
||||
name=fc.get("name", ""),
|
||||
arguments=fc.get("args", {}),
|
||||
)
|
||||
)
|
||||
|
||||
usage = data.get("usageMetadata", {})
|
||||
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
|
||||
"candidatesTokenCount", 0
|
||||
)
|
||||
|
||||
finish_reason = None
|
||||
if candidates:
|
||||
finish_reason = candidates[0].get("finishReason")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
provider="vertex-ai",
|
||||
tokens_used=tokens_used,
|
||||
finish_reason=finish_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
14
tools/ai-review/compliance/__init__.py
Normal file
14
tools/ai-review/compliance/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Compliance Module
|
||||
|
||||
Provides audit trail, compliance reporting, and regulatory checks.
|
||||
"""
|
||||
|
||||
from compliance.audit_trail import AuditEvent, AuditLogger, AuditTrail
|
||||
from compliance.codeowners import CodeownersChecker
|
||||
|
||||
__all__ = [
|
||||
"AuditTrail",
|
||||
"AuditLogger",
|
||||
"AuditEvent",
|
||||
"CodeownersChecker",
|
||||
]
|
||||
430
tools/ai-review/compliance/audit_trail.py
Normal file
430
tools/ai-review/compliance/audit_trail.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Audit Trail
|
||||
|
||||
Provides comprehensive audit logging for compliance requirements.
|
||||
Supports HIPAA, SOC2, and other regulatory frameworks.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AuditAction(Enum):
|
||||
"""Types of auditable actions."""
|
||||
|
||||
# Review actions
|
||||
REVIEW_STARTED = "review_started"
|
||||
REVIEW_COMPLETED = "review_completed"
|
||||
REVIEW_FAILED = "review_failed"
|
||||
|
||||
# Security actions
|
||||
SECURITY_SCAN_STARTED = "security_scan_started"
|
||||
SECURITY_SCAN_COMPLETED = "security_scan_completed"
|
||||
SECURITY_FINDING_DETECTED = "security_finding_detected"
|
||||
SECURITY_FINDING_RESOLVED = "security_finding_resolved"
|
||||
|
||||
# Comment actions
|
||||
COMMENT_POSTED = "comment_posted"
|
||||
COMMENT_UPDATED = "comment_updated"
|
||||
COMMENT_DELETED = "comment_deleted"
|
||||
|
||||
# Label actions
|
||||
LABEL_ADDED = "label_added"
|
||||
LABEL_REMOVED = "label_removed"
|
||||
|
||||
# Configuration actions
|
||||
CONFIG_LOADED = "config_loaded"
|
||||
CONFIG_CHANGED = "config_changed"
|
||||
|
||||
# Access actions
|
||||
API_CALL = "api_call"
|
||||
AUTHENTICATION = "authentication"
|
||||
|
||||
# Approval actions
|
||||
APPROVAL_GRANTED = "approval_granted"
|
||||
APPROVAL_REVOKED = "approval_revoked"
|
||||
CHANGES_REQUESTED = "changes_requested"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditEvent:
|
||||
"""An auditable event."""
|
||||
|
||||
action: AuditAction
|
||||
timestamp: str
|
||||
actor: str
|
||||
resource_type: str
|
||||
resource_id: str
|
||||
repository: str
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
outcome: str = "success"
|
||||
error: str | None = None
|
||||
correlation_id: str | None = None
|
||||
checksum: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calculate checksum for integrity verification."""
|
||||
if not self.checksum:
|
||||
self.checksum = self._calculate_checksum()
|
||||
|
||||
def _calculate_checksum(self) -> str:
|
||||
"""Calculate SHA-256 checksum of event data."""
|
||||
data = {
|
||||
"action": self.action.value
|
||||
if isinstance(self.action, AuditAction)
|
||||
else self.action,
|
||||
"timestamp": self.timestamp,
|
||||
"actor": self.actor,
|
||||
"resource_type": self.resource_type,
|
||||
"resource_id": self.resource_id,
|
||||
"repository": self.repository,
|
||||
"details": self.details,
|
||||
"outcome": self.outcome,
|
||||
"error": self.error,
|
||||
}
|
||||
json_str = json.dumps(data, sort_keys=True)
|
||||
return hashlib.sha256(json_str.encode()).hexdigest()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert event to dictionary."""
|
||||
data = asdict(self)
|
||||
if isinstance(self.action, AuditAction):
|
||||
data["action"] = self.action.value
|
||||
return data
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert event to JSON string."""
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Logger for audit events."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_file: str | None = None,
|
||||
log_to_stdout: bool = False,
|
||||
log_level: str = "INFO",
|
||||
):
|
||||
"""Initialize audit logger.
|
||||
|
||||
Args:
|
||||
log_file: Path to audit log file.
|
||||
log_to_stdout: Also log to stdout.
|
||||
log_level: Logging level.
|
||||
"""
|
||||
self.log_file = log_file
|
||||
self.log_to_stdout = log_to_stdout
|
||||
self.logger = logging.getLogger("audit")
|
||||
self.logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||||
|
||||
# Clear existing handlers
|
||||
self.logger.handlers = []
|
||||
|
||||
# Add file handler if specified
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter("%(message)s") # JSON lines format
|
||||
)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
# Add stdout handler if requested
|
||||
if log_to_stdout:
|
||||
stdout_handler = logging.StreamHandler()
|
||||
stdout_handler.setFormatter(logging.Formatter("[AUDIT] %(message)s"))
|
||||
self.logger.addHandler(stdout_handler)
|
||||
|
||||
def log(self, event: AuditEvent):
|
||||
"""Log an audit event.
|
||||
|
||||
Args:
|
||||
event: The audit event to log.
|
||||
"""
|
||||
self.logger.info(event.to_json())
|
||||
|
||||
def log_action(
|
||||
self,
|
||||
action: AuditAction,
|
||||
actor: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
repository: str,
|
||||
details: dict | None = None,
|
||||
outcome: str = "success",
|
||||
error: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
):
|
||||
"""Log an action as an audit event.
|
||||
|
||||
Args:
|
||||
action: The action being performed.
|
||||
actor: Who performed the action.
|
||||
resource_type: Type of resource affected.
|
||||
resource_id: ID of the resource.
|
||||
repository: Repository context.
|
||||
details: Additional details.
|
||||
outcome: success, failure, or partial.
|
||||
error: Error message if failed.
|
||||
correlation_id: ID to correlate related events.
|
||||
"""
|
||||
event = AuditEvent(
|
||||
action=action,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
actor=actor,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
repository=repository,
|
||||
details=details or {},
|
||||
outcome=outcome,
|
||||
error=error,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
self.log(event)
|
||||
|
||||
|
||||
class AuditTrail:
|
||||
"""High-level audit trail management."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""Initialize audit trail.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary.
|
||||
"""
|
||||
self.config = config
|
||||
compliance_config = config.get("compliance", {})
|
||||
audit_config = compliance_config.get("audit", {})
|
||||
|
||||
self.enabled = audit_config.get("enabled", False)
|
||||
self.log_file = audit_config.get("log_file", "audit.log")
|
||||
self.log_to_stdout = audit_config.get("log_to_stdout", False)
|
||||
self.retention_days = audit_config.get("retention_days", 90)
|
||||
|
||||
if self.enabled:
|
||||
self.logger = AuditLogger(
|
||||
log_file=self.log_file,
|
||||
log_to_stdout=self.log_to_stdout,
|
||||
)
|
||||
else:
|
||||
self.logger = None
|
||||
|
||||
self._correlation_id = None
|
||||
|
||||
def set_correlation_id(self, correlation_id: str):
|
||||
"""Set correlation ID for subsequent events.
|
||||
|
||||
Args:
|
||||
correlation_id: ID to correlate related events.
|
||||
"""
|
||||
self._correlation_id = correlation_id
|
||||
|
||||
def log(
|
||||
self,
|
||||
action: AuditAction,
|
||||
actor: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
repository: str,
|
||||
details: dict | None = None,
|
||||
outcome: str = "success",
|
||||
error: str | None = None,
|
||||
):
|
||||
"""Log an audit event.
|
||||
|
||||
Args:
|
||||
action: The action being performed.
|
||||
actor: Who performed the action.
|
||||
resource_type: Type of resource (pr, issue, comment, etc).
|
||||
resource_id: ID of the resource.
|
||||
repository: Repository (owner/repo).
|
||||
details: Additional details.
|
||||
outcome: success, failure, or partial.
|
||||
error: Error message if failed.
|
||||
"""
|
||||
if not self.enabled or not self.logger:
|
||||
return
|
||||
|
||||
self.logger.log_action(
|
||||
action=action,
|
||||
actor=actor,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
repository=repository,
|
||||
details=details,
|
||||
outcome=outcome,
|
||||
error=error,
|
||||
correlation_id=self._correlation_id,
|
||||
)
|
||||
|
||||
def log_review_started(
|
||||
self,
|
||||
repository: str,
|
||||
pr_number: int,
|
||||
reviewer: str = "openrabbit",
|
||||
):
|
||||
"""Log that a review has started."""
|
||||
self.log(
|
||||
action=AuditAction.REVIEW_STARTED,
|
||||
actor=reviewer,
|
||||
resource_type="pull_request",
|
||||
resource_id=str(pr_number),
|
||||
repository=repository,
|
||||
)
|
||||
|
||||
def log_review_completed(
|
||||
self,
|
||||
repository: str,
|
||||
pr_number: int,
|
||||
recommendation: str,
|
||||
findings_count: int,
|
||||
reviewer: str = "openrabbit",
|
||||
):
|
||||
"""Log that a review has completed."""
|
||||
self.log(
|
||||
action=AuditAction.REVIEW_COMPLETED,
|
||||
actor=reviewer,
|
||||
resource_type="pull_request",
|
||||
resource_id=str(pr_number),
|
||||
repository=repository,
|
||||
details={
|
||||
"recommendation": recommendation,
|
||||
"findings_count": findings_count,
|
||||
},
|
||||
)
|
||||
|
||||
def log_security_finding(
|
||||
self,
|
||||
repository: str,
|
||||
pr_number: int,
|
||||
finding: dict,
|
||||
scanner: str = "openrabbit",
|
||||
):
|
||||
"""Log a security finding."""
|
||||
self.log(
|
||||
action=AuditAction.SECURITY_FINDING_DETECTED,
|
||||
actor=scanner,
|
||||
resource_type="pull_request",
|
||||
resource_id=str(pr_number),
|
||||
repository=repository,
|
||||
details={
|
||||
"severity": finding.get("severity"),
|
||||
"category": finding.get("category"),
|
||||
"file": finding.get("file"),
|
||||
"line": finding.get("line"),
|
||||
"cwe": finding.get("cwe"),
|
||||
},
|
||||
)
|
||||
|
||||
def log_approval(
|
||||
self,
|
||||
repository: str,
|
||||
pr_number: int,
|
||||
approver: str,
|
||||
approval_type: str = "ai",
|
||||
):
|
||||
"""Log an approval action."""
|
||||
self.log(
|
||||
action=AuditAction.APPROVAL_GRANTED,
|
||||
actor=approver,
|
||||
resource_type="pull_request",
|
||||
resource_id=str(pr_number),
|
||||
repository=repository,
|
||||
details={"approval_type": approval_type},
|
||||
)
|
||||
|
||||
def log_changes_requested(
|
||||
self,
|
||||
repository: str,
|
||||
pr_number: int,
|
||||
requester: str,
|
||||
reason: str | None = None,
|
||||
):
|
||||
"""Log a changes requested action."""
|
||||
self.log(
|
||||
action=AuditAction.CHANGES_REQUESTED,
|
||||
actor=requester,
|
||||
resource_type="pull_request",
|
||||
resource_id=str(pr_number),
|
||||
repository=repository,
|
||||
details={"reason": reason} if reason else {},
|
||||
)
|
||||
|
||||
def generate_report(
|
||||
self,
|
||||
start_date: datetime | None = None,
|
||||
end_date: datetime | None = None,
|
||||
repository: str | None = None,
|
||||
) -> dict:
|
||||
"""Generate an audit report.
|
||||
|
||||
Args:
|
||||
start_date: Start of reporting period.
|
||||
end_date: End of reporting period.
|
||||
repository: Filter by repository.
|
||||
|
||||
Returns:
|
||||
Report dictionary with statistics and events.
|
||||
"""
|
||||
if not self.log_file or not os.path.exists(self.log_file):
|
||||
return {"events": [], "statistics": {}}
|
||||
|
||||
events = []
|
||||
with open(self.log_file) as f:
|
||||
for line in f:
|
||||
try:
|
||||
event = json.loads(line.strip())
|
||||
event_time = datetime.fromisoformat(
|
||||
event["timestamp"].replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if start_date and event_time < start_date:
|
||||
continue
|
||||
if end_date and event_time > end_date:
|
||||
continue
|
||||
if repository and event.get("repository") != repository:
|
||||
continue
|
||||
|
||||
events.append(event)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
# Calculate statistics
|
||||
action_counts = {}
|
||||
outcome_counts = {"success": 0, "failure": 0, "partial": 0}
|
||||
security_findings = 0
|
||||
|
||||
for event in events:
|
||||
action = event.get("action", "unknown")
|
||||
action_counts[action] = action_counts.get(action, 0) + 1
|
||||
|
||||
outcome = event.get("outcome", "success")
|
||||
if outcome in outcome_counts:
|
||||
outcome_counts[outcome] += 1
|
||||
|
||||
if action == "security_finding_detected":
|
||||
security_findings += 1
|
||||
|
||||
return {
|
||||
"events": events,
|
||||
"statistics": {
|
||||
"total_events": len(events),
|
||||
"action_counts": action_counts,
|
||||
"outcome_counts": outcome_counts,
|
||||
"security_findings": security_findings,
|
||||
},
|
||||
"period": {
|
||||
"start": start_date.isoformat() if start_date else None,
|
||||
"end": end_date.isoformat() if end_date else None,
|
||||
},
|
||||
}
|
||||
314
tools/ai-review/compliance/codeowners.py
Normal file
314
tools/ai-review/compliance/codeowners.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""CODEOWNERS Checker
|
||||
|
||||
Parses and validates CODEOWNERS files for compliance enforcement.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeOwnerRule:
|
||||
"""A CODEOWNERS rule."""
|
||||
|
||||
pattern: str
|
||||
owners: list[str]
|
||||
line_number: int
|
||||
is_negation: bool = False
|
||||
|
||||
def matches(self, path: str) -> bool:
|
||||
"""Check if a path matches this rule.
|
||||
|
||||
Args:
|
||||
path: File path to check.
|
||||
|
||||
Returns:
|
||||
True if the path matches.
|
||||
"""
|
||||
path = path.lstrip("/")
|
||||
pattern = self.pattern.lstrip("/")
|
||||
|
||||
# Handle directory patterns
|
||||
if pattern.endswith("/"):
|
||||
return path.startswith(pattern) or fnmatch.fnmatch(path, pattern + "*")
|
||||
|
||||
# Handle ** patterns
|
||||
if "**" in pattern:
|
||||
regex = pattern.replace("**", ".*").replace("*", "[^/]*")
|
||||
return bool(re.match(f"^{regex}$", path))
|
||||
|
||||
# Standard fnmatch
|
||||
return fnmatch.fnmatch(path, pattern) or fnmatch.fnmatch(path, f"**/{pattern}")
|
||||
|
||||
|
||||
class CodeownersChecker:
|
||||
"""Checker for CODEOWNERS file compliance."""
|
||||
|
||||
CODEOWNERS_LOCATIONS = [
|
||||
"CODEOWNERS",
|
||||
".github/CODEOWNERS",
|
||||
".gitea/CODEOWNERS",
|
||||
"docs/CODEOWNERS",
|
||||
]
|
||||
|
||||
def __init__(self, repo_root: str | None = None):
|
||||
"""Initialize CODEOWNERS checker.
|
||||
|
||||
Args:
|
||||
repo_root: Repository root path.
|
||||
"""
|
||||
self.repo_root = repo_root or os.getcwd()
|
||||
self.rules: list[CodeOwnerRule] = []
|
||||
self.codeowners_path: str | None = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
self._load_codeowners()
|
||||
|
||||
def _load_codeowners(self):
|
||||
"""Load CODEOWNERS file from repository."""
|
||||
for location in self.CODEOWNERS_LOCATIONS:
|
||||
path = os.path.join(self.repo_root, location)
|
||||
if os.path.exists(path):
|
||||
self.codeowners_path = path
|
||||
self._parse_codeowners(path)
|
||||
break
|
||||
|
||||
def _parse_codeowners(self, path: str):
|
||||
"""Parse a CODEOWNERS file.
|
||||
|
||||
Args:
|
||||
path: Path to CODEOWNERS file.
|
||||
"""
|
||||
with open(path) as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Parse pattern and owners
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
|
||||
pattern = parts[0]
|
||||
owners = parts[1:]
|
||||
|
||||
# Check for negation (optional syntax)
|
||||
is_negation = pattern.startswith("!")
|
||||
if is_negation:
|
||||
pattern = pattern[1:]
|
||||
|
||||
self.rules.append(
|
||||
CodeOwnerRule(
|
||||
pattern=pattern,
|
||||
owners=owners,
|
||||
line_number=line_num,
|
||||
is_negation=is_negation,
|
||||
)
|
||||
)
|
||||
|
||||
def get_owners(self, path: str) -> list[str]:
|
||||
"""Get owners for a file path.
|
||||
|
||||
Args:
|
||||
path: File path to check.
|
||||
|
||||
Returns:
|
||||
List of owner usernames/teams.
|
||||
"""
|
||||
owners = []
|
||||
|
||||
# Apply rules in order (later rules override earlier ones)
|
||||
for rule in self.rules:
|
||||
if rule.matches(path):
|
||||
if rule.is_negation:
|
||||
owners = [] # Clear owners for negation
|
||||
else:
|
||||
owners = rule.owners
|
||||
|
||||
return owners
|
||||
|
||||
def get_owners_for_files(self, files: list[str]) -> dict[str, list[str]]:
|
||||
"""Get owners for multiple files.
|
||||
|
||||
Args:
|
||||
files: List of file paths.
|
||||
|
||||
Returns:
|
||||
Dict mapping file paths to owner lists.
|
||||
"""
|
||||
return {f: self.get_owners(f) for f in files}
|
||||
|
||||
def get_required_reviewers(self, files: list[str]) -> set[str]:
|
||||
"""Get all required reviewers for a set of files.
|
||||
|
||||
Args:
|
||||
files: List of file paths.
|
||||
|
||||
Returns:
|
||||
Set of all required reviewer usernames/teams.
|
||||
"""
|
||||
reviewers = set()
|
||||
for f in files:
|
||||
reviewers.update(self.get_owners(f))
|
||||
return reviewers
|
||||
|
||||
def check_approval(
|
||||
self,
|
||||
files: list[str],
|
||||
approvers: list[str],
|
||||
) -> dict:
|
||||
"""Check if files have required approvals.
|
||||
|
||||
Args:
|
||||
files: List of changed files.
|
||||
approvers: List of users who approved.
|
||||
|
||||
Returns:
|
||||
Dict with approval status and missing approvers.
|
||||
"""
|
||||
required = self.get_required_reviewers(files)
|
||||
approvers_set = set(approvers)
|
||||
|
||||
# Normalize @ prefixes
|
||||
required_normalized = {r.lstrip("@") for r in required}
|
||||
approvers_normalized = {a.lstrip("@") for a in approvers_set}
|
||||
|
||||
missing = required_normalized - approvers_normalized
|
||||
|
||||
# Check for team approvals (simplified - actual implementation
|
||||
# would need API calls to check team membership)
|
||||
teams = {r for r in missing if "/" in r}
|
||||
missing_users = missing - teams
|
||||
|
||||
return {
|
||||
"approved": len(missing_users) == 0,
|
||||
"required_reviewers": list(required_normalized),
|
||||
"actual_approvers": list(approvers_normalized),
|
||||
"missing_approvers": list(missing_users),
|
||||
"pending_teams": list(teams),
|
||||
}
|
||||
|
||||
def get_coverage_report(self, files: list[str]) -> dict:
|
||||
"""Generate a coverage report for files.
|
||||
|
||||
Args:
|
||||
files: List of file paths.
|
||||
|
||||
Returns:
|
||||
Coverage report with owned and unowned files.
|
||||
"""
|
||||
owned = []
|
||||
unowned = []
|
||||
|
||||
for f in files:
|
||||
owners = self.get_owners(f)
|
||||
if owners:
|
||||
owned.append({"file": f, "owners": owners})
|
||||
else:
|
||||
unowned.append(f)
|
||||
|
||||
return {
|
||||
"total_files": len(files),
|
||||
"owned_files": len(owned),
|
||||
"unowned_files": len(unowned),
|
||||
"coverage_percent": (len(owned) / len(files) * 100) if files else 0,
|
||||
"owned": owned,
|
||||
"unowned": unowned,
|
||||
}
|
||||
|
||||
def validate_codeowners(self) -> dict:
|
||||
"""Validate the CODEOWNERS file.
|
||||
|
||||
Returns:
|
||||
Validation result with warnings and errors.
|
||||
"""
|
||||
if not self.codeowners_path:
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": ["No CODEOWNERS file found"],
|
||||
"warnings": [],
|
||||
}
|
||||
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# Check for empty rules
|
||||
for rule in self.rules:
|
||||
if not rule.owners:
|
||||
errors.append(
|
||||
f"Line {rule.line_number}: Pattern '{rule.pattern}' has no owners"
|
||||
)
|
||||
|
||||
# Check for invalid owner formats
|
||||
for rule in self.rules:
|
||||
for owner in rule.owners:
|
||||
if not owner.startswith("@") and "/" not in owner:
|
||||
warnings.append(
|
||||
f"Line {rule.line_number}: Owner '{owner}' should start with @ or be a team (org/team)"
|
||||
)
|
||||
|
||||
# Check for overlapping patterns
|
||||
patterns_seen = {}
|
||||
for rule in self.rules:
|
||||
if rule.pattern in patterns_seen:
|
||||
warnings.append(
|
||||
f"Line {rule.line_number}: Pattern '{rule.pattern}' duplicates line {patterns_seen[rule.pattern]}"
|
||||
)
|
||||
patterns_seen[rule.pattern] = rule.line_number
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"rules_count": len(self.rules),
|
||||
"file_path": self.codeowners_path,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_content(cls, content: str) -> "CodeownersChecker":
|
||||
"""Create checker from CODEOWNERS content string.
|
||||
|
||||
Args:
|
||||
content: CODEOWNERS file content.
|
||||
|
||||
Returns:
|
||||
CodeownersChecker instance.
|
||||
"""
|
||||
checker = cls.__new__(cls)
|
||||
checker.repo_root = None
|
||||
checker.rules = []
|
||||
checker.codeowners_path = "<string>"
|
||||
checker.logger = logging.getLogger(__name__)
|
||||
|
||||
for line_num, line in enumerate(content.split("\n"), 1):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
|
||||
pattern = parts[0]
|
||||
owners = parts[1:]
|
||||
is_negation = pattern.startswith("!")
|
||||
if is_negation:
|
||||
pattern = pattern[1:]
|
||||
|
||||
checker.rules.append(
|
||||
CodeOwnerRule(
|
||||
pattern=pattern,
|
||||
owners=owners,
|
||||
line_number=line_num,
|
||||
is_negation=is_negation,
|
||||
)
|
||||
)
|
||||
|
||||
return checker
|
||||
@@ -1,233 +1,355 @@
|
||||
provider: openai # openai | openrouter | ollama
|
||||
# OpenRabbit AI Code Review Configuration
|
||||
# =========================================
|
||||
|
||||
# LLM Provider Configuration
|
||||
# --------------------------
|
||||
# Available providers: openai | openrouter | ollama | anthropic | azure | gemini
|
||||
provider: openai
|
||||
|
||||
model:
|
||||
openai: gpt-4.1-mini
|
||||
openrouter: anthropic/claude-3.5-sonnet
|
||||
ollama: codellama:13b
|
||||
openai: gpt-4.1-mini
|
||||
openrouter: anthropic/claude-3.5-sonnet
|
||||
ollama: codellama:13b
|
||||
anthropic: claude-3-5-sonnet-20241022
|
||||
azure: gpt-4 # Deployment name
|
||||
gemini: gemini-1.5-pro
|
||||
|
||||
temperature: 0
|
||||
max_tokens: 4096
|
||||
|
||||
# Azure OpenAI specific settings (when provider: azure)
|
||||
azure:
|
||||
endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var
|
||||
deployment: "" # Set via AZURE_OPENAI_DEPLOYMENT env var
|
||||
api_version: "2024-02-15-preview"
|
||||
|
||||
# Google Gemini specific settings (when provider: gemini)
|
||||
gemini:
|
||||
project: "" # For Vertex AI, set via GOOGLE_CLOUD_PROJECT env var
|
||||
region: "us-central1"
|
||||
|
||||
# Rate Limits and Timeouts
|
||||
# ------------------------
|
||||
rate_limits:
|
||||
min_interval: 1.0 # Minimum seconds between API requests
|
||||
|
||||
timeouts:
|
||||
llm: 120 # LLM API timeout in seconds (OpenAI, OpenRouter, Anthropic, etc.)
|
||||
ollama: 300 # Ollama timeout (longer for local models)
|
||||
gitea: 30 # Gitea/GitHub API timeout
|
||||
|
||||
# Review settings
|
||||
# ---------------
|
||||
review:
|
||||
fail_on_severity: HIGH
|
||||
max_diff_lines: 800
|
||||
fail_on_severity: HIGH
|
||||
max_diff_lines: 800
|
||||
inline_comments: true
|
||||
security_scan: true
|
||||
|
||||
# File Ignore Patterns
|
||||
# --------------------
|
||||
# Similar to .gitignore, controls which files are excluded from review
|
||||
ignore:
|
||||
use_defaults: true # Include default patterns (node_modules, .git, etc.)
|
||||
file: ".ai-reviewignore" # Custom ignore file name
|
||||
patterns: [] # Additional patterns to ignore
|
||||
|
||||
# Agent Configuration
|
||||
# -------------------
|
||||
agents:
|
||||
issue:
|
||||
enabled: true
|
||||
auto_label: true
|
||||
auto_triage: true
|
||||
duplicate_threshold: 0.85
|
||||
events:
|
||||
- opened
|
||||
- labeled
|
||||
pr:
|
||||
enabled: true
|
||||
inline_comments: true
|
||||
security_scan: true
|
||||
|
||||
# Agent settings
|
||||
agents:
|
||||
issue:
|
||||
enabled: true
|
||||
auto_label: true
|
||||
auto_triage: true
|
||||
duplicate_threshold: 0.85
|
||||
events:
|
||||
- opened
|
||||
- labeled
|
||||
pr:
|
||||
enabled: true
|
||||
inline_comments: true
|
||||
security_scan: true
|
||||
events:
|
||||
- opened
|
||||
- synchronize
|
||||
auto_summary:
|
||||
enabled: true # Auto-generate summary for PRs with empty descriptions
|
||||
post_as_comment: true # true = post as comment, false = update PR description
|
||||
codebase:
|
||||
enabled: true
|
||||
schedule: "0 0 * * 0" # Weekly on Sunday
|
||||
chat:
|
||||
enabled: true
|
||||
name: "Bartender"
|
||||
max_iterations: 5 # Max tool call iterations per chat
|
||||
tools:
|
||||
- search_codebase
|
||||
- read_file
|
||||
- search_web
|
||||
searxng_url: "" # Set via SEARXNG_URL env var or here
|
||||
|
||||
# Interaction settings
|
||||
# CUSTOMIZE YOUR BOT NAME HERE!
|
||||
# Change mention_prefix to your preferred bot name:
|
||||
# "@ai-bot" - Default
|
||||
# "@bartender" - Friendly bar theme
|
||||
# "@uni" - Short and simple
|
||||
# "@joey" - Personal assistant name
|
||||
# "@codebot" - Code-focused name
|
||||
# NOTE: Also update the workflow files (.github/workflows/ or .gitea/workflows/)
|
||||
# to match this prefix in the 'if: contains(...)' condition
|
||||
interaction:
|
||||
respond_to_mentions: true
|
||||
mention_prefix: "@codebot" # Change this to customize your bot's name!
|
||||
commands:
|
||||
- help
|
||||
- explain
|
||||
- suggest
|
||||
- security
|
||||
- summarize # Generate PR summary (works on both issues and PRs)
|
||||
- changelog # Generate Keep a Changelog format entries (PR comments only)
|
||||
- explain-diff # Explain code changes in plain language (PR comments only)
|
||||
- triage
|
||||
- review-again
|
||||
|
||||
# Enterprise settings
|
||||
enterprise:
|
||||
audit_log: true
|
||||
audit_path: "/var/log/ai-review/"
|
||||
metrics_enabled: true
|
||||
rate_limit:
|
||||
requests_per_minute: 30
|
||||
max_concurrent: 4
|
||||
|
||||
# Label mappings for auto-labeling
|
||||
# Each label has:
|
||||
# name: The label name to use/create (string) or full config (dict)
|
||||
# aliases: Alternative names for auto-detection (optional)
|
||||
# color: Hex color code without # (optional, for label creation)
|
||||
# description: Label description (optional, for label creation)
|
||||
labels:
|
||||
priority:
|
||||
critical:
|
||||
name: "priority: critical"
|
||||
color: "b60205" # Dark Red
|
||||
description: "Critical priority - immediate attention required"
|
||||
aliases:
|
||||
["Priority - Critical", "P0", "critical", "Priority/Critical"]
|
||||
high:
|
||||
name: "priority: high"
|
||||
color: "d73a4a" # Red
|
||||
description: "High priority issue"
|
||||
aliases: ["Priority - High", "P1", "high", "Priority/High"]
|
||||
medium:
|
||||
name: "priority: medium"
|
||||
color: "fbca04" # Yellow
|
||||
description: "Medium priority issue"
|
||||
aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"]
|
||||
low:
|
||||
name: "priority: low"
|
||||
color: "28a745" # Green
|
||||
description: "Low priority issue"
|
||||
aliases: ["Priority - Low", "P3", "low", "Priority/Low"]
|
||||
type:
|
||||
bug:
|
||||
name: "type: bug"
|
||||
color: "d73a4a" # Red
|
||||
description: "Something isn't working"
|
||||
aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"]
|
||||
feature:
|
||||
name: "type: feature"
|
||||
color: "1d76db" # Blue
|
||||
description: "New feature request"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Feature",
|
||||
"feature",
|
||||
"enhancement",
|
||||
"Kind/Enhancement",
|
||||
"Type: Feature",
|
||||
"Type/Feature",
|
||||
"Kind - Feature",
|
||||
]
|
||||
question:
|
||||
name: "type: question"
|
||||
color: "cc317c" # Purple
|
||||
description: "Further information is requested"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Question",
|
||||
"question",
|
||||
"Type: Question",
|
||||
"Type/Question",
|
||||
"Kind - Question",
|
||||
]
|
||||
docs:
|
||||
name: "type: documentation"
|
||||
color: "0075ca" # Light Blue
|
||||
description: "Documentation improvements"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Documentation",
|
||||
"documentation",
|
||||
"docs",
|
||||
"Type: Documentation",
|
||||
"Type/Documentation",
|
||||
"Kind - Documentation",
|
||||
]
|
||||
security:
|
||||
name: "type: security"
|
||||
color: "b60205" # Dark Red
|
||||
description: "Security vulnerability or concern"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Security",
|
||||
"security",
|
||||
"Type: Security",
|
||||
"Type/Security",
|
||||
"Kind - Security",
|
||||
]
|
||||
testing:
|
||||
name: "type: testing"
|
||||
color: "0e8a16" # Green
|
||||
description: "Related to testing"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Testing",
|
||||
"testing",
|
||||
"tests",
|
||||
"Type: Testing",
|
||||
"Type/Testing",
|
||||
"Kind - Testing",
|
||||
]
|
||||
status:
|
||||
ai_approved:
|
||||
name: "ai-approved"
|
||||
color: "28a745" # Green
|
||||
description: "AI review approved this PR"
|
||||
aliases:
|
||||
[
|
||||
"Status - Approved",
|
||||
"approved",
|
||||
"Status/Approved",
|
||||
"Status - AI Approved",
|
||||
]
|
||||
ai_changes_required:
|
||||
name: "ai-changes-required"
|
||||
color: "d73a4a" # Red
|
||||
description: "AI review found issues requiring changes"
|
||||
aliases:
|
||||
[
|
||||
"Status - Changes Required",
|
||||
"changes-required",
|
||||
"Status/Changes Required",
|
||||
"Status - AI Changes Required",
|
||||
]
|
||||
ai_reviewed:
|
||||
name: "ai-reviewed"
|
||||
color: "1d76db" # Blue
|
||||
description: "This issue/PR has been reviewed by AI"
|
||||
aliases:
|
||||
[
|
||||
"Reviewed - Confirmed",
|
||||
"reviewed",
|
||||
"Status/Reviewed",
|
||||
"Reviewed/Confirmed",
|
||||
"Status - Reviewed",
|
||||
]
|
||||
|
||||
# Label schema detection patterns
|
||||
# Used by setup-labels command to detect existing naming conventions
|
||||
label_patterns:
|
||||
# Detect prefix-based naming (e.g., Kind/Bug, Type/Feature)
|
||||
prefix_slash: "^(Kind|Type|Category)/(.+)$"
|
||||
# Detect dash-separated naming (e.g., Priority - High, Status - Blocked)
|
||||
prefix_dash: "^(Priority|Status|Reviewed) - (.+)$"
|
||||
# Detect colon-separated naming (e.g., type: bug, priority: high)
|
||||
colon: "^(type|priority|status): (.+)$"
|
||||
|
||||
# Security scanning rules
|
||||
security:
|
||||
events:
|
||||
- opened
|
||||
- synchronize
|
||||
auto_summary:
|
||||
enabled: true
|
||||
post_as_comment: true
|
||||
codebase:
|
||||
enabled: true
|
||||
fail_on_high: true
|
||||
rules_file: "security/security_rules.yml"
|
||||
schedule: "0 0 * * 0" # Weekly on Sunday
|
||||
chat:
|
||||
enabled: true
|
||||
name: "Bartender"
|
||||
max_iterations: 5
|
||||
tools:
|
||||
- search_codebase
|
||||
- read_file
|
||||
- search_web
|
||||
searxng_url: "" # Set via SEARXNG_URL env var
|
||||
|
||||
# Dependency Security Agent
|
||||
dependency:
|
||||
enabled: true
|
||||
scan_on_pr: true # Auto-scan PRs that modify dependency files
|
||||
vulnerability_threshold: "medium" # low | medium | high | critical
|
||||
update_suggestions: true # Suggest version updates
|
||||
|
||||
# Test Coverage Agent
|
||||
test_coverage:
|
||||
enabled: true
|
||||
suggest_tests: true
|
||||
min_coverage_percent: 80 # Warn if coverage below this
|
||||
|
||||
# Architecture Compliance Agent
|
||||
architecture:
|
||||
enabled: true
|
||||
layers:
|
||||
api:
|
||||
can_import_from: [utils, models, services]
|
||||
cannot_import_from: [db, repositories]
|
||||
services:
|
||||
can_import_from: [utils, models, repositories]
|
||||
cannot_import_from: [api]
|
||||
repositories:
|
||||
can_import_from: [utils, models, db]
|
||||
cannot_import_from: [api, services]
|
||||
|
||||
# Interaction Settings
|
||||
# --------------------
|
||||
# CUSTOMIZE YOUR BOT NAME HERE!
|
||||
interaction:
|
||||
respond_to_mentions: true
|
||||
mention_prefix: "@codebot"
|
||||
commands:
|
||||
- help
|
||||
- explain
|
||||
- suggest
|
||||
- security
|
||||
- summarize
|
||||
- changelog
|
||||
- explain-diff
|
||||
- triage
|
||||
- review-again
|
||||
# New commands
|
||||
- check-deps # Check dependencies for vulnerabilities
|
||||
- suggest-tests # Suggest test cases
|
||||
- refactor-suggest # Suggest refactoring opportunities
|
||||
- architecture # Check architecture compliance
|
||||
- arch-check # Alias for architecture
|
||||
|
||||
# Security Scanning
|
||||
# -----------------
|
||||
security:
|
||||
enabled: true
|
||||
fail_on_high: true
|
||||
rules_file: "security/security_rules.yml"
|
||||
|
||||
# SAST Integration
|
||||
sast:
|
||||
enabled: true
|
||||
bandit: true # Python AST-based security scanner
|
||||
semgrep: true # Polyglot security scanner with custom rules
|
||||
trivy: false # Container/filesystem scanner (requires trivy installed)
|
||||
|
||||
# Notifications
|
||||
# -------------
|
||||
notifications:
|
||||
enabled: false
|
||||
threshold: "warning" # info | warning | error | critical
|
||||
|
||||
slack:
|
||||
enabled: false
|
||||
webhook_url: "" # Set via SLACK_WEBHOOK_URL env var
|
||||
channel: "" # Override channel (optional)
|
||||
username: "OpenRabbit"
|
||||
|
||||
discord:
|
||||
enabled: false
|
||||
webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var
|
||||
username: "OpenRabbit"
|
||||
avatar_url: ""
|
||||
|
||||
# Custom webhooks for other integrations
|
||||
webhooks: []
|
||||
# Example:
|
||||
# - url: "https://your-webhook.example.com/notify"
|
||||
# enabled: true
|
||||
# headers:
|
||||
# Authorization: "Bearer your-token"
|
||||
|
||||
# Compliance & Audit
|
||||
# ------------------
|
||||
compliance:
|
||||
enabled: false
|
||||
|
||||
# Audit Trail
|
||||
audit:
|
||||
enabled: false
|
||||
log_file: "audit.log"
|
||||
log_to_stdout: false
|
||||
retention_days: 90
|
||||
|
||||
# CODEOWNERS Enforcement
|
||||
codeowners:
|
||||
enabled: false
|
||||
require_approval: true # Require approval from code owners
|
||||
|
||||
# Regulatory Compliance
|
||||
regulations:
|
||||
hipaa: false
|
||||
soc2: false
|
||||
pci_dss: false
|
||||
gdpr: false
|
||||
|
||||
# Enterprise Settings
|
||||
# -------------------
|
||||
enterprise:
|
||||
audit_log: true
|
||||
audit_path: "/var/log/ai-review/"
|
||||
metrics_enabled: true
|
||||
rate_limit:
|
||||
requests_per_minute: 30
|
||||
max_concurrent: 4
|
||||
|
||||
# Label Mappings
|
||||
# --------------
|
||||
# Each label has:
|
||||
# name: The label name to use/create
|
||||
# aliases: Alternative names for auto-detection
|
||||
# color: Hex color code without #
|
||||
# description: Label description
|
||||
labels:
|
||||
priority:
|
||||
critical:
|
||||
name: "priority: critical"
|
||||
color: "b60205"
|
||||
description: "Critical priority - immediate attention required"
|
||||
aliases: ["Priority - Critical", "P0", "critical", "Priority/Critical"]
|
||||
high:
|
||||
name: "priority: high"
|
||||
color: "d73a4a"
|
||||
description: "High priority issue"
|
||||
aliases: ["Priority - High", "P1", "high", "Priority/High"]
|
||||
medium:
|
||||
name: "priority: medium"
|
||||
color: "fbca04"
|
||||
description: "Medium priority issue"
|
||||
aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"]
|
||||
low:
|
||||
name: "priority: low"
|
||||
color: "28a745"
|
||||
description: "Low priority issue"
|
||||
aliases: ["Priority - Low", "P3", "low", "Priority/Low"]
|
||||
type:
|
||||
bug:
|
||||
name: "type: bug"
|
||||
color: "d73a4a"
|
||||
description: "Something isn't working"
|
||||
aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"]
|
||||
feature:
|
||||
name: "type: feature"
|
||||
color: "1d76db"
|
||||
description: "New feature request"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Feature",
|
||||
"feature",
|
||||
"enhancement",
|
||||
"Kind/Enhancement",
|
||||
"Type: Feature",
|
||||
"Type/Feature",
|
||||
"Kind - Feature",
|
||||
]
|
||||
question:
|
||||
name: "type: question"
|
||||
color: "cc317c"
|
||||
description: "Further information is requested"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Question",
|
||||
"question",
|
||||
"Type: Question",
|
||||
"Type/Question",
|
||||
"Kind - Question",
|
||||
]
|
||||
docs:
|
||||
name: "type: documentation"
|
||||
color: "0075ca"
|
||||
description: "Documentation improvements"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Documentation",
|
||||
"documentation",
|
||||
"docs",
|
||||
"Type: Documentation",
|
||||
"Type/Documentation",
|
||||
"Kind - Documentation",
|
||||
]
|
||||
security:
|
||||
name: "type: security"
|
||||
color: "b60205"
|
||||
description: "Security vulnerability or concern"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Security",
|
||||
"security",
|
||||
"Type: Security",
|
||||
"Type/Security",
|
||||
"Kind - Security",
|
||||
]
|
||||
testing:
|
||||
name: "type: testing"
|
||||
color: "0e8a16"
|
||||
description: "Related to testing"
|
||||
aliases:
|
||||
[
|
||||
"Kind/Testing",
|
||||
"testing",
|
||||
"tests",
|
||||
"Type: Testing",
|
||||
"Type/Testing",
|
||||
"Kind - Testing",
|
||||
]
|
||||
status:
|
||||
ai_approved:
|
||||
name: "ai-approved"
|
||||
color: "28a745"
|
||||
description: "AI review approved this PR"
|
||||
aliases:
|
||||
[
|
||||
"Status - Approved",
|
||||
"approved",
|
||||
"Status/Approved",
|
||||
"Status - AI Approved",
|
||||
]
|
||||
ai_changes_required:
|
||||
name: "ai-changes-required"
|
||||
color: "d73a4a"
|
||||
description: "AI review found issues requiring changes"
|
||||
aliases:
|
||||
[
|
||||
"Status - Changes Required",
|
||||
"changes-required",
|
||||
"Status/Changes Required",
|
||||
"Status - AI Changes Required",
|
||||
]
|
||||
ai_reviewed:
|
||||
name: "ai-reviewed"
|
||||
color: "1d76db"
|
||||
description: "This issue/PR has been reviewed by AI"
|
||||
aliases:
|
||||
[
|
||||
"Reviewed - Confirmed",
|
||||
"reviewed",
|
||||
"Status/Reviewed",
|
||||
"Reviewed/Confirmed",
|
||||
"Status - Reviewed",
|
||||
]
|
||||
|
||||
# Label Pattern Detection
|
||||
# -----------------------
|
||||
label_patterns:
|
||||
prefix_slash: "^(Kind|Type|Category)/(.+)$"
|
||||
prefix_dash: "^(Priority|Status|Reviewed) - (.+)$"
|
||||
colon: "^(type|priority|status): (.+)$"
|
||||
|
||||
20
tools/ai-review/notifications/__init__.py
Normal file
20
tools/ai-review/notifications/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Notifications Package
|
||||
|
||||
Provides webhook-based notifications for Slack, Discord, and other platforms.
|
||||
"""
|
||||
|
||||
from notifications.notifier import (
|
||||
DiscordNotifier,
|
||||
Notifier,
|
||||
NotifierFactory,
|
||||
SlackNotifier,
|
||||
WebhookNotifier,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Notifier",
|
||||
"SlackNotifier",
|
||||
"DiscordNotifier",
|
||||
"WebhookNotifier",
|
||||
"NotifierFactory",
|
||||
]
|
||||
542
tools/ai-review/notifications/notifier.py
Normal file
542
tools/ai-review/notifications/notifier.py
Normal file
@@ -0,0 +1,542 @@
|
||||
"""Notification System
|
||||
|
||||
Provides webhook-based notifications for Slack, Discord, and other platforms.
|
||||
Supports critical security findings, review summaries, and custom alerts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class NotificationLevel(Enum):
|
||||
"""Notification severity levels."""
|
||||
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationMessage:
|
||||
"""A notification message."""
|
||||
|
||||
title: str
|
||||
message: str
|
||||
level: NotificationLevel = NotificationLevel.INFO
|
||||
fields: dict[str, str] | None = None
|
||||
url: str | None = None
|
||||
footer: str | None = None
|
||||
|
||||
|
||||
class Notifier(ABC):
|
||||
"""Abstract base class for notification providers."""
|
||||
|
||||
@abstractmethod
|
||||
def send(self, message: NotificationMessage) -> bool:
|
||||
"""Send a notification.
|
||||
|
||||
Args:
|
||||
message: The notification message to send.
|
||||
|
||||
Returns:
|
||||
True if sent successfully, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send_raw(self, payload: dict) -> bool:
|
||||
"""Send a raw payload to the webhook.
|
||||
|
||||
Args:
|
||||
payload: Raw payload in provider-specific format.
|
||||
|
||||
Returns:
|
||||
True if sent successfully, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class SlackNotifier(Notifier):
|
||||
"""Slack webhook notifier."""
|
||||
|
||||
# Color mapping for different levels
|
||||
LEVEL_COLORS = {
|
||||
NotificationLevel.INFO: "#36a64f", # Green
|
||||
NotificationLevel.WARNING: "#ffcc00", # Yellow
|
||||
NotificationLevel.ERROR: "#ff6600", # Orange
|
||||
NotificationLevel.CRITICAL: "#cc0000", # Red
|
||||
}
|
||||
|
||||
LEVEL_EMOJIS = {
|
||||
NotificationLevel.INFO: ":information_source:",
|
||||
NotificationLevel.WARNING: ":warning:",
|
||||
NotificationLevel.ERROR: ":x:",
|
||||
NotificationLevel.CRITICAL: ":rotating_light:",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
webhook_url: str | None = None,
|
||||
channel: str | None = None,
|
||||
username: str = "OpenRabbit",
|
||||
icon_emoji: str = ":robot_face:",
|
||||
timeout: int = 10,
|
||||
):
|
||||
"""Initialize Slack notifier.
|
||||
|
||||
Args:
|
||||
webhook_url: Slack incoming webhook URL.
|
||||
Defaults to SLACK_WEBHOOK_URL env var.
|
||||
channel: Override channel (optional).
|
||||
username: Bot username to display.
|
||||
icon_emoji: Bot icon emoji.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
self.webhook_url = webhook_url or os.environ.get("SLACK_WEBHOOK_URL", "")
|
||||
self.channel = channel
|
||||
self.username = username
|
||||
self.icon_emoji = icon_emoji
|
||||
self.timeout = timeout
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def send(self, message: NotificationMessage) -> bool:
|
||||
"""Send a notification to Slack."""
|
||||
if not self.webhook_url:
|
||||
self.logger.warning("Slack webhook URL not configured")
|
||||
return False
|
||||
|
||||
# Build attachment
|
||||
attachment = {
|
||||
"color": self.LEVEL_COLORS.get(message.level, "#36a64f"),
|
||||
"title": f"{self.LEVEL_EMOJIS.get(message.level, '')} {message.title}",
|
||||
"text": message.message,
|
||||
"mrkdwn_in": ["text", "fields"],
|
||||
}
|
||||
|
||||
if message.url:
|
||||
attachment["title_link"] = message.url
|
||||
|
||||
if message.fields:
|
||||
attachment["fields"] = [
|
||||
{"title": k, "value": v, "short": len(v) < 40}
|
||||
for k, v in message.fields.items()
|
||||
]
|
||||
|
||||
if message.footer:
|
||||
attachment["footer"] = message.footer
|
||||
attachment["footer_icon"] = "https://github.com/favicon.ico"
|
||||
|
||||
payload = {
|
||||
"username": self.username,
|
||||
"icon_emoji": self.icon_emoji,
|
||||
"attachments": [attachment],
|
||||
}
|
||||
|
||||
if self.channel:
|
||||
payload["channel"] = self.channel
|
||||
|
||||
return self.send_raw(payload)
|
||||
|
||||
def send_raw(self, payload: dict) -> bool:
|
||||
"""Send raw payload to Slack webhook."""
|
||||
if not self.webhook_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
self.logger.error(f"Failed to send Slack notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class DiscordNotifier(Notifier):
|
||||
"""Discord webhook notifier."""
|
||||
|
||||
# Color mapping for different levels (Discord uses decimal colors)
|
||||
LEVEL_COLORS = {
|
||||
NotificationLevel.INFO: 3066993, # Green
|
||||
NotificationLevel.WARNING: 16776960, # Yellow
|
||||
NotificationLevel.ERROR: 16744448, # Orange
|
||||
NotificationLevel.CRITICAL: 13369344, # Red
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
webhook_url: str | None = None,
|
||||
username: str = "OpenRabbit",
|
||||
avatar_url: str | None = None,
|
||||
timeout: int = 10,
|
||||
):
|
||||
"""Initialize Discord notifier.
|
||||
|
||||
Args:
|
||||
webhook_url: Discord webhook URL.
|
||||
Defaults to DISCORD_WEBHOOK_URL env var.
|
||||
username: Bot username to display.
|
||||
avatar_url: Bot avatar URL.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
self.webhook_url = webhook_url or os.environ.get("DISCORD_WEBHOOK_URL", "")
|
||||
self.username = username
|
||||
self.avatar_url = avatar_url
|
||||
self.timeout = timeout
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def send(self, message: NotificationMessage) -> bool:
|
||||
"""Send a notification to Discord."""
|
||||
if not self.webhook_url:
|
||||
self.logger.warning("Discord webhook URL not configured")
|
||||
return False
|
||||
|
||||
# Build embed
|
||||
embed = {
|
||||
"title": message.title,
|
||||
"description": message.message,
|
||||
"color": self.LEVEL_COLORS.get(message.level, 3066993),
|
||||
}
|
||||
|
||||
if message.url:
|
||||
embed["url"] = message.url
|
||||
|
||||
if message.fields:
|
||||
embed["fields"] = [
|
||||
{"name": k, "value": v, "inline": len(v) < 40}
|
||||
for k, v in message.fields.items()
|
||||
]
|
||||
|
||||
if message.footer:
|
||||
embed["footer"] = {"text": message.footer}
|
||||
|
||||
payload = {
|
||||
"username": self.username,
|
||||
"embeds": [embed],
|
||||
}
|
||||
|
||||
if self.avatar_url:
|
||||
payload["avatar_url"] = self.avatar_url
|
||||
|
||||
return self.send_raw(payload)
|
||||
|
||||
def send_raw(self, payload: dict) -> bool:
|
||||
"""Send raw payload to Discord webhook."""
|
||||
if not self.webhook_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
self.logger.error(f"Failed to send Discord notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class WebhookNotifier(Notifier):
|
||||
"""Generic webhook notifier for custom integrations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
webhook_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 10,
|
||||
):
|
||||
"""Initialize generic webhook notifier.
|
||||
|
||||
Args:
|
||||
webhook_url: Webhook URL.
|
||||
headers: Custom headers to include.
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
self.webhook_url = webhook_url
|
||||
self.headers = headers or {"Content-Type": "application/json"}
|
||||
self.timeout = timeout
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def send(self, message: NotificationMessage) -> bool:
|
||||
"""Send a notification as JSON payload."""
|
||||
payload = {
|
||||
"title": message.title,
|
||||
"message": message.message,
|
||||
"level": message.level.value,
|
||||
"fields": message.fields or {},
|
||||
"url": message.url,
|
||||
"footer": message.footer,
|
||||
}
|
||||
return self.send_raw(payload)
|
||||
|
||||
def send_raw(self, payload: dict) -> bool:
|
||||
"""Send raw payload to webhook."""
|
||||
try:
|
||||
response = requests.post(
|
||||
self.webhook_url,
|
||||
json=payload,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
self.logger.error(f"Failed to send webhook notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class NotifierFactory:
|
||||
"""Factory for creating notifier instances from config."""
|
||||
|
||||
@staticmethod
|
||||
def create_from_config(config: dict) -> list[Notifier]:
|
||||
"""Create notifier instances from configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with 'notifications' section.
|
||||
|
||||
Returns:
|
||||
List of configured notifier instances.
|
||||
"""
|
||||
notifiers = []
|
||||
notifications_config = config.get("notifications", {})
|
||||
|
||||
if not notifications_config.get("enabled", False):
|
||||
return notifiers
|
||||
|
||||
# Slack
|
||||
slack_config = notifications_config.get("slack", {})
|
||||
if slack_config.get("enabled", False):
|
||||
webhook_url = slack_config.get("webhook_url") or os.environ.get(
|
||||
"SLACK_WEBHOOK_URL"
|
||||
)
|
||||
if webhook_url:
|
||||
notifiers.append(
|
||||
SlackNotifier(
|
||||
webhook_url=webhook_url,
|
||||
channel=slack_config.get("channel"),
|
||||
username=slack_config.get("username", "OpenRabbit"),
|
||||
)
|
||||
)
|
||||
|
||||
# Discord
|
||||
discord_config = notifications_config.get("discord", {})
|
||||
if discord_config.get("enabled", False):
|
||||
webhook_url = discord_config.get("webhook_url") or os.environ.get(
|
||||
"DISCORD_WEBHOOK_URL"
|
||||
)
|
||||
if webhook_url:
|
||||
notifiers.append(
|
||||
DiscordNotifier(
|
||||
webhook_url=webhook_url,
|
||||
username=discord_config.get("username", "OpenRabbit"),
|
||||
avatar_url=discord_config.get("avatar_url"),
|
||||
)
|
||||
)
|
||||
|
||||
# Generic webhooks
|
||||
webhooks_config = notifications_config.get("webhooks", [])
|
||||
for webhook in webhooks_config:
|
||||
if webhook.get("enabled", True) and webhook.get("url"):
|
||||
notifiers.append(
|
||||
WebhookNotifier(
|
||||
webhook_url=webhook["url"],
|
||||
headers=webhook.get("headers"),
|
||||
)
|
||||
)
|
||||
|
||||
return notifiers
|
||||
|
||||
@staticmethod
|
||||
def should_notify(
|
||||
level: NotificationLevel,
|
||||
config: dict,
|
||||
) -> bool:
|
||||
"""Check if notification should be sent based on level threshold.
|
||||
|
||||
Args:
|
||||
level: Notification level.
|
||||
config: Configuration dictionary.
|
||||
|
||||
Returns:
|
||||
True if notification should be sent.
|
||||
"""
|
||||
notifications_config = config.get("notifications", {})
|
||||
if not notifications_config.get("enabled", False):
|
||||
return False
|
||||
|
||||
threshold = notifications_config.get("threshold", "warning")
|
||||
level_order = ["info", "warning", "error", "critical"]
|
||||
|
||||
try:
|
||||
threshold_idx = level_order.index(threshold)
|
||||
level_idx = level_order.index(level.value)
|
||||
return level_idx >= threshold_idx
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
|
||||
class NotificationService:
|
||||
"""High-level notification service for sending alerts."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""Initialize notification service.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary.
|
||||
"""
|
||||
self.config = config
|
||||
self.notifiers = NotifierFactory.create_from_config(config)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def notify(self, message: NotificationMessage) -> bool:
|
||||
"""Send notification to all configured notifiers.
|
||||
|
||||
Args:
|
||||
message: Notification message.
|
||||
|
||||
Returns:
|
||||
True if at least one notification succeeded.
|
||||
"""
|
||||
if not NotifierFactory.should_notify(message.level, self.config):
|
||||
return True # Not an error, just below threshold
|
||||
|
||||
if not self.notifiers:
|
||||
self.logger.debug("No notifiers configured")
|
||||
return True
|
||||
|
||||
success = False
|
||||
for notifier in self.notifiers:
|
||||
try:
|
||||
if notifier.send(message):
|
||||
success = True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Notifier failed: {e}")
|
||||
|
||||
return success
|
||||
|
||||
def notify_security_finding(
|
||||
self,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
finding: dict,
|
||||
pr_url: str | None = None,
|
||||
) -> bool:
|
||||
"""Send notification for a security finding.
|
||||
|
||||
Args:
|
||||
repo: Repository name (owner/repo).
|
||||
pr_number: Pull request number.
|
||||
finding: Security finding dict with severity, description, etc.
|
||||
pr_url: URL to the pull request.
|
||||
|
||||
Returns:
|
||||
True if notification succeeded.
|
||||
"""
|
||||
severity = finding.get("severity", "MEDIUM").upper()
|
||||
level_map = {
|
||||
"HIGH": NotificationLevel.CRITICAL,
|
||||
"MEDIUM": NotificationLevel.WARNING,
|
||||
"LOW": NotificationLevel.INFO,
|
||||
}
|
||||
level = level_map.get(severity, NotificationLevel.WARNING)
|
||||
|
||||
message = NotificationMessage(
|
||||
title=f"Security Finding in {repo} PR #{pr_number}",
|
||||
message=finding.get("description", "Security issue detected"),
|
||||
level=level,
|
||||
fields={
|
||||
"Severity": severity,
|
||||
"Category": finding.get("category", "Unknown"),
|
||||
"File": finding.get("file", "N/A"),
|
||||
"Line": str(finding.get("line", "N/A")),
|
||||
},
|
||||
url=pr_url,
|
||||
footer=f"CWE: {finding.get('cwe', 'N/A')}",
|
||||
)
|
||||
|
||||
return self.notify(message)
|
||||
|
||||
def notify_review_complete(
|
||||
self,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
summary: dict,
|
||||
pr_url: str | None = None,
|
||||
) -> bool:
|
||||
"""Send notification for completed review.
|
||||
|
||||
Args:
|
||||
repo: Repository name (owner/repo).
|
||||
pr_number: Pull request number.
|
||||
summary: Review summary dict.
|
||||
pr_url: URL to the pull request.
|
||||
|
||||
Returns:
|
||||
True if notification succeeded.
|
||||
"""
|
||||
recommendation = summary.get("recommendation", "COMMENT")
|
||||
level_map = {
|
||||
"APPROVE": NotificationLevel.INFO,
|
||||
"COMMENT": NotificationLevel.INFO,
|
||||
"REQUEST_CHANGES": NotificationLevel.WARNING,
|
||||
}
|
||||
level = level_map.get(recommendation, NotificationLevel.INFO)
|
||||
|
||||
high_issues = summary.get("high_severity_count", 0)
|
||||
if high_issues > 0:
|
||||
level = NotificationLevel.ERROR
|
||||
|
||||
message = NotificationMessage(
|
||||
title=f"Review Complete: {repo} PR #{pr_number}",
|
||||
message=summary.get("summary", "AI review completed"),
|
||||
level=level,
|
||||
fields={
|
||||
"Recommendation": recommendation,
|
||||
"High Severity": str(high_issues),
|
||||
"Medium Severity": str(summary.get("medium_severity_count", 0)),
|
||||
"Files Reviewed": str(summary.get("files_reviewed", 0)),
|
||||
},
|
||||
url=pr_url,
|
||||
footer="OpenRabbit AI Review",
|
||||
)
|
||||
|
||||
return self.notify(message)
|
||||
|
||||
def notify_error(
|
||||
self,
|
||||
repo: str,
|
||||
error: str,
|
||||
context: str | None = None,
|
||||
) -> bool:
|
||||
"""Send notification for an error.
|
||||
|
||||
Args:
|
||||
repo: Repository name.
|
||||
error: Error message.
|
||||
context: Additional context.
|
||||
|
||||
Returns:
|
||||
True if notification succeeded.
|
||||
"""
|
||||
message = NotificationMessage(
|
||||
title=f"Error in {repo}",
|
||||
message=error,
|
||||
level=NotificationLevel.ERROR,
|
||||
fields={"Context": context} if context else None,
|
||||
footer="OpenRabbit Error Alert",
|
||||
)
|
||||
|
||||
return self.notify(message)
|
||||
431
tools/ai-review/security/sast_scanner.py
Normal file
431
tools/ai-review/security/sast_scanner.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""SAST Scanner Integration
|
||||
|
||||
Integrates with external SAST tools like Bandit and Semgrep
|
||||
to provide comprehensive security analysis.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SASTFinding:
|
||||
"""A finding from a SAST tool."""
|
||||
|
||||
tool: str
|
||||
rule_id: str
|
||||
severity: str # CRITICAL, HIGH, MEDIUM, LOW
|
||||
file: str
|
||||
line: int
|
||||
message: str
|
||||
code_snippet: str | None = None
|
||||
cwe: str | None = None
|
||||
owasp: str | None = None
|
||||
fix_recommendation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SASTReport:
|
||||
"""Combined report from all SAST tools."""
|
||||
|
||||
total_findings: int
|
||||
findings_by_severity: dict[str, int]
|
||||
findings_by_tool: dict[str, int]
|
||||
findings: list[SASTFinding]
|
||||
tools_run: list[str]
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class SASTScanner:
|
||||
"""Aggregator for multiple SAST tools."""
|
||||
|
||||
def __init__(self, config: dict | None = None):
|
||||
"""Initialize the SAST scanner.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with tool settings.
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def scan_directory(self, path: str) -> SASTReport:
|
||||
"""Scan a directory with all enabled SAST tools.
|
||||
|
||||
Args:
|
||||
path: Path to the directory to scan.
|
||||
|
||||
Returns:
|
||||
Combined SASTReport from all tools.
|
||||
"""
|
||||
all_findings = []
|
||||
tools_run = []
|
||||
errors = []
|
||||
|
||||
sast_config = self.config.get("security", {}).get("sast", {})
|
||||
|
||||
# Run Bandit (Python)
|
||||
if sast_config.get("bandit", True):
|
||||
if self._is_tool_available("bandit"):
|
||||
try:
|
||||
findings = self._run_bandit(path)
|
||||
all_findings.extend(findings)
|
||||
tools_run.append("bandit")
|
||||
except Exception as e:
|
||||
errors.append(f"Bandit error: {e}")
|
||||
else:
|
||||
self.logger.debug("Bandit not installed, skipping")
|
||||
|
||||
# Run Semgrep
|
||||
if sast_config.get("semgrep", True):
|
||||
if self._is_tool_available("semgrep"):
|
||||
try:
|
||||
findings = self._run_semgrep(path)
|
||||
all_findings.extend(findings)
|
||||
tools_run.append("semgrep")
|
||||
except Exception as e:
|
||||
errors.append(f"Semgrep error: {e}")
|
||||
else:
|
||||
self.logger.debug("Semgrep not installed, skipping")
|
||||
|
||||
# Run Trivy (if enabled for filesystem scanning)
|
||||
if sast_config.get("trivy", False):
|
||||
if self._is_tool_available("trivy"):
|
||||
try:
|
||||
findings = self._run_trivy(path)
|
||||
all_findings.extend(findings)
|
||||
tools_run.append("trivy")
|
||||
except Exception as e:
|
||||
errors.append(f"Trivy error: {e}")
|
||||
else:
|
||||
self.logger.debug("Trivy not installed, skipping")
|
||||
|
||||
# Calculate statistics
|
||||
by_severity = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0}
|
||||
by_tool = {}
|
||||
|
||||
for finding in all_findings:
|
||||
sev = finding.severity.upper()
|
||||
if sev in by_severity:
|
||||
by_severity[sev] += 1
|
||||
tool = finding.tool
|
||||
by_tool[tool] = by_tool.get(tool, 0) + 1
|
||||
|
||||
return SASTReport(
|
||||
total_findings=len(all_findings),
|
||||
findings_by_severity=by_severity,
|
||||
findings_by_tool=by_tool,
|
||||
findings=all_findings,
|
||||
tools_run=tools_run,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
def scan_content(self, content: str, filename: str) -> list[SASTFinding]:
|
||||
"""Scan file content with SAST tools.
|
||||
|
||||
Args:
|
||||
content: File content to scan.
|
||||
filename: Name of the file (for language detection).
|
||||
|
||||
Returns:
|
||||
List of SASTFinding objects.
|
||||
"""
|
||||
# Create temporary file for scanning
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=os.path.splitext(filename)[1],
|
||||
delete=False,
|
||||
) as f:
|
||||
f.write(content)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
report = self.scan_directory(os.path.dirname(temp_path))
|
||||
# Filter findings for our specific file
|
||||
findings = [
|
||||
f
|
||||
for f in report.findings
|
||||
if os.path.basename(f.file) == os.path.basename(temp_path)
|
||||
]
|
||||
# Update file path to original filename
|
||||
for finding in findings:
|
||||
finding.file = filename
|
||||
return findings
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def scan_diff(self, diff: str) -> list[SASTFinding]:
|
||||
"""Scan a diff for security issues.
|
||||
|
||||
Only scans added/modified lines.
|
||||
|
||||
Args:
|
||||
diff: Git diff content.
|
||||
|
||||
Returns:
|
||||
List of SASTFinding objects.
|
||||
"""
|
||||
findings = []
|
||||
|
||||
# Parse diff and extract added content per file
|
||||
files_content = {}
|
||||
current_file = None
|
||||
current_content = []
|
||||
|
||||
for line in diff.splitlines():
|
||||
if line.startswith("diff --git"):
|
||||
if current_file and current_content:
|
||||
files_content[current_file] = "\n".join(current_content)
|
||||
current_file = None
|
||||
current_content = []
|
||||
# Extract filename
|
||||
match = line.split(" b/")
|
||||
if len(match) > 1:
|
||||
current_file = match[1]
|
||||
elif line.startswith("+") and not line.startswith("+++"):
|
||||
if current_file:
|
||||
current_content.append(line[1:]) # Remove + prefix
|
||||
|
||||
# Don't forget last file
|
||||
if current_file and current_content:
|
||||
files_content[current_file] = "\n".join(current_content)
|
||||
|
||||
# Scan each file's content
|
||||
for filename, content in files_content.items():
|
||||
if content.strip():
|
||||
file_findings = self.scan_content(content, filename)
|
||||
findings.extend(file_findings)
|
||||
|
||||
return findings
|
||||
|
||||
def _is_tool_available(self, tool: str) -> bool:
|
||||
"""Check if a tool is installed and available."""
|
||||
return shutil.which(tool) is not None
|
||||
|
||||
def _run_bandit(self, path: str) -> list[SASTFinding]:
|
||||
"""Run Bandit security scanner.
|
||||
|
||||
Args:
|
||||
path: Path to scan.
|
||||
|
||||
Returns:
|
||||
List of SASTFinding objects.
|
||||
"""
|
||||
findings = []
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"bandit",
|
||||
"-r",
|
||||
path,
|
||||
"-f",
|
||||
"json",
|
||||
"-ll", # Only high and medium severity
|
||||
"--quiet",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.stdout:
|
||||
data = json.loads(result.stdout)
|
||||
|
||||
for issue in data.get("results", []):
|
||||
severity = issue.get("issue_severity", "MEDIUM").upper()
|
||||
|
||||
findings.append(
|
||||
SASTFinding(
|
||||
tool="bandit",
|
||||
rule_id=issue.get("test_id", ""),
|
||||
severity=severity,
|
||||
file=issue.get("filename", ""),
|
||||
line=issue.get("line_number", 0),
|
||||
message=issue.get("issue_text", ""),
|
||||
code_snippet=issue.get("code", ""),
|
||||
cwe=f"CWE-{issue.get('issue_cwe', {}).get('id', '')}"
|
||||
if issue.get("issue_cwe")
|
||||
else None,
|
||||
fix_recommendation=issue.get("more_info", ""),
|
||||
)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.warning("Bandit scan timed out")
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.warning(f"Failed to parse Bandit output: {e}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Bandit scan failed: {e}")
|
||||
|
||||
return findings
|
||||
|
||||
def _run_semgrep(self, path: str) -> list[SASTFinding]:
|
||||
"""Run Semgrep security scanner.
|
||||
|
||||
Args:
|
||||
path: Path to scan.
|
||||
|
||||
Returns:
|
||||
List of SASTFinding objects.
|
||||
"""
|
||||
findings = []
|
||||
|
||||
# Get Semgrep config from settings
|
||||
sast_config = self.config.get("security", {}).get("sast", {})
|
||||
semgrep_rules = sast_config.get("semgrep_rules", "p/security-audit")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"semgrep",
|
||||
"--config",
|
||||
semgrep_rules,
|
||||
"--json",
|
||||
"--quiet",
|
||||
path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
if result.stdout:
|
||||
data = json.loads(result.stdout)
|
||||
|
||||
for finding in data.get("results", []):
|
||||
# Map Semgrep severity to our scale
|
||||
sev_map = {
|
||||
"ERROR": "HIGH",
|
||||
"WARNING": "MEDIUM",
|
||||
"INFO": "LOW",
|
||||
}
|
||||
severity = sev_map.get(
|
||||
finding.get("extra", {}).get("severity", "WARNING"), "MEDIUM"
|
||||
)
|
||||
|
||||
metadata = finding.get("extra", {}).get("metadata", {})
|
||||
|
||||
findings.append(
|
||||
SASTFinding(
|
||||
tool="semgrep",
|
||||
rule_id=finding.get("check_id", ""),
|
||||
severity=severity,
|
||||
file=finding.get("path", ""),
|
||||
line=finding.get("start", {}).get("line", 0),
|
||||
message=finding.get("extra", {}).get("message", ""),
|
||||
code_snippet=finding.get("extra", {}).get("lines", ""),
|
||||
cwe=metadata.get("cwe", [None])[0]
|
||||
if metadata.get("cwe")
|
||||
else None,
|
||||
owasp=metadata.get("owasp", [None])[0]
|
||||
if metadata.get("owasp")
|
||||
else None,
|
||||
fix_recommendation=metadata.get("fix", ""),
|
||||
)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.warning("Semgrep scan timed out")
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.warning(f"Failed to parse Semgrep output: {e}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Semgrep scan failed: {e}")
|
||||
|
||||
return findings
|
||||
|
||||
def _run_trivy(self, path: str) -> list[SASTFinding]:
|
||||
"""Run Trivy filesystem scanner.
|
||||
|
||||
Args:
|
||||
path: Path to scan.
|
||||
|
||||
Returns:
|
||||
List of SASTFinding objects.
|
||||
"""
|
||||
findings = []
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"trivy",
|
||||
"fs",
|
||||
"--format",
|
||||
"json",
|
||||
"--security-checks",
|
||||
"vuln,secret,config",
|
||||
path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
if result.stdout:
|
||||
data = json.loads(result.stdout)
|
||||
|
||||
for result_item in data.get("Results", []):
|
||||
target = result_item.get("Target", "")
|
||||
|
||||
# Process vulnerabilities
|
||||
for vuln in result_item.get("Vulnerabilities", []):
|
||||
severity = vuln.get("Severity", "MEDIUM").upper()
|
||||
|
||||
findings.append(
|
||||
SASTFinding(
|
||||
tool="trivy",
|
||||
rule_id=vuln.get("VulnerabilityID", ""),
|
||||
severity=severity,
|
||||
file=target,
|
||||
line=0,
|
||||
message=vuln.get("Title", ""),
|
||||
cwe=vuln.get("CweIDs", [None])[0]
|
||||
if vuln.get("CweIDs")
|
||||
else None,
|
||||
fix_recommendation=f"Upgrade to {vuln.get('FixedVersion', 'latest')}"
|
||||
if vuln.get("FixedVersion")
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
# Process secrets
|
||||
for secret in result_item.get("Secrets", []):
|
||||
findings.append(
|
||||
SASTFinding(
|
||||
tool="trivy",
|
||||
rule_id=secret.get("RuleID", ""),
|
||||
severity="HIGH",
|
||||
file=target,
|
||||
line=secret.get("StartLine", 0),
|
||||
message=f"Secret detected: {secret.get('Title', '')}",
|
||||
code_snippet=secret.get("Match", ""),
|
||||
)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
self.logger.warning("Trivy scan timed out")
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.warning(f"Failed to parse Trivy output: {e}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Trivy scan failed: {e}")
|
||||
|
||||
return findings
|
||||
|
||||
|
||||
def get_sast_scanner(config: dict | None = None) -> SASTScanner:
|
||||
"""Get a configured SAST scanner instance.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary.
|
||||
|
||||
Returns:
|
||||
Configured SASTScanner instance.
|
||||
"""
|
||||
return SASTScanner(config=config)
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Utility Functions Package
|
||||
|
||||
This package contains utility functions for webhook sanitization,
|
||||
safe event dispatching, and other helper functions.
|
||||
safe event dispatching, ignore patterns, and other helper functions.
|
||||
"""
|
||||
|
||||
from utils.ignore_patterns import (
|
||||
IgnorePatterns,
|
||||
get_ignore_patterns,
|
||||
should_ignore_file,
|
||||
)
|
||||
from utils.webhook_sanitizer import (
|
||||
extract_minimal_context,
|
||||
sanitize_webhook_data,
|
||||
@@ -16,4 +21,7 @@ __all__ = [
|
||||
"validate_repository_format",
|
||||
"extract_minimal_context",
|
||||
"validate_webhook_signature",
|
||||
"IgnorePatterns",
|
||||
"get_ignore_patterns",
|
||||
"should_ignore_file",
|
||||
]
|
||||
|
||||
358
tools/ai-review/utils/ignore_patterns.py
Normal file
358
tools/ai-review/utils/ignore_patterns.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""AI Review Ignore Patterns
|
||||
|
||||
Provides .gitignore-style pattern matching for excluding files from AI review.
|
||||
Reads patterns from .ai-reviewignore files in the repository.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class IgnoreRule:
|
||||
"""A single ignore rule."""
|
||||
|
||||
pattern: str
|
||||
negation: bool = False
|
||||
directory_only: bool = False
|
||||
anchored: bool = False
|
||||
regex: re.Pattern = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Compile the pattern to regex."""
|
||||
self.regex = self._compile_pattern()
|
||||
|
||||
def _compile_pattern(self) -> re.Pattern:
|
||||
"""Convert gitignore pattern to regex."""
|
||||
pattern = self.pattern
|
||||
|
||||
# Handle directory-only patterns
|
||||
if pattern.endswith("/"):
|
||||
pattern = pattern[:-1]
|
||||
self.directory_only = True
|
||||
|
||||
# Handle anchored patterns (starting with /)
|
||||
if pattern.startswith("/"):
|
||||
pattern = pattern[1:]
|
||||
self.anchored = True
|
||||
|
||||
# Escape special regex characters except * and ?
|
||||
regex_pattern = ""
|
||||
i = 0
|
||||
while i < len(pattern):
|
||||
c = pattern[i]
|
||||
if c == "*":
|
||||
if i + 1 < len(pattern) and pattern[i + 1] == "*":
|
||||
# ** matches everything including /
|
||||
if i + 2 < len(pattern) and pattern[i + 2] == "/":
|
||||
regex_pattern += "(?:.*/)?"
|
||||
i += 3
|
||||
continue
|
||||
else:
|
||||
regex_pattern += ".*"
|
||||
i += 2
|
||||
continue
|
||||
else:
|
||||
# * matches everything except /
|
||||
regex_pattern += "[^/]*"
|
||||
elif c == "?":
|
||||
regex_pattern += "[^/]"
|
||||
elif c == "[":
|
||||
# Character class
|
||||
j = i + 1
|
||||
if j < len(pattern) and pattern[j] == "!":
|
||||
regex_pattern += "[^"
|
||||
j += 1
|
||||
else:
|
||||
regex_pattern += "["
|
||||
while j < len(pattern) and pattern[j] != "]":
|
||||
regex_pattern += pattern[j]
|
||||
j += 1
|
||||
if j < len(pattern):
|
||||
regex_pattern += "]"
|
||||
i = j
|
||||
elif c in ".^$+{}|()":
|
||||
regex_pattern += "\\" + c
|
||||
else:
|
||||
regex_pattern += c
|
||||
i += 1
|
||||
|
||||
# Anchor pattern
|
||||
if self.anchored:
|
||||
regex_pattern = "^" + regex_pattern
|
||||
else:
|
||||
regex_pattern = "(?:^|/)" + regex_pattern
|
||||
|
||||
# Match to end or as directory prefix
|
||||
regex_pattern += "(?:$|/)"
|
||||
|
||||
return re.compile(regex_pattern)
|
||||
|
||||
def matches(self, path: str, is_directory: bool = False) -> bool:
|
||||
"""Check if a path matches this rule.
|
||||
|
||||
Args:
|
||||
path: Relative path to check.
|
||||
is_directory: Whether the path is a directory.
|
||||
|
||||
Returns:
|
||||
True if the path matches.
|
||||
"""
|
||||
if self.directory_only and not is_directory:
|
||||
return False
|
||||
|
||||
# Normalize path
|
||||
path = path.replace("\\", "/")
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return bool(self.regex.search(path))
|
||||
|
||||
|
||||
class IgnorePatterns:
|
||||
"""Manages .ai-reviewignore patterns for a repository."""
|
||||
|
||||
DEFAULT_PATTERNS = [
|
||||
# Version control
|
||||
".git/",
|
||||
".svn/",
|
||||
".hg/",
|
||||
# Dependencies
|
||||
"node_modules/",
|
||||
"vendor/",
|
||||
"venv/",
|
||||
".venv/",
|
||||
"__pycache__/",
|
||||
"*.pyc",
|
||||
# Build outputs
|
||||
"dist/",
|
||||
"build/",
|
||||
"out/",
|
||||
"target/",
|
||||
"*.min.js",
|
||||
"*.min.css",
|
||||
"*.bundle.js",
|
||||
# IDE/Editor
|
||||
".idea/",
|
||||
".vscode/",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
# Generated files
|
||||
"*.lock",
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
"poetry.lock",
|
||||
"Pipfile.lock",
|
||||
"Cargo.lock",
|
||||
"go.sum",
|
||||
# Binary files
|
||||
"*.exe",
|
||||
"*.dll",
|
||||
"*.so",
|
||||
"*.dylib",
|
||||
"*.bin",
|
||||
"*.o",
|
||||
"*.a",
|
||||
# Media files
|
||||
"*.png",
|
||||
"*.jpg",
|
||||
"*.jpeg",
|
||||
"*.gif",
|
||||
"*.svg",
|
||||
"*.ico",
|
||||
"*.mp3",
|
||||
"*.mp4",
|
||||
"*.wav",
|
||||
"*.pdf",
|
||||
# Archives
|
||||
"*.zip",
|
||||
"*.tar",
|
||||
"*.gz",
|
||||
"*.rar",
|
||||
"*.7z",
|
||||
# Large data files
|
||||
"*.csv",
|
||||
"*.json.gz",
|
||||
"*.sql",
|
||||
"*.sqlite",
|
||||
"*.db",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_root: str | None = None,
|
||||
ignore_file: str = ".ai-reviewignore",
|
||||
use_defaults: bool = True,
|
||||
):
|
||||
"""Initialize ignore patterns.
|
||||
|
||||
Args:
|
||||
repo_root: Repository root path.
|
||||
ignore_file: Name of ignore file to read.
|
||||
use_defaults: Whether to include default patterns.
|
||||
"""
|
||||
self.repo_root = repo_root or os.getcwd()
|
||||
self.ignore_file = ignore_file
|
||||
self.rules: list[IgnoreRule] = []
|
||||
|
||||
# Load default patterns
|
||||
if use_defaults:
|
||||
for pattern in self.DEFAULT_PATTERNS:
|
||||
self._add_pattern(pattern)
|
||||
|
||||
# Load patterns from ignore file
|
||||
self._load_ignore_file()
|
||||
|
||||
def _load_ignore_file(self):
|
||||
"""Load patterns from .ai-reviewignore file."""
|
||||
ignore_path = os.path.join(self.repo_root, self.ignore_file)
|
||||
|
||||
if not os.path.exists(ignore_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(ignore_path) as f:
|
||||
for line in f:
|
||||
line = line.rstrip("\n\r")
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
self._add_pattern(line)
|
||||
except Exception:
|
||||
pass # Ignore errors reading the file
|
||||
|
||||
def _add_pattern(self, pattern: str):
|
||||
"""Add a pattern to the rules list.
|
||||
|
||||
Args:
|
||||
pattern: Pattern string (gitignore format).
|
||||
"""
|
||||
# Check for negation
|
||||
negation = False
|
||||
if pattern.startswith("!"):
|
||||
negation = True
|
||||
pattern = pattern[1:]
|
||||
|
||||
if not pattern:
|
||||
return
|
||||
|
||||
self.rules.append(IgnoreRule(pattern=pattern, negation=negation))
|
||||
|
||||
def is_ignored(self, path: str, is_directory: bool = False) -> bool:
|
||||
"""Check if a path should be ignored.
|
||||
|
||||
Args:
|
||||
path: Relative path to check.
|
||||
is_directory: Whether the path is a directory.
|
||||
|
||||
Returns:
|
||||
True if the path should be ignored.
|
||||
"""
|
||||
# Normalize path
|
||||
path = path.replace("\\", "/").lstrip("/")
|
||||
|
||||
# Check each rule in order (later rules override earlier ones)
|
||||
ignored = False
|
||||
for rule in self.rules:
|
||||
if rule.matches(path, is_directory):
|
||||
ignored = not rule.negation
|
||||
|
||||
return ignored
|
||||
|
||||
def filter_paths(self, paths: list[str]) -> list[str]:
|
||||
"""Filter a list of paths, removing ignored ones.
|
||||
|
||||
Args:
|
||||
paths: List of relative paths.
|
||||
|
||||
Returns:
|
||||
Filtered list of paths.
|
||||
"""
|
||||
return [p for p in paths if not self.is_ignored(p)]
|
||||
|
||||
def filter_diff_files(self, files: list[dict]) -> list[dict]:
|
||||
"""Filter diff file objects, removing ignored ones.
|
||||
|
||||
Args:
|
||||
files: List of file dicts with 'filename' or 'path' key.
|
||||
|
||||
Returns:
|
||||
Filtered list of file dicts.
|
||||
"""
|
||||
result = []
|
||||
for f in files:
|
||||
path = f.get("filename") or f.get("path") or f.get("name", "")
|
||||
if not self.is_ignored(path):
|
||||
result.append(f)
|
||||
return result
|
||||
|
||||
def should_review_file(self, filename: str) -> bool:
|
||||
"""Check if a file should be reviewed.
|
||||
|
||||
Args:
|
||||
filename: File path to check.
|
||||
|
||||
Returns:
|
||||
True if the file should be reviewed.
|
||||
"""
|
||||
return not self.is_ignored(filename)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: dict, repo_root: str | None = None
|
||||
) -> "IgnorePatterns":
|
||||
"""Create IgnorePatterns from config.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary.
|
||||
repo_root: Repository root path.
|
||||
|
||||
Returns:
|
||||
IgnorePatterns instance.
|
||||
"""
|
||||
ignore_config = config.get("ignore", {})
|
||||
use_defaults = ignore_config.get("use_defaults", True)
|
||||
ignore_file = ignore_config.get("file", ".ai-reviewignore")
|
||||
|
||||
instance = cls(
|
||||
repo_root=repo_root,
|
||||
ignore_file=ignore_file,
|
||||
use_defaults=use_defaults,
|
||||
)
|
||||
|
||||
# Add any extra patterns from config
|
||||
extra_patterns = ignore_config.get("patterns", [])
|
||||
for pattern in extra_patterns:
|
||||
instance._add_pattern(pattern)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def get_ignore_patterns(repo_root: str | None = None) -> IgnorePatterns:
|
||||
"""Get ignore patterns for a repository.
|
||||
|
||||
Args:
|
||||
repo_root: Repository root path.
|
||||
|
||||
Returns:
|
||||
IgnorePatterns instance.
|
||||
"""
|
||||
return IgnorePatterns(repo_root=repo_root)
|
||||
|
||||
|
||||
def should_ignore_file(filename: str, repo_root: str | None = None) -> bool:
|
||||
"""Quick check if a file should be ignored.
|
||||
|
||||
Args:
|
||||
filename: File path to check.
|
||||
repo_root: Repository root path.
|
||||
|
||||
Returns:
|
||||
True if the file should be ignored.
|
||||
"""
|
||||
return get_ignore_patterns(repo_root).is_ignored(filename)
|
||||
Reference in New Issue
Block a user