just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s

This commit is contained in:
2026-01-07 21:19:46 +01:00
parent a1fe47cdf4
commit e8d28225e0
24 changed files with 6431 additions and 250 deletions

View File

@@ -69,11 +69,17 @@ The codebase uses an **agent-based architecture** where specialized agents handl
- `execute(context)` - Main execution logic - `execute(context)` - Main execution logic
- Returns `AgentResult` with success status, message, data, and actions taken - Returns `AgentResult` with success status, message, data, and actions taken
**Core Agents:**
- **PRAgent** - Reviews pull requests with inline comments and security scanning - **PRAgent** - Reviews pull requests with inline comments and security scanning
- **IssueAgent** - Triages issues and responds to @ai-bot commands - **IssueAgent** - Triages issues and responds to @codebot commands
- **CodebaseAgent** - Analyzes entire codebase health and tech debt - **CodebaseAgent** - Analyzes entire codebase health and tech debt
- **ChatAgent** - Interactive assistant with tool calling (search_codebase, read_file, search_web) - **ChatAgent** - Interactive assistant with tool calling (search_codebase, read_file, search_web)
**Specialized Agents:**
- **DependencyAgent** - Scans dependencies for security vulnerabilities (Python, JavaScript)
- **TestCoverageAgent** - Analyzes code for test coverage gaps and suggests test cases
- **ArchitectureAgent** - Enforces layer separation and detects architecture violations
3. **Dispatcher** (`dispatcher.py`) - Routes events to appropriate agents: 3. **Dispatcher** (`dispatcher.py`) - Routes events to appropriate agents:
- Registers agents at startup - Registers agents at startup
- Determines which agents can handle each event - Determines which agents can handle each event
@@ -84,14 +90,23 @@ The codebase uses an **agent-based architecture** where specialized agents handl
The `LLMClient` (`clients/llm_client.py`) provides a unified interface for multiple LLM providers: The `LLMClient` (`clients/llm_client.py`) provides a unified interface for multiple LLM providers:
**Core Providers (in llm_client.py):**
- **OpenAI** - Primary provider (gpt-4.1-mini default) - **OpenAI** - Primary provider (gpt-4.1-mini default)
- **OpenRouter** - Multi-provider access (claude-3.5-sonnet) - **OpenRouter** - Multi-provider access (claude-3.5-sonnet)
- **Ollama** - Self-hosted models (codellama:13b) - **Ollama** - Self-hosted models (codellama:13b)
**Additional Providers (in clients/providers/):**
- **AnthropicProvider** - Direct Anthropic Claude API (claude-3.5-sonnet)
- **AzureOpenAIProvider** - Azure OpenAI Service with API key auth
- **AzureOpenAIWithAADProvider** - Azure OpenAI with Azure AD authentication
- **GeminiProvider** - Google Gemini API (public)
- **VertexAIGeminiProvider** - Google Vertex AI Gemini (enterprise GCP)
Key features: Key features:
- Tool/function calling support via `call_with_tools(messages, tools)` - Tool/function calling support via `call_with_tools(messages, tools)`
- JSON response parsing with fallback extraction - JSON response parsing with fallback extraction
- Provider-specific configuration via `config.yml` - Provider-specific configuration via `config.yml`
- Configurable timeouts per provider
### Platform Abstraction ### Platform Abstraction

116
README.md
View File

@@ -1,6 +1,6 @@
# OpenRabbit # OpenRabbit
Enterprise-grade AI code review system for **Gitea** with automated PR review, issue triage, interactive chat, and codebase analysis. Enterprise-grade AI code review system for **Gitea** and **GitHub** with automated PR review, issue triage, interactive chat, and codebase analysis.
--- ---
@@ -14,9 +14,15 @@ Enterprise-grade AI code review system for **Gitea** with automated PR review, i
| **Chat** | Interactive AI chat with codebase search and web search tools | | **Chat** | Interactive AI chat with codebase search and web search tools |
| **@codebot Commands** | `@codebot summarize`, `changelog`, `explain-diff`, `explain`, `suggest`, `triage`, `review-again` in comments | | **@codebot Commands** | `@codebot summarize`, `changelog`, `explain-diff`, `explain`, `suggest`, `triage`, `review-again` in comments |
| **Codebase Analysis** | Health scores, tech debt tracking, weekly reports | | **Codebase Analysis** | Health scores, tech debt tracking, weekly reports |
| **Security Scanner** | 17 OWASP-aligned rules for vulnerability detection | | **Security Scanner** | 17 OWASP-aligned rules + SAST integration (Bandit, Semgrep) |
| **Dependency Scanning** | Vulnerability detection for Python, JavaScript dependencies |
| **Test Coverage** | AI-powered test suggestions for untested code |
| **Architecture Compliance** | Layer separation enforcement, circular dependency detection |
| **Notifications** | Slack/Discord alerts for security findings and reviews |
| **Compliance** | Audit trail, CODEOWNERS enforcement, regulatory support |
| **Multi-Provider LLM** | OpenAI, Anthropic Claude, Azure OpenAI, Google Gemini, Ollama |
| **Enterprise Ready** | Audit logging, metrics, Prometheus export | | **Enterprise Ready** | Audit logging, metrics, Prometheus export |
| **Gitea Native** | Built for Gitea workflows and API | | **Gitea Native** | Built for Gitea workflows and API (also works with GitHub) |
--- ---
@@ -116,12 +122,28 @@ tools/ai-review/
│ ├── issue_agent.py # Issue triage & @codebot commands │ ├── issue_agent.py # Issue triage & @codebot commands
│ ├── pr_agent.py # PR review with security scan │ ├── pr_agent.py # PR review with security scan
│ ├── codebase_agent.py # Codebase health analysis │ ├── codebase_agent.py # Codebase health analysis
── chat_agent.py # Interactive chat with tool calling ── chat_agent.py # Interactive chat with tool calling
│ ├── dependency_agent.py # Dependency vulnerability scanning
│ ├── test_coverage_agent.py # Test coverage analysis
│ └── architecture_agent.py # Architecture compliance checking
├── clients/ # API clients ├── clients/ # API clients
│ ├── gitea_client.py # Gitea REST API wrapper │ ├── gitea_client.py # Gitea REST API wrapper
── llm_client.py # Multi-provider LLM client with tool support ── llm_client.py # Multi-provider LLM client with tool support
│ └── providers/ # Additional LLM providers
│ ├── anthropic_provider.py # Direct Anthropic Claude API
│ ├── azure_provider.py # Azure OpenAI Service
│ └── gemini_provider.py # Google Gemini API
├── security/ # Security scanning ├── security/ # Security scanning
── security_scanner.py # 17 OWASP-aligned rules ── security_scanner.py # 17 OWASP-aligned rules
│ └── sast_scanner.py # Bandit, Semgrep, Trivy integration
├── notifications/ # Alerting system
│ └── notifier.py # Slack, Discord, webhook notifications
├── compliance/ # Compliance & audit
│ ├── audit_trail.py # Audit logging with integrity verification
│ └── codeowners.py # CODEOWNERS enforcement
├── utils/ # Utility functions
│ ├── ignore_patterns.py # .ai-reviewignore support
│ └── webhook_sanitizer.py # Input validation
├── enterprise/ # Enterprise features ├── enterprise/ # Enterprise features
│ ├── audit_logger.py # JSONL audit logging │ ├── audit_logger.py # JSONL audit logging
│ └── metrics.py # Prometheus-compatible metrics │ └── metrics.py # Prometheus-compatible metrics
@@ -182,6 +204,10 @@ In any issue comment:
| `@codebot summarize` | Summarize the issue in 2-3 sentences | | `@codebot summarize` | Summarize the issue in 2-3 sentences |
| `@codebot explain` | Explain what the issue is about | | `@codebot explain` | Explain what the issue is about |
| `@codebot suggest` | Suggest solutions or next steps | | `@codebot suggest` | Suggest solutions or next steps |
| `@codebot check-deps` | Scan dependencies for security vulnerabilities |
| `@codebot suggest-tests` | Suggest test cases for changed code |
| `@codebot refactor-suggest` | Suggest refactoring opportunities |
| `@codebot architecture` | Check architecture compliance (alias: `arch-check`) |
| `@codebot` (any question) | Chat with AI using codebase/web search tools | | `@codebot` (any question) | Chat with AI using codebase/web search tools |
### Pull Request Commands ### Pull Request Commands
@@ -522,19 +548,91 @@ Replace `'Bartender'` with your bot's Gitea username. This prevents the bot from
| Provider | Model | Use Case | | Provider | Model | Use Case |
|----------|-------|----------| |----------|-------|----------|
| OpenAI | gpt-4.1-mini | Fast, reliable | | OpenAI | gpt-4.1-mini | Fast, reliable, default |
| Anthropic | claude-3.5-sonnet | Direct Claude API access |
| Azure OpenAI | gpt-4 (deployment) | Enterprise Azure deployments |
| Google Gemini | gemini-1.5-pro | GCP customers, Vertex AI |
| OpenRouter | claude-3.5-sonnet | Multi-provider access | | OpenRouter | claude-3.5-sonnet | Multi-provider access |
| Ollama | codellama:13b | Self-hosted, private | | Ollama | codellama:13b | Self-hosted, private |
### Provider Configuration
```yaml
# In config.yml
provider: anthropic # openai | anthropic | azure | gemini | openrouter | ollama
# Azure OpenAI
azure:
endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var
deployment: "gpt-4"
api_version: "2024-02-15-preview"
# Google Gemini (Vertex AI)
gemini:
project: "" # Set via GOOGLE_CLOUD_PROJECT env var
region: "us-central1"
```
### Environment Variables
| Variable | Provider | Description |
|----------|----------|-------------|
| `OPENAI_API_KEY` | OpenAI | API key |
| `ANTHROPIC_API_KEY` | Anthropic | API key |
| `AZURE_OPENAI_ENDPOINT` | Azure | Service endpoint URL |
| `AZURE_OPENAI_API_KEY` | Azure | API key |
| `AZURE_OPENAI_DEPLOYMENT` | Azure | Deployment name |
| `GOOGLE_API_KEY` | Gemini | API key (public API) |
| `GOOGLE_CLOUD_PROJECT` | Vertex AI | GCP project ID |
| `OPENROUTER_API_KEY` | OpenRouter | API key |
| `OLLAMA_HOST` | Ollama | Server URL (default: localhost:11434) |
--- ---
## Enterprise Features ## Enterprise Features
- **Audit Logging**: JSONL logs with daily rotation - **Audit Logging**: JSONL logs with integrity checksums and daily rotation
- **Compliance**: HIPAA, SOC2, PCI-DSS, GDPR support with configurable rules
- **CODEOWNERS Enforcement**: Validate approvals against CODEOWNERS file
- **Notifications**: Slack/Discord webhooks for critical findings
- **SAST Integration**: Bandit, Semgrep, Trivy for advanced security scanning
- **Metrics**: Prometheus-compatible export - **Metrics**: Prometheus-compatible export
- **Rate Limiting**: Configurable request limits - **Rate Limiting**: Configurable request limits and timeouts
- **Custom Security Rules**: Define your own patterns via YAML - **Custom Security Rules**: Define your own patterns via YAML
- **Tool Calling**: LLM function calling for interactive chat - **Tool Calling**: LLM function calling for interactive chat
- **Ignore Patterns**: `.ai-reviewignore` for excluding files from review
### Notifications Configuration
```yaml
# In config.yml
notifications:
enabled: true
threshold: "warning" # info | warning | error | critical
slack:
enabled: true
webhook_url: "" # Set via SLACK_WEBHOOK_URL env var
channel: "#code-review"
discord:
enabled: true
webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var
```
### Compliance Configuration
```yaml
compliance:
enabled: true
audit:
enabled: true
log_file: "audit.log"
retention_days: 90
codeowners:
enabled: true
require_approval: true
```
--- ---

296
tests/test_dispatcher.py Normal file
View File

@@ -0,0 +1,296 @@
"""Test Suite for Dispatcher
Tests for event routing and agent execution.
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
from unittest.mock import MagicMock, Mock, patch
import pytest
class TestDispatcherCreation:
"""Test dispatcher initialization."""
def test_create_dispatcher(self):
"""Test creating dispatcher."""
from dispatcher import Dispatcher
dispatcher = Dispatcher()
assert dispatcher is not None
assert dispatcher.agents == []
def test_create_dispatcher_with_config(self):
"""Test creating dispatcher with config."""
from dispatcher import Dispatcher
config = {"dispatcher": {"max_workers": 4}}
dispatcher = Dispatcher(config=config)
assert dispatcher.config == config
class TestAgentRegistration:
"""Test agent registration."""
def test_register_agent(self):
"""Test registering an agent."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
class MockAgent(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "test"
def execute(self, context):
return AgentResult(success=True, message="done")
dispatcher = Dispatcher()
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
dispatcher.register_agent(agent)
assert len(dispatcher.agents) == 1
assert dispatcher.agents[0] == agent
def test_register_multiple_agents(self):
"""Test registering multiple agents."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
class MockAgent1(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "type1"
def execute(self, context):
return AgentResult(success=True, message="agent1")
class MockAgent2(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "type2"
def execute(self, context):
return AgentResult(success=True, message="agent2")
dispatcher = Dispatcher()
dispatcher.register_agent(
MockAgent1(config={}, gitea_client=None, llm_client=None)
)
dispatcher.register_agent(
MockAgent2(config={}, gitea_client=None, llm_client=None)
)
assert len(dispatcher.agents) == 2
class TestEventRouting:
"""Test event routing to agents."""
def test_route_to_matching_agent(self):
"""Test that events are routed to matching agents."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
class MockAgent(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "issues"
def execute(self, context):
return AgentResult(success=True, message="handled")
dispatcher = Dispatcher()
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
dispatcher.register_agent(agent)
result = dispatcher.dispatch(
event_type="issues",
event_data={"action": "opened"},
owner="test",
repo="repo",
)
assert len(result.agents_run) == 1
assert result.results[0].success is True
def test_no_matching_agent(self):
"""Test dispatch when no agent matches."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
class MockAgent(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "issues"
def execute(self, context):
return AgentResult(success=True, message="handled")
dispatcher = Dispatcher()
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
dispatcher.register_agent(agent)
result = dispatcher.dispatch(
event_type="pull_request", # Different event type
event_data={"action": "opened"},
owner="test",
repo="repo",
)
assert len(result.agents_run) == 0
def test_multiple_matching_agents(self):
"""Test dispatch when multiple agents match."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
class MockAgent1(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "issues"
def execute(self, context):
return AgentResult(success=True, message="agent1")
class MockAgent2(BaseAgent):
def can_handle(self, event_type, event_data):
return event_type == "issues"
def execute(self, context):
return AgentResult(success=True, message="agent2")
dispatcher = Dispatcher()
dispatcher.register_agent(
MockAgent1(config={}, gitea_client=None, llm_client=None)
)
dispatcher.register_agent(
MockAgent2(config={}, gitea_client=None, llm_client=None)
)
result = dispatcher.dispatch(
event_type="issues",
event_data={"action": "opened"},
owner="test",
repo="repo",
)
assert len(result.agents_run) == 2
class TestDispatchResult:
"""Test dispatch result structure."""
def test_result_structure(self):
"""Test DispatchResult has correct structure."""
from dispatcher import DispatchResult
result = DispatchResult(
agents_run=["Agent1", "Agent2"],
results=[],
errors=[],
)
assert result.agents_run == ["Agent1", "Agent2"]
assert result.results == []
assert result.errors == []
def test_result_with_errors(self):
"""Test DispatchResult with errors."""
from dispatcher import DispatchResult
result = DispatchResult(
agents_run=["Agent1"],
results=[],
errors=["Error 1", "Error 2"],
)
assert len(result.errors) == 2
class TestAgentExecution:
"""Test agent execution through dispatcher."""
def test_agent_receives_context(self):
"""Test that agents receive proper context."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
received_context = None
class MockAgent(BaseAgent):
def can_handle(self, event_type, event_data):
return True
def execute(self, context):
nonlocal received_context
received_context = context
return AgentResult(success=True, message="done")
dispatcher = Dispatcher()
dispatcher.register_agent(
MockAgent(config={}, gitea_client=None, llm_client=None)
)
dispatcher.dispatch(
event_type="issues",
event_data={"action": "opened", "issue": {"number": 123}},
owner="testowner",
repo="testrepo",
)
assert received_context is not None
assert received_context.owner == "testowner"
assert received_context.repo == "testrepo"
assert received_context.event_type == "issues"
assert received_context.event_data["action"] == "opened"
def test_agent_failure_captured(self):
"""Test that agent failures are captured in results."""
from agents.base_agent import AgentContext, AgentResult, BaseAgent
from dispatcher import Dispatcher
class FailingAgent(BaseAgent):
def can_handle(self, event_type, event_data):
return True
def execute(self, context):
raise Exception("Test error")
dispatcher = Dispatcher()
dispatcher.register_agent(
FailingAgent(config={}, gitea_client=None, llm_client=None)
)
result = dispatcher.dispatch(
event_type="issues",
event_data={},
owner="test",
repo="repo",
)
# Agent should still be in agents_run
assert len(result.agents_run) == 1
# Result should indicate failure
assert result.results[0].success is False
class TestGetDispatcher:
"""Test get_dispatcher factory function."""
def test_get_dispatcher_returns_singleton(self):
"""Test that get_dispatcher returns configured dispatcher."""
from dispatcher import get_dispatcher
dispatcher = get_dispatcher()
assert dispatcher is not None
def test_get_dispatcher_with_config(self):
"""Test get_dispatcher with custom config."""
from dispatcher import get_dispatcher
config = {"test": "value"}
dispatcher = get_dispatcher(config=config)
assert dispatcher.config.get("test") == "value"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

456
tests/test_llm_client.py Normal file
View File

@@ -0,0 +1,456 @@
"""Test Suite for LLM Client
Tests for LLM client functionality including provider support,
tool calling, and JSON parsing.
"""
import json
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
from unittest.mock import MagicMock, Mock, patch
import pytest
class TestLLMClientCreation:
"""Test LLM client initialization."""
def test_create_openai_client(self):
"""Test creating OpenAI client."""
from clients.llm_client import LLMClient
client = LLMClient(
provider="openai",
config={"model": "gpt-4o-mini", "api_key": "test-key"},
)
assert client.provider_name == "openai"
def test_create_openrouter_client(self):
"""Test creating OpenRouter client."""
from clients.llm_client import LLMClient
client = LLMClient(
provider="openrouter",
config={"model": "anthropic/claude-3.5-sonnet", "api_key": "test-key"},
)
assert client.provider_name == "openrouter"
def test_create_ollama_client(self):
"""Test creating Ollama client."""
from clients.llm_client import LLMClient
client = LLMClient(
provider="ollama",
config={"model": "codellama:13b", "host": "http://localhost:11434"},
)
assert client.provider_name == "ollama"
def test_invalid_provider_raises_error(self):
"""Test that invalid provider raises ValueError."""
from clients.llm_client import LLMClient
with pytest.raises(ValueError, match="Unknown provider"):
LLMClient(provider="invalid_provider")
def test_from_config_openai(self):
"""Test creating client from config dict."""
from clients.llm_client import LLMClient
config = {
"provider": "openai",
"model": {"openai": "gpt-4o-mini"},
"temperature": 0,
"max_tokens": 4096,
}
client = LLMClient.from_config(config)
assert client.provider_name == "openai"
class TestLLMResponse:
"""Test LLM response dataclass."""
def test_response_creation(self):
"""Test creating LLMResponse."""
from clients.llm_client import LLMResponse
response = LLMResponse(
content="Test response",
model="gpt-4o-mini",
provider="openai",
tokens_used=100,
finish_reason="stop",
)
assert response.content == "Test response"
assert response.model == "gpt-4o-mini"
assert response.provider == "openai"
assert response.tokens_used == 100
assert response.finish_reason == "stop"
assert response.tool_calls is None
def test_response_with_tool_calls(self):
"""Test LLMResponse with tool calls."""
from clients.llm_client import LLMResponse, ToolCall
tool_calls = [ToolCall(id="call_1", name="search", arguments={"query": "test"})]
response = LLMResponse(
content="",
model="gpt-4o-mini",
provider="openai",
tool_calls=tool_calls,
)
assert response.tool_calls is not None
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "search"
class TestToolCall:
"""Test ToolCall dataclass."""
def test_tool_call_creation(self):
"""Test creating ToolCall."""
from clients.llm_client import ToolCall
tool_call = ToolCall(
id="call_123",
name="search_codebase",
arguments={"query": "authentication", "file_pattern": "*.py"},
)
assert tool_call.id == "call_123"
assert tool_call.name == "search_codebase"
assert tool_call.arguments["query"] == "authentication"
assert tool_call.arguments["file_pattern"] == "*.py"
class TestJSONParsing:
"""Test JSON extraction and parsing."""
def test_parse_direct_json(self):
"""Test parsing direct JSON response."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = '{"key": "value", "number": 42}'
result = client._extract_json(content)
assert result["key"] == "value"
assert result["number"] == 42
def test_parse_json_in_code_block(self):
"""Test parsing JSON in markdown code block."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """Here is the analysis:
```json
{
"type": "bug",
"priority": "high"
}
```
That's my analysis."""
result = client._extract_json(content)
assert result["type"] == "bug"
assert result["priority"] == "high"
def test_parse_json_in_plain_code_block(self):
"""Test parsing JSON in plain code block (no json specifier)."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """Analysis:
```
{"status": "success", "count": 5}
```
"""
result = client._extract_json(content)
assert result["status"] == "success"
assert result["count"] == 5
def test_parse_json_with_preamble(self):
"""Test parsing JSON with text before it."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """Based on my analysis, here is the result:
{"findings": ["issue1", "issue2"], "severity": "medium"}
"""
result = client._extract_json(content)
assert result["findings"] == ["issue1", "issue2"]
assert result["severity"] == "medium"
def test_parse_json_with_postamble(self):
"""Test parsing JSON with text after it."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """{"result": true}
Let me know if you need more details."""
result = client._extract_json(content)
assert result["result"] is True
def test_parse_nested_json(self):
"""Test parsing nested JSON objects."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """{
"outer": {
"inner": {
"value": "deep"
}
},
"array": [1, 2, 3]
}"""
result = client._extract_json(content)
assert result["outer"]["inner"]["value"] == "deep"
assert result["array"] == [1, 2, 3]
def test_parse_invalid_json_raises_error(self):
"""Test that invalid JSON raises ValueError."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = "This is not JSON at all"
with pytest.raises(ValueError, match="Failed to parse JSON"):
client._extract_json(content)
def test_parse_truncated_json_raises_error(self):
"""Test that truncated JSON raises ValueError."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = '{"key": "value", "incomplete'
with pytest.raises(ValueError):
client._extract_json(content)
class TestOpenAIProvider:
"""Test OpenAI provider."""
def test_provider_creation(self):
"""Test OpenAI provider initialization."""
from clients.llm_client import OpenAIProvider
provider = OpenAIProvider(
api_key="test-key",
model="gpt-4o-mini",
temperature=0.5,
max_tokens=2048,
)
assert provider.model == "gpt-4o-mini"
assert provider.temperature == 0.5
assert provider.max_tokens == 2048
def test_provider_requires_api_key(self):
"""Test that calling without API key raises error."""
from clients.llm_client import OpenAIProvider
provider = OpenAIProvider(api_key="")
with pytest.raises(ValueError, match="API key is required"):
provider.call("test prompt")
@patch("clients.llm_client.requests.post")
def test_provider_call_success(self, mock_post):
"""Test successful API call."""
from clients.llm_client import OpenAIProvider
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{
"message": {"content": "Test response"},
"finish_reason": "stop",
}
],
"model": "gpt-4o-mini",
"usage": {"total_tokens": 50},
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OpenAIProvider(api_key="test-key")
response = provider.call("Hello")
assert response.content == "Test response"
assert response.provider == "openai"
assert response.tokens_used == 50
class TestOpenRouterProvider:
"""Test OpenRouter provider."""
def test_provider_creation(self):
"""Test OpenRouter provider initialization."""
from clients.llm_client import OpenRouterProvider
provider = OpenRouterProvider(
api_key="test-key",
model="anthropic/claude-3.5-sonnet",
)
assert provider.model == "anthropic/claude-3.5-sonnet"
@patch("clients.llm_client.requests.post")
def test_provider_call_success(self, mock_post):
"""Test successful API call."""
from clients.llm_client import OpenRouterProvider
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{
"message": {"content": "Claude response"},
"finish_reason": "stop",
}
],
"model": "anthropic/claude-3.5-sonnet",
"usage": {"total_tokens": 75},
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OpenRouterProvider(api_key="test-key")
response = provider.call("Hello")
assert response.content == "Claude response"
assert response.provider == "openrouter"
class TestOllamaProvider:
"""Test Ollama provider."""
def test_provider_creation(self):
"""Test Ollama provider initialization."""
from clients.llm_client import OllamaProvider
provider = OllamaProvider(
host="http://localhost:11434",
model="codellama:13b",
)
assert provider.model == "codellama:13b"
assert provider.host == "http://localhost:11434"
@patch("clients.llm_client.requests.post")
def test_provider_call_success(self, mock_post):
"""Test successful API call."""
from clients.llm_client import OllamaProvider
mock_response = Mock()
mock_response.json.return_value = {
"response": "Ollama response",
"model": "codellama:13b",
"done": True,
"eval_count": 30,
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OllamaProvider()
response = provider.call("Hello")
assert response.content == "Ollama response"
assert response.provider == "ollama"
class TestToolCalling:
"""Test tool/function calling support."""
@patch("clients.llm_client.requests.post")
def test_openai_tool_calling(self, mock_post):
"""Test OpenAI tool calling."""
from clients.llm_client import OpenAIProvider
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": None,
"tool_calls": [
{
"id": "call_abc123",
"function": {
"name": "search_codebase",
"arguments": '{"query": "auth"}',
},
}
],
},
"finish_reason": "tool_calls",
}
],
"model": "gpt-4o-mini",
"usage": {"total_tokens": 100},
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OpenAIProvider(api_key="test-key")
tools = [
{
"type": "function",
"function": {
"name": "search_codebase",
"description": "Search the codebase",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
},
},
}
]
response = provider.call_with_tools(
messages=[{"role": "user", "content": "Search for auth"}],
tools=tools,
)
assert response.tool_calls is not None
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "search_codebase"
assert response.tool_calls[0].arguments["query"] == "auth"
def test_ollama_tool_calling_not_supported(self):
"""Test that Ollama raises NotImplementedError for tool calling."""
from clients.llm_client import OllamaProvider
provider = OllamaProvider()
with pytest.raises(NotImplementedError):
provider.call_with_tools(
messages=[{"role": "user", "content": "test"}],
tools=[],
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -2,20 +2,40 @@
This package contains the modular agent implementations for the This package contains the modular agent implementations for the
enterprise AI code review system. enterprise AI code review system.
Core Agents:
- PRAgent: Pull request review and analysis
- IssueAgent: Issue triage and response
- CodebaseAgent: Codebase health analysis
- ChatAgent: Interactive chat with tool calling
Specialized Agents:
- DependencyAgent: Dependency vulnerability scanning
- TestCoverageAgent: Test coverage analysis and suggestions
- ArchitectureAgent: Architecture compliance checking
""" """
from agents.architecture_agent import ArchitectureAgent
from agents.base_agent import AgentContext, AgentResult, BaseAgent from agents.base_agent import AgentContext, AgentResult, BaseAgent
from agents.chat_agent import ChatAgent from agents.chat_agent import ChatAgent
from agents.codebase_agent import CodebaseAgent from agents.codebase_agent import CodebaseAgent
from agents.dependency_agent import DependencyAgent
from agents.issue_agent import IssueAgent from agents.issue_agent import IssueAgent
from agents.pr_agent import PRAgent from agents.pr_agent import PRAgent
from agents.test_coverage_agent import TestCoverageAgent
__all__ = [ __all__ = [
# Base
"BaseAgent", "BaseAgent",
"AgentContext", "AgentContext",
"AgentResult", "AgentResult",
# Core Agents
"IssueAgent", "IssueAgent",
"PRAgent", "PRAgent",
"CodebaseAgent", "CodebaseAgent",
"ChatAgent", "ChatAgent",
# Specialized Agents
"DependencyAgent",
"TestCoverageAgent",
"ArchitectureAgent",
] ]

View File

@@ -0,0 +1,547 @@
"""Architecture Compliance Agent
AI agent for enforcing architectural patterns and layer separation.
Detects cross-layer violations and circular dependencies.
"""
import base64
import os
import re
from dataclasses import dataclass, field
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class ArchitectureViolation:
"""An architecture violation."""
file: str
line: int
violation_type: str # cross_layer, circular, naming, structure
severity: str # HIGH, MEDIUM, LOW
description: str
recommendation: str
source_layer: str | None = None
target_layer: str | None = None
@dataclass
class ArchitectureReport:
"""Report of architecture analysis."""
violations: list[ArchitectureViolation]
layers_detected: dict[str, list[str]]
circular_dependencies: list[tuple[str, str]]
compliance_score: float
recommendations: list[str]
class ArchitectureAgent(BaseAgent):
"""Agent for enforcing architectural compliance."""
# Marker for architecture comments
ARCH_AI_MARKER = "<!-- AI_ARCHITECTURE_CHECK -->"
# Default layer definitions
DEFAULT_LAYERS = {
"api": {
"patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"],
"can_import": ["services", "models", "utils", "config"],
"cannot_import": ["db", "repositories", "infrastructure"],
},
"services": {
"patterns": ["services/", "usecases/", "application/"],
"can_import": ["models", "repositories", "utils", "config"],
"cannot_import": ["api", "controllers", "handlers"],
},
"repositories": {
"patterns": ["repositories/", "repos/", "data/"],
"can_import": ["models", "db", "utils", "config"],
"cannot_import": ["api", "services", "controllers"],
},
"models": {
"patterns": ["models/", "entities/", "domain/", "schemas/"],
"can_import": ["utils", "config"],
"cannot_import": ["api", "services", "repositories", "db"],
},
"db": {
"patterns": ["db/", "database/", "infrastructure/"],
"can_import": ["models", "config"],
"cannot_import": ["api", "services"],
},
"utils": {
"patterns": ["utils/", "helpers/", "common/", "lib/"],
"can_import": ["config"],
"cannot_import": [],
},
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("architecture", {})
if not agent_config.get("enabled", False):
return False
# Handle PR events
if event_type == "pull_request":
action = event_data.get("action", "")
if action in ["opened", "synchronize"]:
return True
# Handle @codebot architecture command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} architecture" in comment_body.lower():
return True
if f"{mention_prefix} arch-check" in comment_body.lower():
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the architecture agent."""
self.logger.info(f"Checking architecture for {context.owner}/{context.repo}")
actions_taken = []
# Get layer configuration
agent_config = self.config.get("agents", {}).get("architecture", {})
layers = agent_config.get("layers", self.DEFAULT_LAYERS)
# Determine issue number
if context.event_type == "issue_comment":
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {})
.get("user", {})
.get("login", "user")
)
is_pr = issue.get("pull_request") is not None
else:
pr = context.event_data.get("pull_request", {})
issue_number = pr.get("number")
comment_author = None
is_pr = True
if is_pr and issue_number:
# Analyze PR diff
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
report = self._analyze_diff(diff, layers)
actions_taken.append(f"Analyzed PR diff for architecture violations")
else:
# Analyze repository structure
report = self._analyze_repository(context.owner, context.repo, layers)
actions_taken.append(f"Analyzed repository architecture")
# Post report
if issue_number:
comment = self._format_architecture_report(report, comment_author)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.ARCH_AI_MARKER,
)
actions_taken.append("Posted architecture report")
return AgentResult(
success=len(report.violations) == 0 or report.compliance_score >= 0.8,
message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance",
data={
"violations_count": len(report.violations),
"compliance_score": report.compliance_score,
"circular_dependencies": len(report.circular_dependencies),
},
actions_taken=actions_taken,
)
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
"""Get the PR diff."""
try:
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
except Exception as e:
self.logger.error(f"Failed to get PR diff: {e}")
return ""
def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport:
"""Analyze PR diff for architecture violations."""
violations = []
imports_by_file = {}
current_file = None
current_language = None
for line in diff.splitlines():
# Track current file
if line.startswith("diff --git"):
match = re.search(r"b/(.+)$", line)
if match:
current_file = match.group(1)
current_language = self._detect_language(current_file)
imports_by_file[current_file] = []
# Look for import statements in added lines
if line.startswith("+") and not line.startswith("+++"):
if current_file and current_language:
imports = self._extract_imports(line[1:], current_language)
imports_by_file.setdefault(current_file, []).extend(imports)
# Check for violations
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
layer_config = layers.get(source_layer, {})
cannot_import = layer_config.get("cannot_import", [])
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer in cannot_import:
violations.append(
ArchitectureViolation(
file=file_path,
line=0, # Line number not tracked in this simple implementation
violation_type="cross_layer",
severity="HIGH",
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
recommendation=f"Move this import to an allowed layer or refactor the dependency",
source_layer=source_layer,
target_layer=target_layer,
)
)
# Detect circular dependencies
circular = self._detect_circular_dependencies(imports_by_file, layers)
# Calculate compliance score
total_imports = sum(len(imps) for imps in imports_by_file.values())
if total_imports > 0:
compliance = 1.0 - (len(violations) / max(total_imports, 1))
else:
compliance = 1.0
return ArchitectureReport(
violations=violations,
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
circular_dependencies=circular,
compliance_score=max(0.0, compliance),
recommendations=self._generate_recommendations(violations),
)
def _analyze_repository(
self, owner: str, repo: str, layers: dict
) -> ArchitectureReport:
"""Analyze repository structure for architecture compliance."""
violations = []
imports_by_file = {}
# Collect files from each layer
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
try:
path = pattern.rstrip("/")
contents = self.gitea.get_file_contents(owner, repo, path)
if isinstance(contents, list):
for item in contents[:20]: # Limit files per layer
if item.get("type") == "file":
filepath = item.get("path", "")
imports = self._get_file_imports(owner, repo, filepath)
imports_by_file[filepath] = imports
except Exception:
pass
# Check for violations
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
layer_config = layers.get(source_layer, {})
cannot_import = layer_config.get("cannot_import", [])
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer in cannot_import:
violations.append(
ArchitectureViolation(
file=file_path,
line=0,
violation_type="cross_layer",
severity="HIGH",
description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'",
recommendation=f"Refactor to remove dependency on '{target_layer}'",
source_layer=source_layer,
target_layer=target_layer,
)
)
# Detect circular dependencies
circular = self._detect_circular_dependencies(imports_by_file, layers)
# Calculate compliance
total_imports = sum(len(imps) for imps in imports_by_file.values())
if total_imports > 0:
compliance = 1.0 - (len(violations) / max(total_imports, 1))
else:
compliance = 1.0
return ArchitectureReport(
violations=violations,
layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers),
circular_dependencies=circular,
compliance_score=max(0.0, compliance),
recommendations=self._generate_recommendations(violations),
)
def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]:
"""Get imports from a file."""
imports = []
language = self._detect_language(filepath)
if not language:
return imports
try:
content_data = self.gitea.get_file_contents(owner, repo, filepath)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
for line in content.splitlines():
imports.extend(self._extract_imports(line, language))
except Exception:
pass
return imports
def _detect_language(self, filepath: str) -> str | None:
"""Detect programming language from file path."""
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".go": "go",
".java": "java",
".rb": "ruby",
}
ext = os.path.splitext(filepath)[1]
return ext_map.get(ext)
def _extract_imports(self, line: str, language: str) -> list[str]:
"""Extract import statements from a line of code."""
imports = []
line = line.strip()
if language == "python":
# from x import y, import x
match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line)
if match:
imp = match.group(1) or match.group(2)
if imp:
imports.append(imp.split(".")[0])
elif language in ("javascript", "typescript"):
# import x from 'y', require('y')
match = re.search(
r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line
)
if match:
imp = match.group(1) or match.group(2)
if imp and not imp.startswith("."):
imports.append(imp.split("/")[0])
elif imp:
imports.append(imp)
elif language == "go":
# import "package"
match = re.search(r'import\s+["\']([^"\']+)["\']', line)
if match:
imports.append(match.group(1).split("/")[-1])
elif language == "java":
# import package.Class
match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line)
if match:
parts = match.group(1).split(".")
if len(parts) > 1:
imports.append(parts[-2]) # Package name
return imports
def _detect_layer(self, filepath: str, layers: dict) -> str | None:
"""Detect which layer a file belongs to."""
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
if pattern.rstrip("/") in filepath:
return layer_name
return None
def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None:
"""Detect which layer an import refers to."""
for layer_name, layer_config in layers.items():
for pattern in layer_config.get("patterns", []):
pattern_name = pattern.rstrip("/").split("/")[-1]
if pattern_name in import_path or import_path.startswith(pattern_name):
return layer_name
return None
def _detect_circular_dependencies(
self, imports_by_file: dict, layers: dict
) -> list[tuple[str, str]]:
"""Detect circular dependencies between layers."""
circular = []
# Build layer dependency graph
layer_deps = {}
for file_path, imports in imports_by_file.items():
source_layer = self._detect_layer(file_path, layers)
if not source_layer:
continue
if source_layer not in layer_deps:
layer_deps[source_layer] = set()
for imp in imports:
target_layer = self._detect_layer_from_import(imp, layers)
if target_layer and target_layer != source_layer:
layer_deps[source_layer].add(target_layer)
# Check for circular dependencies
for layer_a, deps_a in layer_deps.items():
for layer_b in deps_a:
if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()):
pair = tuple(sorted([layer_a, layer_b]))
if pair not in circular:
circular.append(pair)
return circular
def _group_files_by_layer(
self, files: list[str], layers: dict
) -> dict[str, list[str]]:
"""Group files by their layer."""
grouped = {}
for filepath in files:
layer = self._detect_layer(filepath, layers)
if layer:
if layer not in grouped:
grouped[layer] = []
grouped[layer].append(filepath)
return grouped
def _generate_recommendations(
self, violations: list[ArchitectureViolation]
) -> list[str]:
"""Generate recommendations based on violations."""
recommendations = []
# Count violations by type
cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer")
if cross_layer > 0:
recommendations.append(
f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces"
)
if cross_layer > 5:
recommendations.append(
"Consider using dependency injection to reduce coupling between layers"
)
return recommendations
def _format_architecture_report(
self, report: ArchitectureReport, user: str | None
) -> str:
"""Format the architecture report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🏗️ Architecture Compliance Check",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Compliance Score | {report.compliance_score:.0%} |",
f"| Violations | {len(report.violations)} |",
f"| Circular Dependencies | {len(report.circular_dependencies)} |",
f"| Layers Detected | {len(report.layers_detected)} |",
"",
]
)
# Compliance bar
filled = int(report.compliance_score * 10)
bar = "" * filled + "" * (10 - filled)
lines.append(f"`[{bar}]` {report.compliance_score:.0%}")
lines.append("")
# Violations
if report.violations:
lines.append("### 🚨 Violations")
lines.append("")
for v in report.violations[:10]: # Limit display
severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
lines.append(
f"{severity_emoji.get(v.severity, '')} **{v.violation_type.upper()}** in `{v.file}`"
)
lines.append(f" - {v.description}")
lines.append(f" - 💡 {v.recommendation}")
lines.append("")
if len(report.violations) > 10:
lines.append(f"*... and {len(report.violations) - 10} more violations*")
lines.append("")
# Circular dependencies
if report.circular_dependencies:
lines.append("### 🔄 Circular Dependencies")
lines.append("")
for a, b in report.circular_dependencies:
lines.append(f"- `{a}` ↔ `{b}`")
lines.append("")
# Layers detected
if report.layers_detected:
lines.append("### 📁 Layers Detected")
lines.append("")
for layer, files in report.layers_detected.items():
lines.append(f"- **{layer}**: {len(files)} files")
lines.append("")
# Recommendations
if report.recommendations:
lines.append("### 💡 Recommendations")
lines.append("")
for rec in report.recommendations:
lines.append(f"- {rec}")
lines.append("")
# Overall status
if report.compliance_score >= 0.9:
lines.append("---")
lines.append("✅ **Excellent architecture compliance!**")
elif report.compliance_score >= 0.7:
lines.append("---")
lines.append("⚠️ **Some architectural issues to address**")
else:
lines.append("---")
lines.append("❌ **Significant architectural violations detected**")
return "\n".join(lines)

View File

@@ -65,9 +65,10 @@ class BaseAgent(ABC):
self.llm = llm_client or LLMClient.from_config(self.config) self.llm = llm_client or LLMClient.from_config(self.config)
self.logger = logging.getLogger(self.__class__.__name__) self.logger = logging.getLogger(self.__class__.__name__)
# Rate limiting # Rate limiting - now configurable
self._last_request_time = 0.0 self._last_request_time = 0.0
self._min_request_interval = 1.0 # seconds rate_limits = self.config.get("rate_limits", {})
self._min_request_interval = rate_limits.get("min_interval", 1.0) # seconds
@staticmethod @staticmethod
def _load_config() -> dict: def _load_config() -> dict:

View File

@@ -0,0 +1,548 @@
"""Dependency Security Agent
AI agent for scanning dependency files for known vulnerabilities
and outdated packages. Supports multiple package managers.
"""
import base64
import json
import logging
import os
import re
import subprocess
from dataclasses import dataclass, field
from typing import Any
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class DependencyFinding:
"""A security finding in a dependency."""
package: str
version: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW
vulnerability_id: str # CVE, GHSA, etc.
title: str
description: str
fixed_version: str | None = None
references: list[str] = field(default_factory=list)
@dataclass
class DependencyReport:
"""Report of dependency analysis."""
total_packages: int
vulnerable_packages: int
outdated_packages: int
findings: list[DependencyFinding]
recommendations: list[str]
files_scanned: list[str]
class DependencyAgent(BaseAgent):
"""Agent for scanning dependencies for security vulnerabilities."""
# Marker for dependency comments
DEP_AI_MARKER = "<!-- AI_DEPENDENCY_SCAN -->"
# Supported dependency files
DEPENDENCY_FILES = {
"python": ["requirements.txt", "Pipfile", "pyproject.toml", "setup.py"],
"javascript": ["package.json", "package-lock.json", "yarn.lock"],
"ruby": ["Gemfile", "Gemfile.lock"],
"go": ["go.mod", "go.sum"],
"rust": ["Cargo.toml", "Cargo.lock"],
"java": ["pom.xml", "build.gradle", "build.gradle.kts"],
"php": ["composer.json", "composer.lock"],
"dotnet": ["*.csproj", "packages.config", "*.fsproj"],
}
# Common vulnerable package patterns
KNOWN_VULNERABILITIES = {
"python": {
"requests": {
"< 2.31.0": "CVE-2023-32681 - Proxy-Authorization header leak"
},
"urllib3": {
"< 2.0.7": "CVE-2023-45803 - Request body not stripped on redirects"
},
"cryptography": {"< 41.0.0": "Multiple CVEs - Update recommended"},
"pillow": {"< 10.0.0": "CVE-2023-4863 - WebP vulnerability"},
"django": {"< 4.2.0": "Multiple security fixes"},
"flask": {"< 2.3.0": "Security improvements"},
"pyyaml": {"< 6.0": "CVE-2020-14343 - Arbitrary code execution"},
"jinja2": {"< 3.1.0": "Security fixes"},
},
"javascript": {
"lodash": {"< 4.17.21": "CVE-2021-23337 - Prototype pollution"},
"axios": {"< 1.6.0": "CVE-2023-45857 - CSRF vulnerability"},
"express": {"< 4.18.0": "Security updates"},
"jquery": {"< 3.5.0": "XSS vulnerabilities"},
"minimist": {"< 1.2.6": "Prototype pollution"},
"node-fetch": {"< 3.3.0": "Security fixes"},
},
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("dependency", {})
if not agent_config.get("enabled", True):
return False
# Handle PR events that modify dependency files
if event_type == "pull_request":
action = event_data.get("action", "")
if action in ["opened", "synchronize"]:
# Check if any dependency files are modified
files = event_data.get("files", [])
for f in files:
if self._is_dependency_file(f.get("filename", "")):
return True
# Handle @codebot check-deps command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} check-deps" in comment_body.lower():
return True
return False
def _is_dependency_file(self, filename: str) -> bool:
"""Check if a file is a dependency file."""
basename = os.path.basename(filename)
for lang, files in self.DEPENDENCY_FILES.items():
for pattern in files:
if pattern.startswith("*"):
if basename.endswith(pattern[1:]):
return True
elif basename == pattern:
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the dependency agent."""
self.logger.info(f"Scanning dependencies for {context.owner}/{context.repo}")
actions_taken = []
# Determine if this is a command or PR event
if context.event_type == "issue_comment":
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {})
.get("user", {})
.get("login", "user")
)
else:
pr = context.event_data.get("pull_request", {})
issue_number = pr.get("number")
comment_author = None
# Collect dependency files
dep_files = self._collect_dependency_files(context.owner, context.repo)
if not dep_files:
message = "No dependency files found in repository."
if issue_number:
self.gitea.create_issue_comment(
context.owner,
context.repo,
issue_number,
f"{self.AI_DISCLAIMER}\n\n{message}",
)
return AgentResult(
success=True,
message=message,
)
actions_taken.append(f"Found {len(dep_files)} dependency files")
# Analyze dependencies
report = self._analyze_dependencies(context.owner, context.repo, dep_files)
actions_taken.append(f"Analyzed {report.total_packages} packages")
# Run external scanners if available
external_findings = self._run_external_scanners(context.owner, context.repo)
if external_findings:
report.findings.extend(external_findings)
actions_taken.append(
f"External scanner found {len(external_findings)} issues"
)
# Generate and post report
if issue_number:
comment = self._format_dependency_report(report, comment_author)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.DEP_AI_MARKER,
)
actions_taken.append("Posted dependency report")
return AgentResult(
success=True,
message=f"Dependency scan complete: {report.vulnerable_packages} vulnerable, {report.outdated_packages} outdated",
data={
"total_packages": report.total_packages,
"vulnerable_packages": report.vulnerable_packages,
"outdated_packages": report.outdated_packages,
"findings_count": len(report.findings),
},
actions_taken=actions_taken,
)
def _collect_dependency_files(
self, owner: str, repo: str
) -> dict[str, dict[str, Any]]:
"""Collect all dependency files from the repository."""
dep_files = {}
# Common paths to check
paths_to_check = [
"", # Root
"backend/",
"frontend/",
"api/",
"services/",
]
for base_path in paths_to_check:
for lang, filenames in self.DEPENDENCY_FILES.items():
for filename in filenames:
if filename.startswith("*"):
continue # Skip glob patterns for now
filepath = f"{base_path}{filename}".lstrip("/")
try:
content_data = self.gitea.get_file_contents(
owner, repo, filepath
)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
dep_files[filepath] = {
"language": lang,
"content": content,
}
except Exception:
pass # File doesn't exist
return dep_files
def _analyze_dependencies(
self, owner: str, repo: str, dep_files: dict
) -> DependencyReport:
"""Analyze dependency files for vulnerabilities."""
findings = []
total_packages = 0
vulnerable_count = 0
outdated_count = 0
recommendations = []
files_scanned = list(dep_files.keys())
for filepath, file_info in dep_files.items():
lang = file_info["language"]
content = file_info["content"]
if lang == "python":
packages = self._parse_python_deps(content, filepath)
elif lang == "javascript":
packages = self._parse_javascript_deps(content, filepath)
else:
packages = []
total_packages += len(packages)
# Check for known vulnerabilities
known_vulns = self.KNOWN_VULNERABILITIES.get(lang, {})
for pkg_name, version in packages:
if pkg_name.lower() in known_vulns:
vuln_info = known_vulns[pkg_name.lower()]
for version_constraint, vuln_desc in vuln_info.items():
if self._version_matches_constraint(
version, version_constraint
):
findings.append(
DependencyFinding(
package=pkg_name,
version=version or "unknown",
severity="HIGH",
vulnerability_id=vuln_desc.split(" - ")[0]
if " - " in vuln_desc
else "VULN",
title=vuln_desc,
description=f"Package {pkg_name} version {version} has known vulnerabilities",
fixed_version=version_constraint.replace("< ", ""),
)
)
vulnerable_count += 1
# Add recommendations
if vulnerable_count > 0:
recommendations.append(
f"Update {vulnerable_count} packages with known vulnerabilities"
)
if total_packages > 50:
recommendations.append(
"Consider auditing dependencies to reduce attack surface"
)
return DependencyReport(
total_packages=total_packages,
vulnerable_packages=vulnerable_count,
outdated_packages=outdated_count,
findings=findings,
recommendations=recommendations,
files_scanned=files_scanned,
)
def _parse_python_deps(
self, content: str, filepath: str
) -> list[tuple[str, str | None]]:
"""Parse Python dependency file."""
packages = []
if "requirements" in filepath.lower():
# requirements.txt format
for line in content.splitlines():
line = line.strip()
if not line or line.startswith("#") or line.startswith("-"):
continue
# Parse package==version, package>=version, package
match = re.match(r"([a-zA-Z0-9_-]+)([<>=!]+)?(.+)?", line)
if match:
pkg_name = match.group(1)
version = match.group(3) if match.group(3) else None
packages.append((pkg_name, version))
elif filepath.endswith("pyproject.toml"):
# pyproject.toml format
in_deps = False
for line in content.splitlines():
if (
"[project.dependencies]" in line
or "[tool.poetry.dependencies]" in line
):
in_deps = True
continue
if in_deps:
if line.startswith("["):
in_deps = False
continue
match = re.match(r'"?([a-zA-Z0-9_-]+)"?\s*[=<>]', line)
if match:
packages.append((match.group(1), None))
return packages
def _parse_javascript_deps(
self, content: str, filepath: str
) -> list[tuple[str, str | None]]:
"""Parse JavaScript dependency file."""
packages = []
if filepath.endswith("package.json"):
try:
data = json.loads(content)
for dep_type in ["dependencies", "devDependencies"]:
deps = data.get(dep_type, {})
for name, version in deps.items():
# Strip version prefixes like ^, ~, >=
clean_version = re.sub(r"^[\^~>=<]+", "", version)
packages.append((name, clean_version))
except json.JSONDecodeError:
pass
return packages
def _version_matches_constraint(self, version: str | None, constraint: str) -> bool:
"""Check if version matches a vulnerability constraint."""
if not version:
return True # Assume vulnerable if version unknown
# Simple version comparison
if constraint.startswith("< "):
target = constraint[2:]
try:
return self._compare_versions(version, target) < 0
except Exception:
return False
return False
def _compare_versions(self, v1: str, v2: str) -> int:
"""Compare two version strings. Returns -1, 0, or 1."""
def normalize(v):
return [int(x) for x in re.sub(r"[^0-9.]", "", v).split(".") if x]
try:
parts1 = normalize(v1)
parts2 = normalize(v2)
for i in range(max(len(parts1), len(parts2))):
p1 = parts1[i] if i < len(parts1) else 0
p2 = parts2[i] if i < len(parts2) else 0
if p1 < p2:
return -1
if p1 > p2:
return 1
return 0
except Exception:
return 0
def _run_external_scanners(self, owner: str, repo: str) -> list[DependencyFinding]:
"""Run external vulnerability scanners if available."""
findings = []
agent_config = self.config.get("agents", {}).get("dependency", {})
# Try pip-audit for Python
if agent_config.get("pip_audit", False):
try:
result = subprocess.run(
["pip-audit", "--format", "json"],
capture_output=True,
text=True,
timeout=60,
)
if result.returncode == 0:
data = json.loads(result.stdout)
for vuln in data.get("vulnerabilities", []):
findings.append(
DependencyFinding(
package=vuln.get("name", ""),
version=vuln.get("version", ""),
severity=vuln.get("severity", "MEDIUM"),
vulnerability_id=vuln.get("id", ""),
title=vuln.get("description", "")[:100],
description=vuln.get("description", ""),
fixed_version=vuln.get("fix_versions", [None])[0],
)
)
except Exception as e:
self.logger.debug(f"pip-audit not available: {e}")
# Try npm audit for JavaScript
if agent_config.get("npm_audit", False):
try:
result = subprocess.run(
["npm", "audit", "--json"],
capture_output=True,
text=True,
timeout=60,
)
data = json.loads(result.stdout)
for vuln_id, vuln in data.get("vulnerabilities", {}).items():
findings.append(
DependencyFinding(
package=vuln.get("name", vuln_id),
version=vuln.get("range", ""),
severity=vuln.get("severity", "moderate").upper(),
vulnerability_id=vuln_id,
title=vuln.get("title", ""),
description=vuln.get("overview", ""),
fixed_version=vuln.get("fixAvailable", {}).get("version"),
)
)
except Exception as e:
self.logger.debug(f"npm audit not available: {e}")
return findings
def _format_dependency_report(
self, report: DependencyReport, user: str | None = None
) -> str:
"""Format the dependency report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🔍 Dependency Security Scan",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Total Packages | {report.total_packages} |",
f"| Vulnerable | {report.vulnerable_packages} |",
f"| Outdated | {report.outdated_packages} |",
f"| Files Scanned | {len(report.files_scanned)} |",
"",
]
)
# Findings by severity
if report.findings:
lines.append("### 🚨 Security Findings")
lines.append("")
# Group by severity
by_severity = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []}
for finding in report.findings:
sev = finding.severity.upper()
if sev in by_severity:
by_severity[sev].append(finding)
severity_emoji = {
"CRITICAL": "🔴",
"HIGH": "🟠",
"MEDIUM": "🟡",
"LOW": "🔵",
}
for severity in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]:
findings = by_severity[severity]
if findings:
lines.append(f"#### {severity_emoji[severity]} {severity}")
lines.append("")
for f in findings[:10]: # Limit display
lines.append(f"- **{f.package}** `{f.version}`")
lines.append(f" - {f.vulnerability_id}: {f.title}")
if f.fixed_version:
lines.append(f" - ✅ Fix: Upgrade to `{f.fixed_version}`")
if len(findings) > 10:
lines.append(f" - ... and {len(findings) - 10} more")
lines.append("")
# Files scanned
lines.append("### 📁 Files Scanned")
lines.append("")
for f in report.files_scanned:
lines.append(f"- `{f}`")
lines.append("")
# Recommendations
if report.recommendations:
lines.append("### 💡 Recommendations")
lines.append("")
for rec in report.recommendations:
lines.append(f"- {rec}")
lines.append("")
# Overall status
if report.vulnerable_packages == 0:
lines.append("---")
lines.append("✅ **No known vulnerabilities detected**")
else:
lines.append("---")
lines.append(
f"⚠️ **{report.vulnerable_packages} vulnerable packages require attention**"
)
return "\n".join(lines)

View File

@@ -365,9 +365,20 @@ class IssueAgent(BaseAgent):
"commands", ["explain", "suggest", "security", "summarize", "triage"] "commands", ["explain", "suggest", "security", "summarize", "triage"]
) )
# Also check for setup-labels command (not in config since it's a setup command) # Built-in commands not in config
if f"{mention_prefix} setup-labels" in body.lower(): builtin_commands = [
return "setup-labels" "setup-labels",
"check-deps",
"suggest-tests",
"refactor-suggest",
"architecture",
"arch-check",
]
# Check built-in commands first
for command in builtin_commands:
if f"{mention_prefix} {command}" in body.lower():
return command
for command in commands: for command in commands:
if f"{mention_prefix} {command}" in body.lower(): if f"{mention_prefix} {command}" in body.lower():
@@ -392,6 +403,14 @@ class IssueAgent(BaseAgent):
return self._command_triage(context, issue) return self._command_triage(context, issue)
elif command == "setup-labels": elif command == "setup-labels":
return self._command_setup_labels(context, issue) return self._command_setup_labels(context, issue)
elif command == "check-deps":
return self._command_check_deps(context)
elif command == "suggest-tests":
return self._command_suggest_tests(context)
elif command == "refactor-suggest":
return self._command_refactor_suggest(context)
elif command == "architecture" or command == "arch-check":
return self._command_architecture(context)
return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`." return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`."
@@ -464,6 +483,12 @@ Be practical and concise."""
- `{mention_prefix} suggest` - Solution suggestions or next steps - `{mention_prefix} suggest` - Solution suggestions or next steps
- `{mention_prefix} security` - Security-focused analysis of the issue - `{mention_prefix} security` - Security-focused analysis of the issue
### Code Quality & Security
- `{mention_prefix} check-deps` - Scan dependencies for security vulnerabilities
- `{mention_prefix} suggest-tests` - Suggest test cases for changed/new code
- `{mention_prefix} refactor-suggest` - Suggest refactoring opportunities
- `{mention_prefix} architecture` - Check architecture compliance (alias: `arch-check`)
### Interactive Chat ### Interactive Chat
- `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools) - `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools)
- Example: `{mention_prefix} how does authentication work?` - Example: `{mention_prefix} how does authentication work?`
@@ -494,9 +519,19 @@ PR reviews run automatically when you open or update a pull request. The bot pro
{mention_prefix} triage {mention_prefix} triage
``` ```
**Get help understanding:** **Check for dependency vulnerabilities:**
``` ```
{mention_prefix} explain {mention_prefix} check-deps
```
**Get test suggestions:**
```
{mention_prefix} suggest-tests
```
**Check architecture compliance:**
```
{mention_prefix} architecture
``` ```
**Ask about the codebase:** **Ask about the codebase:**
@@ -504,11 +539,6 @@ PR reviews run automatically when you open or update a pull request. The bot pro
{mention_prefix} how does the authentication system work? {mention_prefix} how does the authentication system work?
``` ```
**Setup repository labels:**
```
{mention_prefix} setup-labels
```
--- ---
*For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)* *For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)*
@@ -854,3 +884,145 @@ PR reviews run automatically when you open or update a pull request. The bot pro
return f"{prefix} - {value}" return f"{prefix} - {value}"
else: # colon or unknown else: # colon or unknown
return base_name return base_name
def _command_check_deps(self, context: AgentContext) -> str:
"""Check dependencies for security vulnerabilities."""
try:
from agents.dependency_agent import DependencyAgent
agent = DependencyAgent(config=self.config)
result = agent.run(context)
if result.success:
return result.data.get(
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
)
else:
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Failed**\n\n{result.error or result.message}"
except ImportError:
return f"{self.AI_DISCLAIMER}\n\n**Dependency Agent Not Available**\n\nThe dependency security agent is not installed."
except Exception as e:
self.logger.error(f"Dependency check failed: {e}")
return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Error**\n\n{e}"
def _command_suggest_tests(self, context: AgentContext) -> str:
"""Suggest tests for changed or new code."""
try:
from agents.test_coverage_agent import TestCoverageAgent
agent = TestCoverageAgent(config=self.config)
result = agent.run(context)
if result.success:
return result.data.get(
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
)
else:
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Failed**\n\n{result.error or result.message}"
except ImportError:
return f"{self.AI_DISCLAIMER}\n\n**Test Coverage Agent Not Available**\n\nThe test coverage agent is not installed."
except Exception as e:
self.logger.error(f"Test suggestion failed: {e}")
return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Error**\n\n{e}"
def _command_architecture(self, context: AgentContext) -> str:
"""Check architecture compliance."""
try:
from agents.architecture_agent import ArchitectureAgent
agent = ArchitectureAgent(config=self.config)
result = agent.run(context)
if result.success:
return result.data.get(
"report", f"{self.AI_DISCLAIMER}\n\n{result.message}"
)
else:
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Failed**\n\n{result.error or result.message}"
except ImportError:
return f"{self.AI_DISCLAIMER}\n\n**Architecture Agent Not Available**\n\nThe architecture compliance agent is not installed."
except Exception as e:
self.logger.error(f"Architecture check failed: {e}")
return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Error**\n\n{e}"
def _command_refactor_suggest(self, context: AgentContext) -> str:
"""Suggest refactoring opportunities."""
issue = context.event_data.get("issue", {})
title = issue.get("title", "")
body = issue.get("body", "")
# Use LLM to analyze for refactoring opportunities
prompt = f"""Analyze the following issue/context and suggest refactoring opportunities:
Issue Title: {title}
Issue Body: {body}
Based on common refactoring patterns, suggest:
1. Code smell detection (if any code is referenced)
2. Design pattern opportunities
3. Complexity reduction suggestions
4. DRY principle violations
5. SOLID principle improvements
Format your response as a structured report with actionable recommendations.
If no code is referenced in the issue, provide general refactoring guidance based on the context.
Return as JSON:
{{
"summary": "Brief summary of refactoring opportunities",
"suggestions": [
{{
"category": "Code Smell | Design Pattern | Complexity | DRY | SOLID",
"title": "Short title",
"description": "Detailed description",
"priority": "high | medium | low",
"effort": "small | medium | large"
}}
],
"general_advice": "Any general refactoring advice"
}}"""
try:
result = self.call_llm_json(prompt)
lines = [f"{self.AI_DISCLAIMER}\n"]
lines.append("## Refactoring Suggestions\n")
if result.get("summary"):
lines.append(f"**Summary:** {result['summary']}\n")
suggestions = result.get("suggestions", [])
if suggestions:
lines.append("### Recommendations\n")
lines.append("| Priority | Category | Suggestion | Effort |")
lines.append("|----------|----------|------------|--------|")
for s in suggestions:
priority = s.get("priority", "medium").upper()
priority_icon = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🟢"}.get(
priority, ""
)
lines.append(
f"| {priority_icon} {priority} | {s.get('category', 'General')} | "
f"**{s.get('title', 'Suggestion')}** | {s.get('effort', 'medium')} |"
)
lines.append("")
# Detailed descriptions
lines.append("### Details\n")
for i, s in enumerate(suggestions, 1):
lines.append(f"**{i}. {s.get('title', 'Suggestion')}**")
lines.append(f"{s.get('description', 'No description')}\n")
if result.get("general_advice"):
lines.append("### General Advice\n")
lines.append(result["general_advice"])
return "\n".join(lines)
except Exception as e:
self.logger.error(f"Refactor suggestion failed: {e}")
return (
f"{self.AI_DISCLAIMER}\n\n**Refactor Suggestion Failed**\n\nError: {e}"
)

View File

@@ -0,0 +1,480 @@
"""Test Coverage Agent
AI agent for analyzing code changes and suggesting test cases.
Helps improve test coverage by identifying untested code paths.
"""
import base64
import os
import re
from dataclasses import dataclass, field
from agents.base_agent import AgentContext, AgentResult, BaseAgent
@dataclass
class TestSuggestion:
"""A suggested test case."""
function_name: str
file_path: str
test_type: str # unit, integration, edge_case
description: str
example_code: str | None = None
priority: str = "MEDIUM" # HIGH, MEDIUM, LOW
@dataclass
class CoverageReport:
"""Report of test coverage analysis."""
functions_analyzed: int
functions_with_tests: int
functions_without_tests: int
suggestions: list[TestSuggestion]
existing_tests: list[str]
coverage_estimate: float
class TestCoverageAgent(BaseAgent):
"""Agent for analyzing test coverage and suggesting tests."""
# Marker for test coverage comments
TEST_AI_MARKER = "<!-- AI_TEST_COVERAGE -->"
# Test file patterns by language
TEST_PATTERNS = {
"python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"],
"javascript": [
r".*\.test\.[jt]sx?$",
r".*\.spec\.[jt]sx?$",
r"__tests__/.*\.[jt]sx?$",
],
"go": [r".*_test\.go$"],
"rust": [r"tests?/.*\.rs$"],
"java": [r".*Test\.java$", r".*Tests\.java$"],
"ruby": [r".*_spec\.rb$", r"test_.*\.rb$"],
}
# Function/method patterns by language
FUNCTION_PATTERNS = {
"python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(",
"javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))",
"go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(",
"rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)",
"java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{",
"ruby": r"^\s*def\s+(\w+)",
}
def can_handle(self, event_type: str, event_data: dict) -> bool:
"""Check if this agent handles the given event."""
agent_config = self.config.get("agents", {}).get("test_coverage", {})
if not agent_config.get("enabled", True):
return False
# Handle @codebot suggest-tests command
if event_type == "issue_comment":
comment_body = event_data.get("comment", {}).get("body", "")
mention_prefix = self.config.get("interaction", {}).get(
"mention_prefix", "@codebot"
)
if f"{mention_prefix} suggest-tests" in comment_body.lower():
return True
return False
def execute(self, context: AgentContext) -> AgentResult:
"""Execute the test coverage agent."""
self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}")
actions_taken = []
# Get issue/PR number and author
issue = context.event_data.get("issue", {})
issue_number = issue.get("number")
comment_author = (
context.event_data.get("comment", {}).get("user", {}).get("login", "user")
)
# Check if this is a PR
is_pr = issue.get("pull_request") is not None
if is_pr:
# Analyze PR diff for changed functions
diff = self._get_pr_diff(context.owner, context.repo, issue_number)
changed_functions = self._extract_changed_functions(diff)
actions_taken.append(f"Analyzed {len(changed_functions)} changed functions")
else:
# Analyze entire repository
changed_functions = self._analyze_repository(context.owner, context.repo)
actions_taken.append(
f"Analyzed {len(changed_functions)} functions in repository"
)
# Find existing tests
existing_tests = self._find_existing_tests(context.owner, context.repo)
actions_taken.append(f"Found {len(existing_tests)} existing test files")
# Generate test suggestions using LLM
report = self._generate_suggestions(
context.owner, context.repo, changed_functions, existing_tests
)
# Post report
if issue_number:
comment = self._format_coverage_report(report, comment_author, is_pr)
self.upsert_comment(
context.owner,
context.repo,
issue_number,
comment,
marker=self.TEST_AI_MARKER,
)
actions_taken.append("Posted test coverage report")
return AgentResult(
success=True,
message=f"Generated {len(report.suggestions)} test suggestions",
data={
"functions_analyzed": report.functions_analyzed,
"suggestions_count": len(report.suggestions),
"coverage_estimate": report.coverage_estimate,
},
actions_taken=actions_taken,
)
def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str:
"""Get the PR diff."""
try:
return self.gitea.get_pull_request_diff(owner, repo, pr_number)
except Exception as e:
self.logger.error(f"Failed to get PR diff: {e}")
return ""
def _extract_changed_functions(self, diff: str) -> list[dict]:
"""Extract changed functions from diff."""
functions = []
current_file = None
current_language = None
for line in diff.splitlines():
# Track current file
if line.startswith("diff --git"):
match = re.search(r"b/(.+)$", line)
if match:
current_file = match.group(1)
current_language = self._detect_language(current_file)
# Look for function definitions in added lines
if line.startswith("+") and not line.startswith("+++"):
if current_language and current_language in self.FUNCTION_PATTERNS:
pattern = self.FUNCTION_PATTERNS[current_language]
match = re.search(pattern, line[1:]) # Skip the + prefix
if match:
func_name = next(g for g in match.groups() if g)
functions.append(
{
"name": func_name,
"file": current_file,
"language": current_language,
"line": line[1:].strip(),
}
)
return functions
def _analyze_repository(self, owner: str, repo: str) -> list[dict]:
"""Analyze repository for functions without tests."""
functions = []
code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"}
# Get repository contents (limited to avoid API exhaustion)
try:
contents = self.gitea.get_file_contents(owner, repo, "")
if isinstance(contents, list):
for item in contents[:50]: # Limit files
if item.get("type") == "file":
filepath = item.get("path", "")
ext = os.path.splitext(filepath)[1]
if ext in code_extensions:
file_functions = self._extract_functions_from_file(
owner, repo, filepath
)
functions.extend(file_functions)
except Exception as e:
self.logger.warning(f"Failed to analyze repository: {e}")
return functions[:100] # Limit total functions
def _extract_functions_from_file(
self, owner: str, repo: str, filepath: str
) -> list[dict]:
"""Extract function definitions from a file."""
functions = []
language = self._detect_language(filepath)
if not language or language not in self.FUNCTION_PATTERNS:
return functions
try:
content_data = self.gitea.get_file_contents(owner, repo, filepath)
if content_data.get("content"):
content = base64.b64decode(content_data["content"]).decode(
"utf-8", errors="ignore"
)
pattern = self.FUNCTION_PATTERNS[language]
for i, line in enumerate(content.splitlines(), 1):
match = re.search(pattern, line)
if match:
func_name = next((g for g in match.groups() if g), None)
if func_name and not func_name.startswith("_"):
functions.append(
{
"name": func_name,
"file": filepath,
"language": language,
"line_number": i,
}
)
except Exception:
pass
return functions
def _detect_language(self, filepath: str) -> str | None:
"""Detect programming language from file path."""
ext_map = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "javascript",
".tsx": "javascript",
".go": "go",
".rs": "rust",
".java": "java",
".rb": "ruby",
}
ext = os.path.splitext(filepath)[1]
return ext_map.get(ext)
def _find_existing_tests(self, owner: str, repo: str) -> list[str]:
"""Find existing test files in the repository."""
test_files = []
# Common test directories
test_dirs = ["tests", "test", "__tests__", "spec"]
for test_dir in test_dirs:
try:
contents = self.gitea.get_file_contents(owner, repo, test_dir)
if isinstance(contents, list):
for item in contents:
if item.get("type") == "file":
test_files.append(item.get("path", ""))
except Exception:
pass
# Also check root for test files
try:
contents = self.gitea.get_file_contents(owner, repo, "")
if isinstance(contents, list):
for item in contents:
if item.get("type") == "file":
filepath = item.get("path", "")
if self._is_test_file(filepath):
test_files.append(filepath)
except Exception:
pass
return test_files
def _is_test_file(self, filepath: str) -> bool:
"""Check if a file is a test file."""
for lang, patterns in self.TEST_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, filepath):
return True
return False
def _generate_suggestions(
self,
owner: str,
repo: str,
functions: list[dict],
existing_tests: list[str],
) -> CoverageReport:
"""Generate test suggestions using LLM."""
suggestions = []
# Build prompt for LLM
if functions:
functions_text = "\n".join(
[
f"- {f['name']} in {f['file']} ({f['language']})"
for f in functions[:20] # Limit for prompt size
]
)
prompt = f"""Analyze these functions and suggest test cases:
Functions to test:
{functions_text}
Existing test files:
{", ".join(existing_tests[:10]) if existing_tests else "None found"}
For each function, suggest:
1. What to test (happy path, edge cases, error handling)
2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities)
3. Brief example test code if possible
Respond in JSON format:
{{
"suggestions": [
{{
"function_name": "function_name",
"file_path": "path/to/file",
"test_type": "unit|integration|edge_case",
"description": "What to test",
"example_code": "brief example or null",
"priority": "HIGH|MEDIUM|LOW"
}}
],
"coverage_estimate": 0.0 to 1.0
}}
"""
try:
result = self.call_llm_json(prompt)
for s in result.get("suggestions", []):
suggestions.append(
TestSuggestion(
function_name=s.get("function_name", ""),
file_path=s.get("file_path", ""),
test_type=s.get("test_type", "unit"),
description=s.get("description", ""),
example_code=s.get("example_code"),
priority=s.get("priority", "MEDIUM"),
)
)
coverage_estimate = result.get("coverage_estimate", 0.5)
except Exception as e:
self.logger.warning(f"LLM suggestion failed: {e}")
# Generate basic suggestions without LLM
for f in functions[:10]:
suggestions.append(
TestSuggestion(
function_name=f["name"],
file_path=f["file"],
test_type="unit",
description=f"Add unit tests for {f['name']}",
priority="MEDIUM",
)
)
coverage_estimate = 0.5
else:
coverage_estimate = 1.0 if existing_tests else 0.0
# Estimate functions with tests
functions_with_tests = int(len(functions) * coverage_estimate)
return CoverageReport(
functions_analyzed=len(functions),
functions_with_tests=functions_with_tests,
functions_without_tests=len(functions) - functions_with_tests,
suggestions=suggestions,
existing_tests=existing_tests,
coverage_estimate=coverage_estimate,
)
def _format_coverage_report(
self, report: CoverageReport, user: str | None, is_pr: bool
) -> str:
"""Format the coverage report as a comment."""
lines = []
if user:
lines.append(f"@{user}")
lines.append("")
lines.extend(
[
f"{self.AI_DISCLAIMER}",
"",
"## 🧪 Test Coverage Analysis",
"",
"### Summary",
"",
f"| Metric | Value |",
f"|--------|-------|",
f"| Functions Analyzed | {report.functions_analyzed} |",
f"| Estimated Coverage | {report.coverage_estimate:.0%} |",
f"| Test Files Found | {len(report.existing_tests)} |",
f"| Suggestions | {len(report.suggestions)} |",
"",
]
)
# Suggestions by priority
if report.suggestions:
lines.append("### 💡 Test Suggestions")
lines.append("")
# Group by priority
by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []}
for s in report.suggestions:
if s.priority in by_priority:
by_priority[s.priority].append(s)
priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"}
for priority in ["HIGH", "MEDIUM", "LOW"]:
suggestions = by_priority[priority]
if suggestions:
lines.append(f"#### {priority_emoji[priority]} {priority} Priority")
lines.append("")
for s in suggestions[:5]: # Limit display
lines.append(f"**`{s.function_name}`** in `{s.file_path}`")
lines.append(f"- Type: {s.test_type}")
lines.append(f"- {s.description}")
if s.example_code:
lines.append(f"```")
lines.append(s.example_code[:200])
lines.append(f"```")
lines.append("")
if len(suggestions) > 5:
lines.append(f"*... and {len(suggestions) - 5} more*")
lines.append("")
# Existing test files
if report.existing_tests:
lines.append("### 📁 Existing Test Files")
lines.append("")
for f in report.existing_tests[:10]:
lines.append(f"- `{f}`")
if len(report.existing_tests) > 10:
lines.append(f"- *... and {len(report.existing_tests) - 10} more*")
lines.append("")
# Coverage bar
lines.append("### 📊 Coverage Estimate")
lines.append("")
filled = int(report.coverage_estimate * 10)
bar = "" * filled + "" * (10 - filled)
lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}")
lines.append("")
# Recommendations
if report.coverage_estimate < 0.8:
lines.append("---")
lines.append("⚠️ **Coverage below 80%** - Consider adding more tests")
elif report.coverage_estimate >= 0.8:
lines.append("---")
lines.append("✅ **Good test coverage!**")
return "\n".join(lines)

View File

@@ -77,11 +77,13 @@ class OpenAIProvider(BaseLLMProvider):
model: str = "gpt-4o-mini", model: str = "gpt-4o-mini",
temperature: float = 0, temperature: float = 0,
max_tokens: int = 4096, max_tokens: int = 4096,
timeout: int = 120,
): ):
self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
self.model = model self.model = model
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.timeout = timeout
self.api_url = "https://api.openai.com/v1/chat/completions" self.api_url = "https://api.openai.com/v1/chat/completions"
def call(self, prompt: str, **kwargs) -> LLMResponse: def call(self, prompt: str, **kwargs) -> LLMResponse:
@@ -101,7 +103,7 @@ class OpenAIProvider(BaseLLMProvider):
"max_tokens": kwargs.get("max_tokens", self.max_tokens), "max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
}, },
timeout=120, timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -145,7 +147,7 @@ class OpenAIProvider(BaseLLMProvider):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json=request_body, json=request_body,
timeout=120, timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -186,11 +188,13 @@ class OpenRouterProvider(BaseLLMProvider):
model: str = "anthropic/claude-3.5-sonnet", model: str = "anthropic/claude-3.5-sonnet",
temperature: float = 0, temperature: float = 0,
max_tokens: int = 4096, max_tokens: int = 4096,
timeout: int = 120,
): ):
self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "") self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "")
self.model = model self.model = model
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.timeout = timeout
self.api_url = "https://openrouter.ai/api/v1/chat/completions" self.api_url = "https://openrouter.ai/api/v1/chat/completions"
def call(self, prompt: str, **kwargs) -> LLMResponse: def call(self, prompt: str, **kwargs) -> LLMResponse:
@@ -210,7 +214,7 @@ class OpenRouterProvider(BaseLLMProvider):
"max_tokens": kwargs.get("max_tokens", self.max_tokens), "max_tokens": kwargs.get("max_tokens", self.max_tokens),
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
}, },
timeout=120, timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -254,7 +258,7 @@ class OpenRouterProvider(BaseLLMProvider):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json=request_body, json=request_body,
timeout=120, timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -294,10 +298,12 @@ class OllamaProvider(BaseLLMProvider):
host: str | None = None, host: str | None = None,
model: str = "codellama:13b", model: str = "codellama:13b",
temperature: float = 0, temperature: float = 0,
timeout: int = 300,
): ):
self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434") self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434")
self.model = model self.model = model
self.temperature = temperature self.temperature = temperature
self.timeout = timeout
def call(self, prompt: str, **kwargs) -> LLMResponse: def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Call Ollama API.""" """Call Ollama API."""
@@ -311,7 +317,7 @@ class OllamaProvider(BaseLLMProvider):
"temperature": kwargs.get("temperature", self.temperature), "temperature": kwargs.get("temperature", self.temperature),
}, },
}, },
timeout=300, # Longer timeout for local models timeout=self.timeout,
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
@@ -477,12 +483,18 @@ class LLMClient:
provider = config.get("provider", "openai") provider = config.get("provider", "openai")
provider_config = {} provider_config = {}
# Get timeout configuration
timeouts = config.get("timeouts", {})
llm_timeout = timeouts.get("llm", 120)
ollama_timeout = timeouts.get("ollama", 300)
# Map config keys to provider-specific settings # Map config keys to provider-specific settings
if provider == "openai": if provider == "openai":
provider_config = { provider_config = {
"model": config.get("model", {}).get("openai", "gpt-4o-mini"), "model": config.get("model", {}).get("openai", "gpt-4o-mini"),
"temperature": config.get("temperature", 0), "temperature": config.get("temperature", 0),
"max_tokens": config.get("max_tokens", 16000), "max_tokens": config.get("max_tokens", 16000),
"timeout": llm_timeout,
} }
elif provider == "openrouter": elif provider == "openrouter":
provider_config = { provider_config = {
@@ -491,11 +503,13 @@ class LLMClient:
), ),
"temperature": config.get("temperature", 0), "temperature": config.get("temperature", 0),
"max_tokens": config.get("max_tokens", 16000), "max_tokens": config.get("max_tokens", 16000),
"timeout": llm_timeout,
} }
elif provider == "ollama": elif provider == "ollama":
provider_config = { provider_config = {
"model": config.get("model", {}).get("ollama", "codellama:13b"), "model": config.get("model", {}).get("ollama", "codellama:13b"),
"temperature": config.get("temperature", 0), "temperature": config.get("temperature", 0),
"timeout": ollama_timeout,
} }
return cls(provider=provider, config=provider_config) return cls(provider=provider, config=provider_config)

View File

@@ -0,0 +1,27 @@
"""LLM Providers Package
This package contains additional LLM provider implementations
beyond the core providers in llm_client.py.
Providers:
- AnthropicProvider: Direct Anthropic Claude API
- AzureOpenAIProvider: Azure OpenAI Service with API key auth
- AzureOpenAIWithAADProvider: Azure OpenAI with Azure AD auth
- GeminiProvider: Google Gemini API (public)
- VertexAIGeminiProvider: Google Vertex AI Gemini (enterprise GCP)
"""
from clients.providers.anthropic_provider import AnthropicProvider
from clients.providers.azure_provider import (
AzureOpenAIProvider,
AzureOpenAIWithAADProvider,
)
from clients.providers.gemini_provider import GeminiProvider, VertexAIGeminiProvider
__all__ = [
"AnthropicProvider",
"AzureOpenAIProvider",
"AzureOpenAIWithAADProvider",
"GeminiProvider",
"VertexAIGeminiProvider",
]

View File

@@ -0,0 +1,249 @@
"""Anthropic Claude Provider
Direct integration with Anthropic's Claude API.
Supports Claude 3.5 Sonnet, Claude 3 Opus, and other models.
"""
import json
import os
# Import base classes from parent module
import sys
from dataclasses import dataclass
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class AnthropicProvider(BaseLLMProvider):
"""Anthropic Claude API provider.
Provides direct integration with Anthropic's Claude models
without going through OpenRouter.
Supports:
- Claude 3.5 Sonnet (claude-3-5-sonnet-20241022)
- Claude 3 Opus (claude-3-opus-20240229)
- Claude 3 Sonnet (claude-3-sonnet-20240229)
- Claude 3 Haiku (claude-3-haiku-20240307)
"""
API_URL = "https://api.anthropic.com/v1/messages"
API_VERSION = "2023-06-01"
def __init__(
self,
api_key: str | None = None,
model: str = "claude-3-5-sonnet-20241022",
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Anthropic provider.
Args:
api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var.
model: Model to use. Defaults to Claude 3.5 Sonnet.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
"""
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", "")
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Anthropic API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (model, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If API key is not set.
requests.HTTPError: If the API request fails.
"""
if not self.api_key:
raise ValueError("Anthropic API key is required")
response = requests.post(
self.API_URL,
headers={
"x-api-key": self.api_key,
"anthropic-version": self.API_VERSION,
"Content-Type": "application/json",
},
json={
"model": kwargs.get("model", self.model),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
"messages": [{"role": "user", "content": prompt}],
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
for block in data.get("content", []):
if block.get("type") == "text":
content += block.get("text", "")
return LLMResponse(
content=content,
model=data.get("model", self.model),
provider="anthropic",
tokens_used=data.get("usage", {}).get("input_tokens", 0)
+ data.get("usage", {}).get("output_tokens", 0),
finish_reason=data.get("stop_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Anthropic API with tool support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.api_key:
raise ValueError("Anthropic API key is required")
# Convert OpenAI-style messages to Anthropic format
anthropic_messages = []
system_content = None
for msg in messages:
role = msg.get("role", "user")
if role == "system":
system_content = msg.get("content", "")
elif role == "assistant":
# Handle assistant messages with tool calls
if msg.get("tool_calls"):
content = []
if msg.get("content"):
content.append({"type": "text", "text": msg["content"]})
for tc in msg["tool_calls"]:
content.append(
{
"type": "tool_use",
"id": tc["id"],
"name": tc["function"]["name"],
"input": json.loads(tc["function"]["arguments"])
if isinstance(tc["function"]["arguments"], str)
else tc["function"]["arguments"],
}
)
anthropic_messages.append({"role": "assistant", "content": content})
else:
anthropic_messages.append(
{
"role": "assistant",
"content": msg.get("content", ""),
}
)
elif role == "tool":
# Tool response
anthropic_messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": msg.get("content", ""),
}
],
}
)
else:
anthropic_messages.append(
{
"role": "user",
"content": msg.get("content", ""),
}
)
# Convert OpenAI-style tools to Anthropic format
anthropic_tools = None
if tools:
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
anthropic_tools.append(
{
"name": func["name"],
"description": func.get("description", ""),
"input_schema": func.get("parameters", {}),
}
)
request_body = {
"model": kwargs.get("model", self.model),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
"messages": anthropic_messages,
}
if system_content:
request_body["system"] = system_content
if anthropic_tools:
request_body["tools"] = anthropic_tools
response = requests.post(
self.API_URL,
headers={
"x-api-key": self.api_key,
"anthropic-version": self.API_VERSION,
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
for block in data.get("content", []):
if block.get("type") == "text":
content += block.get("text", "")
elif block.get("type") == "tool_use":
if tool_calls is None:
tool_calls = []
tool_calls.append(
ToolCall(
id=block.get("id", ""),
name=block.get("name", ""),
arguments=block.get("input", {}),
)
)
return LLMResponse(
content=content,
model=data.get("model", self.model),
provider="anthropic",
tokens_used=data.get("usage", {}).get("input_tokens", 0)
+ data.get("usage", {}).get("output_tokens", 0),
finish_reason=data.get("stop_reason"),
tool_calls=tool_calls,
)

View File

@@ -0,0 +1,420 @@
"""Azure OpenAI Provider
Integration with Azure OpenAI Service for enterprise deployments.
Supports custom deployments, regional endpoints, and Azure AD auth.
"""
import json
import os
import sys
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class AzureOpenAIProvider(BaseLLMProvider):
"""Azure OpenAI Service provider.
Provides integration with Azure-hosted OpenAI models for
enterprise customers with Azure deployments.
Supports:
- GPT-4, GPT-4 Turbo, GPT-4o
- GPT-3.5 Turbo
- Custom fine-tuned models
Environment Variables:
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
- AZURE_OPENAI_API_KEY: API key for authentication
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
- AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview)
"""
DEFAULT_API_VERSION = "2024-02-15-preview"
def __init__(
self,
endpoint: str | None = None,
api_key: str | None = None,
deployment: str | None = None,
api_version: str | None = None,
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Azure OpenAI provider.
Args:
endpoint: Azure OpenAI endpoint URL.
Defaults to AZURE_OPENAI_ENDPOINT env var.
api_key: API key for authentication.
Defaults to AZURE_OPENAI_API_KEY env var.
deployment: Deployment name to use.
Defaults to AZURE_OPENAI_DEPLOYMENT env var.
api_version: API version string.
Defaults to AZURE_OPENAI_API_VERSION env var or latest.
temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens in response.
"""
self.endpoint = (
endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
).rstrip("/")
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
self.api_version = api_version or os.environ.get(
"AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION
)
self.temperature = temperature
self.max_tokens = max_tokens
def _get_api_url(self, deployment: str | None = None) -> str:
"""Build the API URL for a given deployment.
Args:
deployment: Deployment name. Uses default if not specified.
Returns:
Full API URL for chat completions.
"""
deploy = deployment or self.deployment
return (
f"{self.endpoint}/openai/deployments/{deploy}"
f"/chat/completions?api-version={self.api_version}"
)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Azure OpenAI API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (deployment, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If required configuration is missing.
requests.HTTPError: If the API request fails.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.api_key:
raise ValueError("Azure OpenAI API key is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
response = requests.post(
self._get_api_url(deployment),
headers={
"api-key": self.api_key,
"Content-Type": "application/json",
},
json={
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
},
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Azure OpenAI API with tool support.
Azure OpenAI uses the same format as OpenAI for tools.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.api_key:
raise ValueError("Azure OpenAI API key is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
request_body = {
"messages": messages,
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
}
if tools:
request_body["tools"] = tools
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
response = requests.post(
self._get_api_url(deployment),
headers={
"api-key": self.api_key,
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
# Parse tool calls if present
tool_calls = None
if message.get("tool_calls"):
tool_calls = []
for tc in message["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=args,
)
)
return LLMResponse(
content=message.get("content", "") or "",
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
tool_calls=tool_calls,
)
class AzureOpenAIWithAADProvider(AzureOpenAIProvider):
"""Azure OpenAI provider with Azure Active Directory authentication.
Uses Azure AD tokens instead of API keys for authentication.
Requires azure-identity package for token acquisition.
Environment Variables:
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
- AZURE_TENANT_ID: Azure AD tenant ID (optional)
- AZURE_CLIENT_ID: Azure AD client ID (optional)
- AZURE_CLIENT_SECRET: Azure AD client secret (optional)
"""
SCOPE = "https://cognitiveservices.azure.com/.default"
def __init__(
self,
endpoint: str | None = None,
deployment: str | None = None,
api_version: str | None = None,
temperature: float = 0,
max_tokens: int = 4096,
credential=None,
):
"""Initialize the Azure OpenAI AAD provider.
Args:
endpoint: Azure OpenAI endpoint URL.
deployment: Deployment name to use.
api_version: API version string.
temperature: Sampling temperature (0-2).
max_tokens: Maximum tokens in response.
credential: Azure credential object. If not provided,
uses DefaultAzureCredential.
"""
super().__init__(
endpoint=endpoint,
api_key="", # Not used with AAD
deployment=deployment,
api_version=api_version,
temperature=temperature,
max_tokens=max_tokens,
)
self._credential = credential
self._token = None
self._token_expires_at = 0
def _get_token(self) -> str:
"""Get an Azure AD token for authentication.
Returns:
Bearer token string.
Raises:
ImportError: If azure-identity is not installed.
"""
import time
# Return cached token if still valid (with 5 min buffer)
if self._token and self._token_expires_at > time.time() + 300:
return self._token
try:
from azure.identity import DefaultAzureCredential
except ImportError:
raise ImportError(
"azure-identity package is required for AAD authentication. "
"Install with: pip install azure-identity"
)
if self._credential is None:
self._credential = DefaultAzureCredential()
token = self._credential.get_token(self.SCOPE)
self._token = token.token
self._token_expires_at = token.expires_on
return self._token
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Azure OpenAI API using AAD auth.
Args:
prompt: The prompt to send.
**kwargs: Additional options.
Returns:
LLMResponse with the generated content.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
token = self._get_token()
response = requests.post(
self._get_api_url(deployment),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
},
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Azure OpenAI API with tool support using AAD auth.
Args:
messages: List of message dicts.
tools: List of tool definitions.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.endpoint:
raise ValueError("Azure OpenAI endpoint is required")
if not self.deployment and not kwargs.get("deployment"):
raise ValueError("Azure OpenAI deployment name is required")
deployment = kwargs.get("deployment", self.deployment)
token = self._get_token()
request_body = {
"messages": messages,
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"temperature": kwargs.get("temperature", self.temperature),
}
if tools:
request_body["tools"] = tools
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
response = requests.post(
self._get_api_url(deployment),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
# Parse tool calls if present
tool_calls = None
if message.get("tool_calls"):
tool_calls = []
for tc in message["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
name=func.get("name", ""),
arguments=args,
)
)
return LLMResponse(
content=message.get("content", "") or "",
model=data.get("model", deployment),
provider="azure",
tokens_used=data.get("usage", {}).get("total_tokens", 0),
finish_reason=choice.get("finish_reason"),
tool_calls=tool_calls,
)

View File

@@ -0,0 +1,599 @@
"""Google Gemini Provider
Integration with Google's Gemini API for GCP customers.
Supports Gemini Pro, Gemini Ultra, and other models.
"""
import json
import os
import sys
import requests
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
class GeminiProvider(BaseLLMProvider):
"""Google Gemini API provider.
Provides integration with Google's Gemini models.
Supports:
- Gemini 1.5 Pro (gemini-1.5-pro)
- Gemini 1.5 Flash (gemini-1.5-flash)
- Gemini 1.0 Pro (gemini-pro)
Environment Variables:
- GOOGLE_API_KEY: Google AI API key
- GEMINI_MODEL: Default model (optional)
"""
API_URL = "https://generativelanguage.googleapis.com/v1beta/models"
def __init__(
self,
api_key: str | None = None,
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
):
"""Initialize the Gemini provider.
Args:
api_key: Google API key. Defaults to GOOGLE_API_KEY env var.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
"""
self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "")
self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
def _get_api_url(self, model: str | None = None, stream: bool = False) -> str:
"""Build the API URL for a given model.
Args:
model: Model name. Uses default if not specified.
stream: Whether to use streaming endpoint.
Returns:
Full API URL.
"""
m = model or self.model
action = "streamGenerateContent" if stream else "generateContent"
return f"{self.API_URL}/{m}:{action}?key={self.api_key}"
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to the Gemini API.
Args:
prompt: The prompt to send.
**kwargs: Additional options (model, temperature, max_tokens).
Returns:
LLMResponse with the generated content.
Raises:
ValueError: If API key is not set.
requests.HTTPError: If the API request fails.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to the Gemini API with tool support.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: List of tool definitions in OpenAI format.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.api_key:
raise ValueError("Google API key is required")
model = kwargs.get("model", self.model)
# Convert OpenAI-style messages to Gemini format
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
# Handle assistant messages with tool calls
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
# Tool response in Gemini format
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
# User message
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert OpenAI-style tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={"Content-Type": "application/json"},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="gemini",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)
class VertexAIGeminiProvider(BaseLLMProvider):
"""Google Vertex AI Gemini provider for enterprise GCP deployments.
Uses Vertex AI endpoints instead of the public Gemini API.
Supports regional deployments and IAM authentication.
Environment Variables:
- GOOGLE_CLOUD_PROJECT: GCP project ID
- GOOGLE_CLOUD_REGION: GCP region (default: us-central1)
- VERTEX_AI_MODEL: Default model (optional)
"""
def __init__(
self,
project: str | None = None,
region: str = "us-central1",
model: str = "gemini-1.5-pro",
temperature: float = 0,
max_tokens: int = 4096,
credentials=None,
):
"""Initialize the Vertex AI Gemini provider.
Args:
project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var.
region: GCP region. Defaults to us-central1.
model: Model to use. Defaults to gemini-1.5-pro.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens in response.
credentials: Google credentials object. If not provided,
uses Application Default Credentials.
"""
self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "")
self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")
self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro")
self.temperature = temperature
self.max_tokens = max_tokens
self._credentials = credentials
self._token = None
self._token_expires_at = 0
def _get_token(self) -> str:
"""Get a Google Cloud access token.
Returns:
Access token string.
Raises:
ImportError: If google-auth is not installed.
"""
import time
# Return cached token if still valid (with 5 min buffer)
if self._token and self._token_expires_at > time.time() + 300:
return self._token
try:
import google.auth
from google.auth.transport.requests import Request
except ImportError:
raise ImportError(
"google-auth package is required for Vertex AI authentication. "
"Install with: pip install google-auth"
)
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
if not self._credentials.valid:
self._credentials.refresh(Request())
self._token = self._credentials.token
# Tokens typically expire in 1 hour
self._token_expires_at = time.time() + 3500
return self._token
def _get_api_url(self, model: str | None = None) -> str:
"""Build the Vertex AI API URL.
Args:
model: Model name. Uses default if not specified.
Returns:
Full API URL.
"""
m = model or self.model
return (
f"https://{self.region}-aiplatform.googleapis.com/v1/"
f"projects/{self.project}/locations/{self.region}/"
f"publishers/google/models/{m}:generateContent"
)
def call(self, prompt: str, **kwargs) -> LLMResponse:
"""Make a call to Vertex AI Gemini.
Args:
prompt: The prompt to send.
**kwargs: Additional options.
Returns:
LLMResponse with the generated content.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={
"contents": [{"parts": [{"text": prompt}]}],
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
},
timeout=120,
)
response.raise_for_status()
data = response.json()
# Extract content from response
content = ""
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
# Get token counts
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
)
def call_with_tools(
self,
messages: list[dict],
tools: list[dict] | None = None,
**kwargs,
) -> LLMResponse:
"""Make a call to Vertex AI Gemini with tool support.
Args:
messages: List of message dicts.
tools: List of tool definitions.
**kwargs: Additional options.
Returns:
LLMResponse with content and/or tool_calls.
"""
if not self.project:
raise ValueError("GCP project ID is required")
model = kwargs.get("model", self.model)
token = self._get_token()
# Convert messages to Gemini format (same as GeminiProvider)
gemini_contents = []
system_instruction = None
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
system_instruction = content
elif role == "assistant":
parts = []
if content:
parts.append({"text": content})
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
parts.append(
{
"functionCall": {
"name": func.get("name", ""),
"args": args,
}
}
)
gemini_contents.append({"role": "model", "parts": parts})
elif role == "tool":
gemini_contents.append(
{
"role": "function",
"parts": [
{
"functionResponse": {
"name": msg.get("name", ""),
"response": {"result": content},
}
}
],
}
)
else:
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
# Convert tools to Gemini format
gemini_tools = None
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
function_declarations.append(
{
"name": func["name"],
"description": func.get("description", ""),
"parameters": func.get("parameters", {}),
}
)
if function_declarations:
gemini_tools = [{"functionDeclarations": function_declarations}]
request_body = {
"contents": gemini_contents,
"generationConfig": {
"temperature": kwargs.get("temperature", self.temperature),
"maxOutputTokens": kwargs.get("max_tokens", self.max_tokens),
},
}
if system_instruction:
request_body["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if gemini_tools:
request_body["tools"] = gemini_tools
response = requests.post(
self._get_api_url(model),
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json=request_body,
timeout=120,
)
response.raise_for_status()
data = response.json()
# Parse response
content = ""
tool_calls = None
candidates = data.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
if tool_calls is None:
tool_calls = []
fc = part["functionCall"]
tool_calls.append(
ToolCall(
id=f"call_{len(tool_calls)}",
name=fc.get("name", ""),
arguments=fc.get("args", {}),
)
)
usage = data.get("usageMetadata", {})
tokens_used = usage.get("promptTokenCount", 0) + usage.get(
"candidatesTokenCount", 0
)
finish_reason = None
if candidates:
finish_reason = candidates[0].get("finishReason")
return LLMResponse(
content=content,
model=model,
provider="vertex-ai",
tokens_used=tokens_used,
finish_reason=finish_reason,
tool_calls=tool_calls,
)

View File

@@ -0,0 +1,14 @@
"""Compliance Module
Provides audit trail, compliance reporting, and regulatory checks.
"""
from compliance.audit_trail import AuditEvent, AuditLogger, AuditTrail
from compliance.codeowners import CodeownersChecker
__all__ = [
"AuditTrail",
"AuditLogger",
"AuditEvent",
"CodeownersChecker",
]

View File

@@ -0,0 +1,430 @@
"""Audit Trail
Provides comprehensive audit logging for compliance requirements.
Supports HIPAA, SOC2, and other regulatory frameworks.
"""
import hashlib
import json
import logging
import os
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any
class AuditAction(Enum):
"""Types of auditable actions."""
# Review actions
REVIEW_STARTED = "review_started"
REVIEW_COMPLETED = "review_completed"
REVIEW_FAILED = "review_failed"
# Security actions
SECURITY_SCAN_STARTED = "security_scan_started"
SECURITY_SCAN_COMPLETED = "security_scan_completed"
SECURITY_FINDING_DETECTED = "security_finding_detected"
SECURITY_FINDING_RESOLVED = "security_finding_resolved"
# Comment actions
COMMENT_POSTED = "comment_posted"
COMMENT_UPDATED = "comment_updated"
COMMENT_DELETED = "comment_deleted"
# Label actions
LABEL_ADDED = "label_added"
LABEL_REMOVED = "label_removed"
# Configuration actions
CONFIG_LOADED = "config_loaded"
CONFIG_CHANGED = "config_changed"
# Access actions
API_CALL = "api_call"
AUTHENTICATION = "authentication"
# Approval actions
APPROVAL_GRANTED = "approval_granted"
APPROVAL_REVOKED = "approval_revoked"
CHANGES_REQUESTED = "changes_requested"
@dataclass
class AuditEvent:
"""An auditable event."""
action: AuditAction
timestamp: str
actor: str
resource_type: str
resource_id: str
repository: str
details: dict[str, Any] = field(default_factory=dict)
outcome: str = "success"
error: str | None = None
correlation_id: str | None = None
checksum: str | None = None
def __post_init__(self):
"""Calculate checksum for integrity verification."""
if not self.checksum:
self.checksum = self._calculate_checksum()
def _calculate_checksum(self) -> str:
"""Calculate SHA-256 checksum of event data."""
data = {
"action": self.action.value
if isinstance(self.action, AuditAction)
else self.action,
"timestamp": self.timestamp,
"actor": self.actor,
"resource_type": self.resource_type,
"resource_id": self.resource_id,
"repository": self.repository,
"details": self.details,
"outcome": self.outcome,
"error": self.error,
}
json_str = json.dumps(data, sort_keys=True)
return hashlib.sha256(json_str.encode()).hexdigest()
def to_dict(self) -> dict:
"""Convert event to dictionary."""
data = asdict(self)
if isinstance(self.action, AuditAction):
data["action"] = self.action.value
return data
def to_json(self) -> str:
"""Convert event to JSON string."""
return json.dumps(self.to_dict())
class AuditLogger:
"""Logger for audit events."""
def __init__(
self,
log_file: str | None = None,
log_to_stdout: bool = False,
log_level: str = "INFO",
):
"""Initialize audit logger.
Args:
log_file: Path to audit log file.
log_to_stdout: Also log to stdout.
log_level: Logging level.
"""
self.log_file = log_file
self.log_to_stdout = log_to_stdout
self.logger = logging.getLogger("audit")
self.logger.setLevel(getattr(logging, log_level.upper(), logging.INFO))
# Clear existing handlers
self.logger.handlers = []
# Add file handler if specified
if log_file:
log_dir = os.path.dirname(log_file)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(
logging.Formatter("%(message)s") # JSON lines format
)
self.logger.addHandler(file_handler)
# Add stdout handler if requested
if log_to_stdout:
stdout_handler = logging.StreamHandler()
stdout_handler.setFormatter(logging.Formatter("[AUDIT] %(message)s"))
self.logger.addHandler(stdout_handler)
def log(self, event: AuditEvent):
"""Log an audit event.
Args:
event: The audit event to log.
"""
self.logger.info(event.to_json())
def log_action(
self,
action: AuditAction,
actor: str,
resource_type: str,
resource_id: str,
repository: str,
details: dict | None = None,
outcome: str = "success",
error: str | None = None,
correlation_id: str | None = None,
):
"""Log an action as an audit event.
Args:
action: The action being performed.
actor: Who performed the action.
resource_type: Type of resource affected.
resource_id: ID of the resource.
repository: Repository context.
details: Additional details.
outcome: success, failure, or partial.
error: Error message if failed.
correlation_id: ID to correlate related events.
"""
event = AuditEvent(
action=action,
timestamp=datetime.now(timezone.utc).isoformat(),
actor=actor,
resource_type=resource_type,
resource_id=resource_id,
repository=repository,
details=details or {},
outcome=outcome,
error=error,
correlation_id=correlation_id,
)
self.log(event)
class AuditTrail:
"""High-level audit trail management."""
def __init__(self, config: dict):
"""Initialize audit trail.
Args:
config: Configuration dictionary.
"""
self.config = config
compliance_config = config.get("compliance", {})
audit_config = compliance_config.get("audit", {})
self.enabled = audit_config.get("enabled", False)
self.log_file = audit_config.get("log_file", "audit.log")
self.log_to_stdout = audit_config.get("log_to_stdout", False)
self.retention_days = audit_config.get("retention_days", 90)
if self.enabled:
self.logger = AuditLogger(
log_file=self.log_file,
log_to_stdout=self.log_to_stdout,
)
else:
self.logger = None
self._correlation_id = None
def set_correlation_id(self, correlation_id: str):
"""Set correlation ID for subsequent events.
Args:
correlation_id: ID to correlate related events.
"""
self._correlation_id = correlation_id
def log(
self,
action: AuditAction,
actor: str,
resource_type: str,
resource_id: str,
repository: str,
details: dict | None = None,
outcome: str = "success",
error: str | None = None,
):
"""Log an audit event.
Args:
action: The action being performed.
actor: Who performed the action.
resource_type: Type of resource (pr, issue, comment, etc).
resource_id: ID of the resource.
repository: Repository (owner/repo).
details: Additional details.
outcome: success, failure, or partial.
error: Error message if failed.
"""
if not self.enabled or not self.logger:
return
self.logger.log_action(
action=action,
actor=actor,
resource_type=resource_type,
resource_id=resource_id,
repository=repository,
details=details,
outcome=outcome,
error=error,
correlation_id=self._correlation_id,
)
def log_review_started(
self,
repository: str,
pr_number: int,
reviewer: str = "openrabbit",
):
"""Log that a review has started."""
self.log(
action=AuditAction.REVIEW_STARTED,
actor=reviewer,
resource_type="pull_request",
resource_id=str(pr_number),
repository=repository,
)
def log_review_completed(
self,
repository: str,
pr_number: int,
recommendation: str,
findings_count: int,
reviewer: str = "openrabbit",
):
"""Log that a review has completed."""
self.log(
action=AuditAction.REVIEW_COMPLETED,
actor=reviewer,
resource_type="pull_request",
resource_id=str(pr_number),
repository=repository,
details={
"recommendation": recommendation,
"findings_count": findings_count,
},
)
def log_security_finding(
self,
repository: str,
pr_number: int,
finding: dict,
scanner: str = "openrabbit",
):
"""Log a security finding."""
self.log(
action=AuditAction.SECURITY_FINDING_DETECTED,
actor=scanner,
resource_type="pull_request",
resource_id=str(pr_number),
repository=repository,
details={
"severity": finding.get("severity"),
"category": finding.get("category"),
"file": finding.get("file"),
"line": finding.get("line"),
"cwe": finding.get("cwe"),
},
)
def log_approval(
self,
repository: str,
pr_number: int,
approver: str,
approval_type: str = "ai",
):
"""Log an approval action."""
self.log(
action=AuditAction.APPROVAL_GRANTED,
actor=approver,
resource_type="pull_request",
resource_id=str(pr_number),
repository=repository,
details={"approval_type": approval_type},
)
def log_changes_requested(
self,
repository: str,
pr_number: int,
requester: str,
reason: str | None = None,
):
"""Log a changes requested action."""
self.log(
action=AuditAction.CHANGES_REQUESTED,
actor=requester,
resource_type="pull_request",
resource_id=str(pr_number),
repository=repository,
details={"reason": reason} if reason else {},
)
def generate_report(
self,
start_date: datetime | None = None,
end_date: datetime | None = None,
repository: str | None = None,
) -> dict:
"""Generate an audit report.
Args:
start_date: Start of reporting period.
end_date: End of reporting period.
repository: Filter by repository.
Returns:
Report dictionary with statistics and events.
"""
if not self.log_file or not os.path.exists(self.log_file):
return {"events": [], "statistics": {}}
events = []
with open(self.log_file) as f:
for line in f:
try:
event = json.loads(line.strip())
event_time = datetime.fromisoformat(
event["timestamp"].replace("Z", "+00:00")
)
# Apply filters
if start_date and event_time < start_date:
continue
if end_date and event_time > end_date:
continue
if repository and event.get("repository") != repository:
continue
events.append(event)
except (json.JSONDecodeError, KeyError):
continue
# Calculate statistics
action_counts = {}
outcome_counts = {"success": 0, "failure": 0, "partial": 0}
security_findings = 0
for event in events:
action = event.get("action", "unknown")
action_counts[action] = action_counts.get(action, 0) + 1
outcome = event.get("outcome", "success")
if outcome in outcome_counts:
outcome_counts[outcome] += 1
if action == "security_finding_detected":
security_findings += 1
return {
"events": events,
"statistics": {
"total_events": len(events),
"action_counts": action_counts,
"outcome_counts": outcome_counts,
"security_findings": security_findings,
},
"period": {
"start": start_date.isoformat() if start_date else None,
"end": end_date.isoformat() if end_date else None,
},
}

View File

@@ -0,0 +1,314 @@
"""CODEOWNERS Checker
Parses and validates CODEOWNERS files for compliance enforcement.
"""
import fnmatch
import logging
import os
import re
from dataclasses import dataclass
from pathlib import Path
@dataclass
class CodeOwnerRule:
"""A CODEOWNERS rule."""
pattern: str
owners: list[str]
line_number: int
is_negation: bool = False
def matches(self, path: str) -> bool:
"""Check if a path matches this rule.
Args:
path: File path to check.
Returns:
True if the path matches.
"""
path = path.lstrip("/")
pattern = self.pattern.lstrip("/")
# Handle directory patterns
if pattern.endswith("/"):
return path.startswith(pattern) or fnmatch.fnmatch(path, pattern + "*")
# Handle ** patterns
if "**" in pattern:
regex = pattern.replace("**", ".*").replace("*", "[^/]*")
return bool(re.match(f"^{regex}$", path))
# Standard fnmatch
return fnmatch.fnmatch(path, pattern) or fnmatch.fnmatch(path, f"**/{pattern}")
class CodeownersChecker:
"""Checker for CODEOWNERS file compliance."""
CODEOWNERS_LOCATIONS = [
"CODEOWNERS",
".github/CODEOWNERS",
".gitea/CODEOWNERS",
"docs/CODEOWNERS",
]
def __init__(self, repo_root: str | None = None):
"""Initialize CODEOWNERS checker.
Args:
repo_root: Repository root path.
"""
self.repo_root = repo_root or os.getcwd()
self.rules: list[CodeOwnerRule] = []
self.codeowners_path: str | None = None
self.logger = logging.getLogger(__name__)
self._load_codeowners()
def _load_codeowners(self):
"""Load CODEOWNERS file from repository."""
for location in self.CODEOWNERS_LOCATIONS:
path = os.path.join(self.repo_root, location)
if os.path.exists(path):
self.codeowners_path = path
self._parse_codeowners(path)
break
def _parse_codeowners(self, path: str):
"""Parse a CODEOWNERS file.
Args:
path: Path to CODEOWNERS file.
"""
with open(path) as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
# Parse pattern and owners
parts = line.split()
if len(parts) < 2:
continue
pattern = parts[0]
owners = parts[1:]
# Check for negation (optional syntax)
is_negation = pattern.startswith("!")
if is_negation:
pattern = pattern[1:]
self.rules.append(
CodeOwnerRule(
pattern=pattern,
owners=owners,
line_number=line_num,
is_negation=is_negation,
)
)
def get_owners(self, path: str) -> list[str]:
"""Get owners for a file path.
Args:
path: File path to check.
Returns:
List of owner usernames/teams.
"""
owners = []
# Apply rules in order (later rules override earlier ones)
for rule in self.rules:
if rule.matches(path):
if rule.is_negation:
owners = [] # Clear owners for negation
else:
owners = rule.owners
return owners
def get_owners_for_files(self, files: list[str]) -> dict[str, list[str]]:
"""Get owners for multiple files.
Args:
files: List of file paths.
Returns:
Dict mapping file paths to owner lists.
"""
return {f: self.get_owners(f) for f in files}
def get_required_reviewers(self, files: list[str]) -> set[str]:
"""Get all required reviewers for a set of files.
Args:
files: List of file paths.
Returns:
Set of all required reviewer usernames/teams.
"""
reviewers = set()
for f in files:
reviewers.update(self.get_owners(f))
return reviewers
def check_approval(
self,
files: list[str],
approvers: list[str],
) -> dict:
"""Check if files have required approvals.
Args:
files: List of changed files.
approvers: List of users who approved.
Returns:
Dict with approval status and missing approvers.
"""
required = self.get_required_reviewers(files)
approvers_set = set(approvers)
# Normalize @ prefixes
required_normalized = {r.lstrip("@") for r in required}
approvers_normalized = {a.lstrip("@") for a in approvers_set}
missing = required_normalized - approvers_normalized
# Check for team approvals (simplified - actual implementation
# would need API calls to check team membership)
teams = {r for r in missing if "/" in r}
missing_users = missing - teams
return {
"approved": len(missing_users) == 0,
"required_reviewers": list(required_normalized),
"actual_approvers": list(approvers_normalized),
"missing_approvers": list(missing_users),
"pending_teams": list(teams),
}
def get_coverage_report(self, files: list[str]) -> dict:
"""Generate a coverage report for files.
Args:
files: List of file paths.
Returns:
Coverage report with owned and unowned files.
"""
owned = []
unowned = []
for f in files:
owners = self.get_owners(f)
if owners:
owned.append({"file": f, "owners": owners})
else:
unowned.append(f)
return {
"total_files": len(files),
"owned_files": len(owned),
"unowned_files": len(unowned),
"coverage_percent": (len(owned) / len(files) * 100) if files else 0,
"owned": owned,
"unowned": unowned,
}
def validate_codeowners(self) -> dict:
"""Validate the CODEOWNERS file.
Returns:
Validation result with warnings and errors.
"""
if not self.codeowners_path:
return {
"valid": False,
"errors": ["No CODEOWNERS file found"],
"warnings": [],
}
errors = []
warnings = []
# Check for empty rules
for rule in self.rules:
if not rule.owners:
errors.append(
f"Line {rule.line_number}: Pattern '{rule.pattern}' has no owners"
)
# Check for invalid owner formats
for rule in self.rules:
for owner in rule.owners:
if not owner.startswith("@") and "/" not in owner:
warnings.append(
f"Line {rule.line_number}: Owner '{owner}' should start with @ or be a team (org/team)"
)
# Check for overlapping patterns
patterns_seen = {}
for rule in self.rules:
if rule.pattern in patterns_seen:
warnings.append(
f"Line {rule.line_number}: Pattern '{rule.pattern}' duplicates line {patterns_seen[rule.pattern]}"
)
patterns_seen[rule.pattern] = rule.line_number
return {
"valid": len(errors) == 0,
"errors": errors,
"warnings": warnings,
"rules_count": len(self.rules),
"file_path": self.codeowners_path,
}
@classmethod
def from_content(cls, content: str) -> "CodeownersChecker":
"""Create checker from CODEOWNERS content string.
Args:
content: CODEOWNERS file content.
Returns:
CodeownersChecker instance.
"""
checker = cls.__new__(cls)
checker.repo_root = None
checker.rules = []
checker.codeowners_path = "<string>"
checker.logger = logging.getLogger(__name__)
for line_num, line in enumerate(content.split("\n"), 1):
line = line.strip()
if not line or line.startswith("#"):
continue
parts = line.split()
if len(parts) < 2:
continue
pattern = parts[0]
owners = parts[1:]
is_negation = pattern.startswith("!")
if is_negation:
pattern = pattern[1:]
checker.rules.append(
CodeOwnerRule(
pattern=pattern,
owners=owners,
line_number=line_num,
is_negation=is_negation,
)
)
return checker

View File

@@ -1,21 +1,61 @@
provider: openai # openai | openrouter | ollama # OpenRabbit AI Code Review Configuration
# =========================================
# LLM Provider Configuration
# --------------------------
# Available providers: openai | openrouter | ollama | anthropic | azure | gemini
provider: openai
model: model:
openai: gpt-4.1-mini openai: gpt-4.1-mini
openrouter: anthropic/claude-3.5-sonnet openrouter: anthropic/claude-3.5-sonnet
ollama: codellama:13b ollama: codellama:13b
anthropic: claude-3-5-sonnet-20241022
azure: gpt-4 # Deployment name
gemini: gemini-1.5-pro
temperature: 0 temperature: 0
max_tokens: 4096 max_tokens: 4096
# Azure OpenAI specific settings (when provider: azure)
azure:
endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var
deployment: "" # Set via AZURE_OPENAI_DEPLOYMENT env var
api_version: "2024-02-15-preview"
# Google Gemini specific settings (when provider: gemini)
gemini:
project: "" # For Vertex AI, set via GOOGLE_CLOUD_PROJECT env var
region: "us-central1"
# Rate Limits and Timeouts
# ------------------------
rate_limits:
min_interval: 1.0 # Minimum seconds between API requests
timeouts:
llm: 120 # LLM API timeout in seconds (OpenAI, OpenRouter, Anthropic, etc.)
ollama: 300 # Ollama timeout (longer for local models)
gitea: 30 # Gitea/GitHub API timeout
# Review settings # Review settings
# ---------------
review: review:
fail_on_severity: HIGH fail_on_severity: HIGH
max_diff_lines: 800 max_diff_lines: 800
inline_comments: true inline_comments: true
security_scan: true security_scan: true
# Agent settings # File Ignore Patterns
# --------------------
# Similar to .gitignore, controls which files are excluded from review
ignore:
use_defaults: true # Include default patterns (node_modules, .git, etc.)
file: ".ai-reviewignore" # Custom ignore file name
patterns: [] # Additional patterns to ignore
# Agent Configuration
# -------------------
agents: agents:
issue: issue:
enabled: true enabled: true
@@ -33,46 +73,137 @@ agents:
- opened - opened
- synchronize - synchronize
auto_summary: auto_summary:
enabled: true # Auto-generate summary for PRs with empty descriptions enabled: true
post_as_comment: true # true = post as comment, false = update PR description post_as_comment: true
codebase: codebase:
enabled: true enabled: true
schedule: "0 0 * * 0" # Weekly on Sunday schedule: "0 0 * * 0" # Weekly on Sunday
chat: chat:
enabled: true enabled: true
name: "Bartender" name: "Bartender"
max_iterations: 5 # Max tool call iterations per chat max_iterations: 5
tools: tools:
- search_codebase - search_codebase
- read_file - read_file
- search_web - search_web
searxng_url: "" # Set via SEARXNG_URL env var or here searxng_url: "" # Set via SEARXNG_URL env var
# Interaction settings # Dependency Security Agent
dependency:
enabled: true
scan_on_pr: true # Auto-scan PRs that modify dependency files
vulnerability_threshold: "medium" # low | medium | high | critical
update_suggestions: true # Suggest version updates
# Test Coverage Agent
test_coverage:
enabled: true
suggest_tests: true
min_coverage_percent: 80 # Warn if coverage below this
# Architecture Compliance Agent
architecture:
enabled: true
layers:
api:
can_import_from: [utils, models, services]
cannot_import_from: [db, repositories]
services:
can_import_from: [utils, models, repositories]
cannot_import_from: [api]
repositories:
can_import_from: [utils, models, db]
cannot_import_from: [api, services]
# Interaction Settings
# --------------------
# CUSTOMIZE YOUR BOT NAME HERE! # CUSTOMIZE YOUR BOT NAME HERE!
# Change mention_prefix to your preferred bot name:
# "@ai-bot" - Default
# "@bartender" - Friendly bar theme
# "@uni" - Short and simple
# "@joey" - Personal assistant name
# "@codebot" - Code-focused name
# NOTE: Also update the workflow files (.github/workflows/ or .gitea/workflows/)
# to match this prefix in the 'if: contains(...)' condition
interaction: interaction:
respond_to_mentions: true respond_to_mentions: true
mention_prefix: "@codebot" # Change this to customize your bot's name! mention_prefix: "@codebot"
commands: commands:
- help - help
- explain - explain
- suggest - suggest
- security - security
- summarize # Generate PR summary (works on both issues and PRs) - summarize
- changelog # Generate Keep a Changelog format entries (PR comments only) - changelog
- explain-diff # Explain code changes in plain language (PR comments only) - explain-diff
- triage - triage
- review-again - review-again
# New commands
- check-deps # Check dependencies for vulnerabilities
- suggest-tests # Suggest test cases
- refactor-suggest # Suggest refactoring opportunities
- architecture # Check architecture compliance
- arch-check # Alias for architecture
# Enterprise settings # Security Scanning
# -----------------
security:
enabled: true
fail_on_high: true
rules_file: "security/security_rules.yml"
# SAST Integration
sast:
enabled: true
bandit: true # Python AST-based security scanner
semgrep: true # Polyglot security scanner with custom rules
trivy: false # Container/filesystem scanner (requires trivy installed)
# Notifications
# -------------
notifications:
enabled: false
threshold: "warning" # info | warning | error | critical
slack:
enabled: false
webhook_url: "" # Set via SLACK_WEBHOOK_URL env var
channel: "" # Override channel (optional)
username: "OpenRabbit"
discord:
enabled: false
webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var
username: "OpenRabbit"
avatar_url: ""
# Custom webhooks for other integrations
webhooks: []
# Example:
# - url: "https://your-webhook.example.com/notify"
# enabled: true
# headers:
# Authorization: "Bearer your-token"
# Compliance & Audit
# ------------------
compliance:
enabled: false
# Audit Trail
audit:
enabled: false
log_file: "audit.log"
log_to_stdout: false
retention_days: 90
# CODEOWNERS Enforcement
codeowners:
enabled: false
require_approval: true # Require approval from code owners
# Regulatory Compliance
regulations:
hipaa: false
soc2: false
pci_dss: false
gdpr: false
# Enterprise Settings
# -------------------
enterprise: enterprise:
audit_log: true audit_log: true
audit_path: "/var/log/ai-review/" audit_path: "/var/log/ai-review/"
@@ -81,44 +212,44 @@ enterprise:
requests_per_minute: 30 requests_per_minute: 30
max_concurrent: 4 max_concurrent: 4
# Label mappings for auto-labeling # Label Mappings
# --------------
# Each label has: # Each label has:
# name: The label name to use/create (string) or full config (dict) # name: The label name to use/create
# aliases: Alternative names for auto-detection (optional) # aliases: Alternative names for auto-detection
# color: Hex color code without # (optional, for label creation) # color: Hex color code without #
# description: Label description (optional, for label creation) # description: Label description
labels: labels:
priority: priority:
critical: critical:
name: "priority: critical" name: "priority: critical"
color: "b60205" # Dark Red color: "b60205"
description: "Critical priority - immediate attention required" description: "Critical priority - immediate attention required"
aliases: aliases: ["Priority - Critical", "P0", "critical", "Priority/Critical"]
["Priority - Critical", "P0", "critical", "Priority/Critical"]
high: high:
name: "priority: high" name: "priority: high"
color: "d73a4a" # Red color: "d73a4a"
description: "High priority issue" description: "High priority issue"
aliases: ["Priority - High", "P1", "high", "Priority/High"] aliases: ["Priority - High", "P1", "high", "Priority/High"]
medium: medium:
name: "priority: medium" name: "priority: medium"
color: "fbca04" # Yellow color: "fbca04"
description: "Medium priority issue" description: "Medium priority issue"
aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"] aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"]
low: low:
name: "priority: low" name: "priority: low"
color: "28a745" # Green color: "28a745"
description: "Low priority issue" description: "Low priority issue"
aliases: ["Priority - Low", "P3", "low", "Priority/Low"] aliases: ["Priority - Low", "P3", "low", "Priority/Low"]
type: type:
bug: bug:
name: "type: bug" name: "type: bug"
color: "d73a4a" # Red color: "d73a4a"
description: "Something isn't working" description: "Something isn't working"
aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"] aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"]
feature: feature:
name: "type: feature" name: "type: feature"
color: "1d76db" # Blue color: "1d76db"
description: "New feature request" description: "New feature request"
aliases: aliases:
[ [
@@ -132,7 +263,7 @@ labels:
] ]
question: question:
name: "type: question" name: "type: question"
color: "cc317c" # Purple color: "cc317c"
description: "Further information is requested" description: "Further information is requested"
aliases: aliases:
[ [
@@ -144,7 +275,7 @@ labels:
] ]
docs: docs:
name: "type: documentation" name: "type: documentation"
color: "0075ca" # Light Blue color: "0075ca"
description: "Documentation improvements" description: "Documentation improvements"
aliases: aliases:
[ [
@@ -157,7 +288,7 @@ labels:
] ]
security: security:
name: "type: security" name: "type: security"
color: "b60205" # Dark Red color: "b60205"
description: "Security vulnerability or concern" description: "Security vulnerability or concern"
aliases: aliases:
[ [
@@ -169,7 +300,7 @@ labels:
] ]
testing: testing:
name: "type: testing" name: "type: testing"
color: "0e8a16" # Green color: "0e8a16"
description: "Related to testing" description: "Related to testing"
aliases: aliases:
[ [
@@ -183,7 +314,7 @@ labels:
status: status:
ai_approved: ai_approved:
name: "ai-approved" name: "ai-approved"
color: "28a745" # Green color: "28a745"
description: "AI review approved this PR" description: "AI review approved this PR"
aliases: aliases:
[ [
@@ -194,7 +325,7 @@ labels:
] ]
ai_changes_required: ai_changes_required:
name: "ai-changes-required" name: "ai-changes-required"
color: "d73a4a" # Red color: "d73a4a"
description: "AI review found issues requiring changes" description: "AI review found issues requiring changes"
aliases: aliases:
[ [
@@ -205,7 +336,7 @@ labels:
] ]
ai_reviewed: ai_reviewed:
name: "ai-reviewed" name: "ai-reviewed"
color: "1d76db" # Blue color: "1d76db"
description: "This issue/PR has been reviewed by AI" description: "This issue/PR has been reviewed by AI"
aliases: aliases:
[ [
@@ -216,18 +347,9 @@ labels:
"Status - Reviewed", "Status - Reviewed",
] ]
# Label schema detection patterns # Label Pattern Detection
# Used by setup-labels command to detect existing naming conventions # -----------------------
label_patterns: label_patterns:
# Detect prefix-based naming (e.g., Kind/Bug, Type/Feature)
prefix_slash: "^(Kind|Type|Category)/(.+)$" prefix_slash: "^(Kind|Type|Category)/(.+)$"
# Detect dash-separated naming (e.g., Priority - High, Status - Blocked)
prefix_dash: "^(Priority|Status|Reviewed) - (.+)$" prefix_dash: "^(Priority|Status|Reviewed) - (.+)$"
# Detect colon-separated naming (e.g., type: bug, priority: high)
colon: "^(type|priority|status): (.+)$" colon: "^(type|priority|status): (.+)$"
# Security scanning rules
security:
enabled: true
fail_on_high: true
rules_file: "security/security_rules.yml"

View File

@@ -0,0 +1,20 @@
"""Notifications Package
Provides webhook-based notifications for Slack, Discord, and other platforms.
"""
from notifications.notifier import (
DiscordNotifier,
Notifier,
NotifierFactory,
SlackNotifier,
WebhookNotifier,
)
__all__ = [
"Notifier",
"SlackNotifier",
"DiscordNotifier",
"WebhookNotifier",
"NotifierFactory",
]

View File

@@ -0,0 +1,542 @@
"""Notification System
Provides webhook-based notifications for Slack, Discord, and other platforms.
Supports critical security findings, review summaries, and custom alerts.
"""
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any
import requests
class NotificationLevel(Enum):
"""Notification severity levels."""
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
@dataclass
class NotificationMessage:
"""A notification message."""
title: str
message: str
level: NotificationLevel = NotificationLevel.INFO
fields: dict[str, str] | None = None
url: str | None = None
footer: str | None = None
class Notifier(ABC):
"""Abstract base class for notification providers."""
@abstractmethod
def send(self, message: NotificationMessage) -> bool:
"""Send a notification.
Args:
message: The notification message to send.
Returns:
True if sent successfully, False otherwise.
"""
pass
@abstractmethod
def send_raw(self, payload: dict) -> bool:
"""Send a raw payload to the webhook.
Args:
payload: Raw payload in provider-specific format.
Returns:
True if sent successfully, False otherwise.
"""
pass
class SlackNotifier(Notifier):
"""Slack webhook notifier."""
# Color mapping for different levels
LEVEL_COLORS = {
NotificationLevel.INFO: "#36a64f", # Green
NotificationLevel.WARNING: "#ffcc00", # Yellow
NotificationLevel.ERROR: "#ff6600", # Orange
NotificationLevel.CRITICAL: "#cc0000", # Red
}
LEVEL_EMOJIS = {
NotificationLevel.INFO: ":information_source:",
NotificationLevel.WARNING: ":warning:",
NotificationLevel.ERROR: ":x:",
NotificationLevel.CRITICAL: ":rotating_light:",
}
def __init__(
self,
webhook_url: str | None = None,
channel: str | None = None,
username: str = "OpenRabbit",
icon_emoji: str = ":robot_face:",
timeout: int = 10,
):
"""Initialize Slack notifier.
Args:
webhook_url: Slack incoming webhook URL.
Defaults to SLACK_WEBHOOK_URL env var.
channel: Override channel (optional).
username: Bot username to display.
icon_emoji: Bot icon emoji.
timeout: Request timeout in seconds.
"""
self.webhook_url = webhook_url or os.environ.get("SLACK_WEBHOOK_URL", "")
self.channel = channel
self.username = username
self.icon_emoji = icon_emoji
self.timeout = timeout
self.logger = logging.getLogger(__name__)
def send(self, message: NotificationMessage) -> bool:
"""Send a notification to Slack."""
if not self.webhook_url:
self.logger.warning("Slack webhook URL not configured")
return False
# Build attachment
attachment = {
"color": self.LEVEL_COLORS.get(message.level, "#36a64f"),
"title": f"{self.LEVEL_EMOJIS.get(message.level, '')} {message.title}",
"text": message.message,
"mrkdwn_in": ["text", "fields"],
}
if message.url:
attachment["title_link"] = message.url
if message.fields:
attachment["fields"] = [
{"title": k, "value": v, "short": len(v) < 40}
for k, v in message.fields.items()
]
if message.footer:
attachment["footer"] = message.footer
attachment["footer_icon"] = "https://github.com/favicon.ico"
payload = {
"username": self.username,
"icon_emoji": self.icon_emoji,
"attachments": [attachment],
}
if self.channel:
payload["channel"] = self.channel
return self.send_raw(payload)
def send_raw(self, payload: dict) -> bool:
"""Send raw payload to Slack webhook."""
if not self.webhook_url:
return False
try:
response = requests.post(
self.webhook_url,
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
return True
except requests.RequestException as e:
self.logger.error(f"Failed to send Slack notification: {e}")
return False
class DiscordNotifier(Notifier):
"""Discord webhook notifier."""
# Color mapping for different levels (Discord uses decimal colors)
LEVEL_COLORS = {
NotificationLevel.INFO: 3066993, # Green
NotificationLevel.WARNING: 16776960, # Yellow
NotificationLevel.ERROR: 16744448, # Orange
NotificationLevel.CRITICAL: 13369344, # Red
}
def __init__(
self,
webhook_url: str | None = None,
username: str = "OpenRabbit",
avatar_url: str | None = None,
timeout: int = 10,
):
"""Initialize Discord notifier.
Args:
webhook_url: Discord webhook URL.
Defaults to DISCORD_WEBHOOK_URL env var.
username: Bot username to display.
avatar_url: Bot avatar URL.
timeout: Request timeout in seconds.
"""
self.webhook_url = webhook_url or os.environ.get("DISCORD_WEBHOOK_URL", "")
self.username = username
self.avatar_url = avatar_url
self.timeout = timeout
self.logger = logging.getLogger(__name__)
def send(self, message: NotificationMessage) -> bool:
"""Send a notification to Discord."""
if not self.webhook_url:
self.logger.warning("Discord webhook URL not configured")
return False
# Build embed
embed = {
"title": message.title,
"description": message.message,
"color": self.LEVEL_COLORS.get(message.level, 3066993),
}
if message.url:
embed["url"] = message.url
if message.fields:
embed["fields"] = [
{"name": k, "value": v, "inline": len(v) < 40}
for k, v in message.fields.items()
]
if message.footer:
embed["footer"] = {"text": message.footer}
payload = {
"username": self.username,
"embeds": [embed],
}
if self.avatar_url:
payload["avatar_url"] = self.avatar_url
return self.send_raw(payload)
def send_raw(self, payload: dict) -> bool:
"""Send raw payload to Discord webhook."""
if not self.webhook_url:
return False
try:
response = requests.post(
self.webhook_url,
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
return True
except requests.RequestException as e:
self.logger.error(f"Failed to send Discord notification: {e}")
return False
class WebhookNotifier(Notifier):
"""Generic webhook notifier for custom integrations."""
def __init__(
self,
webhook_url: str,
headers: dict[str, str] | None = None,
timeout: int = 10,
):
"""Initialize generic webhook notifier.
Args:
webhook_url: Webhook URL.
headers: Custom headers to include.
timeout: Request timeout in seconds.
"""
self.webhook_url = webhook_url
self.headers = headers or {"Content-Type": "application/json"}
self.timeout = timeout
self.logger = logging.getLogger(__name__)
def send(self, message: NotificationMessage) -> bool:
"""Send a notification as JSON payload."""
payload = {
"title": message.title,
"message": message.message,
"level": message.level.value,
"fields": message.fields or {},
"url": message.url,
"footer": message.footer,
}
return self.send_raw(payload)
def send_raw(self, payload: dict) -> bool:
"""Send raw payload to webhook."""
try:
response = requests.post(
self.webhook_url,
json=payload,
headers=self.headers,
timeout=self.timeout,
)
response.raise_for_status()
return True
except requests.RequestException as e:
self.logger.error(f"Failed to send webhook notification: {e}")
return False
class NotifierFactory:
"""Factory for creating notifier instances from config."""
@staticmethod
def create_from_config(config: dict) -> list[Notifier]:
"""Create notifier instances from configuration.
Args:
config: Configuration dictionary with 'notifications' section.
Returns:
List of configured notifier instances.
"""
notifiers = []
notifications_config = config.get("notifications", {})
if not notifications_config.get("enabled", False):
return notifiers
# Slack
slack_config = notifications_config.get("slack", {})
if slack_config.get("enabled", False):
webhook_url = slack_config.get("webhook_url") or os.environ.get(
"SLACK_WEBHOOK_URL"
)
if webhook_url:
notifiers.append(
SlackNotifier(
webhook_url=webhook_url,
channel=slack_config.get("channel"),
username=slack_config.get("username", "OpenRabbit"),
)
)
# Discord
discord_config = notifications_config.get("discord", {})
if discord_config.get("enabled", False):
webhook_url = discord_config.get("webhook_url") or os.environ.get(
"DISCORD_WEBHOOK_URL"
)
if webhook_url:
notifiers.append(
DiscordNotifier(
webhook_url=webhook_url,
username=discord_config.get("username", "OpenRabbit"),
avatar_url=discord_config.get("avatar_url"),
)
)
# Generic webhooks
webhooks_config = notifications_config.get("webhooks", [])
for webhook in webhooks_config:
if webhook.get("enabled", True) and webhook.get("url"):
notifiers.append(
WebhookNotifier(
webhook_url=webhook["url"],
headers=webhook.get("headers"),
)
)
return notifiers
@staticmethod
def should_notify(
level: NotificationLevel,
config: dict,
) -> bool:
"""Check if notification should be sent based on level threshold.
Args:
level: Notification level.
config: Configuration dictionary.
Returns:
True if notification should be sent.
"""
notifications_config = config.get("notifications", {})
if not notifications_config.get("enabled", False):
return False
threshold = notifications_config.get("threshold", "warning")
level_order = ["info", "warning", "error", "critical"]
try:
threshold_idx = level_order.index(threshold)
level_idx = level_order.index(level.value)
return level_idx >= threshold_idx
except ValueError:
return True
class NotificationService:
"""High-level notification service for sending alerts."""
def __init__(self, config: dict):
"""Initialize notification service.
Args:
config: Configuration dictionary.
"""
self.config = config
self.notifiers = NotifierFactory.create_from_config(config)
self.logger = logging.getLogger(__name__)
def notify(self, message: NotificationMessage) -> bool:
"""Send notification to all configured notifiers.
Args:
message: Notification message.
Returns:
True if at least one notification succeeded.
"""
if not NotifierFactory.should_notify(message.level, self.config):
return True # Not an error, just below threshold
if not self.notifiers:
self.logger.debug("No notifiers configured")
return True
success = False
for notifier in self.notifiers:
try:
if notifier.send(message):
success = True
except Exception as e:
self.logger.error(f"Notifier failed: {e}")
return success
def notify_security_finding(
self,
repo: str,
pr_number: int,
finding: dict,
pr_url: str | None = None,
) -> bool:
"""Send notification for a security finding.
Args:
repo: Repository name (owner/repo).
pr_number: Pull request number.
finding: Security finding dict with severity, description, etc.
pr_url: URL to the pull request.
Returns:
True if notification succeeded.
"""
severity = finding.get("severity", "MEDIUM").upper()
level_map = {
"HIGH": NotificationLevel.CRITICAL,
"MEDIUM": NotificationLevel.WARNING,
"LOW": NotificationLevel.INFO,
}
level = level_map.get(severity, NotificationLevel.WARNING)
message = NotificationMessage(
title=f"Security Finding in {repo} PR #{pr_number}",
message=finding.get("description", "Security issue detected"),
level=level,
fields={
"Severity": severity,
"Category": finding.get("category", "Unknown"),
"File": finding.get("file", "N/A"),
"Line": str(finding.get("line", "N/A")),
},
url=pr_url,
footer=f"CWE: {finding.get('cwe', 'N/A')}",
)
return self.notify(message)
def notify_review_complete(
self,
repo: str,
pr_number: int,
summary: dict,
pr_url: str | None = None,
) -> bool:
"""Send notification for completed review.
Args:
repo: Repository name (owner/repo).
pr_number: Pull request number.
summary: Review summary dict.
pr_url: URL to the pull request.
Returns:
True if notification succeeded.
"""
recommendation = summary.get("recommendation", "COMMENT")
level_map = {
"APPROVE": NotificationLevel.INFO,
"COMMENT": NotificationLevel.INFO,
"REQUEST_CHANGES": NotificationLevel.WARNING,
}
level = level_map.get(recommendation, NotificationLevel.INFO)
high_issues = summary.get("high_severity_count", 0)
if high_issues > 0:
level = NotificationLevel.ERROR
message = NotificationMessage(
title=f"Review Complete: {repo} PR #{pr_number}",
message=summary.get("summary", "AI review completed"),
level=level,
fields={
"Recommendation": recommendation,
"High Severity": str(high_issues),
"Medium Severity": str(summary.get("medium_severity_count", 0)),
"Files Reviewed": str(summary.get("files_reviewed", 0)),
},
url=pr_url,
footer="OpenRabbit AI Review",
)
return self.notify(message)
def notify_error(
self,
repo: str,
error: str,
context: str | None = None,
) -> bool:
"""Send notification for an error.
Args:
repo: Repository name.
error: Error message.
context: Additional context.
Returns:
True if notification succeeded.
"""
message = NotificationMessage(
title=f"Error in {repo}",
message=error,
level=NotificationLevel.ERROR,
fields={"Context": context} if context else None,
footer="OpenRabbit Error Alert",
)
return self.notify(message)

View File

@@ -0,0 +1,431 @@
"""SAST Scanner Integration
Integrates with external SAST tools like Bandit and Semgrep
to provide comprehensive security analysis.
"""
import json
import logging
import os
import shutil
import subprocess
import tempfile
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class SASTFinding:
"""A finding from a SAST tool."""
tool: str
rule_id: str
severity: str # CRITICAL, HIGH, MEDIUM, LOW
file: str
line: int
message: str
code_snippet: str | None = None
cwe: str | None = None
owasp: str | None = None
fix_recommendation: str | None = None
@dataclass
class SASTReport:
"""Combined report from all SAST tools."""
total_findings: int
findings_by_severity: dict[str, int]
findings_by_tool: dict[str, int]
findings: list[SASTFinding]
tools_run: list[str]
errors: list[str] = field(default_factory=list)
class SASTScanner:
"""Aggregator for multiple SAST tools."""
def __init__(self, config: dict | None = None):
"""Initialize the SAST scanner.
Args:
config: Configuration dictionary with tool settings.
"""
self.config = config or {}
self.logger = logging.getLogger(self.__class__.__name__)
def scan_directory(self, path: str) -> SASTReport:
"""Scan a directory with all enabled SAST tools.
Args:
path: Path to the directory to scan.
Returns:
Combined SASTReport from all tools.
"""
all_findings = []
tools_run = []
errors = []
sast_config = self.config.get("security", {}).get("sast", {})
# Run Bandit (Python)
if sast_config.get("bandit", True):
if self._is_tool_available("bandit"):
try:
findings = self._run_bandit(path)
all_findings.extend(findings)
tools_run.append("bandit")
except Exception as e:
errors.append(f"Bandit error: {e}")
else:
self.logger.debug("Bandit not installed, skipping")
# Run Semgrep
if sast_config.get("semgrep", True):
if self._is_tool_available("semgrep"):
try:
findings = self._run_semgrep(path)
all_findings.extend(findings)
tools_run.append("semgrep")
except Exception as e:
errors.append(f"Semgrep error: {e}")
else:
self.logger.debug("Semgrep not installed, skipping")
# Run Trivy (if enabled for filesystem scanning)
if sast_config.get("trivy", False):
if self._is_tool_available("trivy"):
try:
findings = self._run_trivy(path)
all_findings.extend(findings)
tools_run.append("trivy")
except Exception as e:
errors.append(f"Trivy error: {e}")
else:
self.logger.debug("Trivy not installed, skipping")
# Calculate statistics
by_severity = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0}
by_tool = {}
for finding in all_findings:
sev = finding.severity.upper()
if sev in by_severity:
by_severity[sev] += 1
tool = finding.tool
by_tool[tool] = by_tool.get(tool, 0) + 1
return SASTReport(
total_findings=len(all_findings),
findings_by_severity=by_severity,
findings_by_tool=by_tool,
findings=all_findings,
tools_run=tools_run,
errors=errors,
)
def scan_content(self, content: str, filename: str) -> list[SASTFinding]:
"""Scan file content with SAST tools.
Args:
content: File content to scan.
filename: Name of the file (for language detection).
Returns:
List of SASTFinding objects.
"""
# Create temporary file for scanning
with tempfile.NamedTemporaryFile(
mode="w",
suffix=os.path.splitext(filename)[1],
delete=False,
) as f:
f.write(content)
temp_path = f.name
try:
report = self.scan_directory(os.path.dirname(temp_path))
# Filter findings for our specific file
findings = [
f
for f in report.findings
if os.path.basename(f.file) == os.path.basename(temp_path)
]
# Update file path to original filename
for finding in findings:
finding.file = filename
return findings
finally:
os.unlink(temp_path)
def scan_diff(self, diff: str) -> list[SASTFinding]:
"""Scan a diff for security issues.
Only scans added/modified lines.
Args:
diff: Git diff content.
Returns:
List of SASTFinding objects.
"""
findings = []
# Parse diff and extract added content per file
files_content = {}
current_file = None
current_content = []
for line in diff.splitlines():
if line.startswith("diff --git"):
if current_file and current_content:
files_content[current_file] = "\n".join(current_content)
current_file = None
current_content = []
# Extract filename
match = line.split(" b/")
if len(match) > 1:
current_file = match[1]
elif line.startswith("+") and not line.startswith("+++"):
if current_file:
current_content.append(line[1:]) # Remove + prefix
# Don't forget last file
if current_file and current_content:
files_content[current_file] = "\n".join(current_content)
# Scan each file's content
for filename, content in files_content.items():
if content.strip():
file_findings = self.scan_content(content, filename)
findings.extend(file_findings)
return findings
def _is_tool_available(self, tool: str) -> bool:
"""Check if a tool is installed and available."""
return shutil.which(tool) is not None
def _run_bandit(self, path: str) -> list[SASTFinding]:
"""Run Bandit security scanner.
Args:
path: Path to scan.
Returns:
List of SASTFinding objects.
"""
findings = []
try:
result = subprocess.run(
[
"bandit",
"-r",
path,
"-f",
"json",
"-ll", # Only high and medium severity
"--quiet",
],
capture_output=True,
text=True,
timeout=120,
)
if result.stdout:
data = json.loads(result.stdout)
for issue in data.get("results", []):
severity = issue.get("issue_severity", "MEDIUM").upper()
findings.append(
SASTFinding(
tool="bandit",
rule_id=issue.get("test_id", ""),
severity=severity,
file=issue.get("filename", ""),
line=issue.get("line_number", 0),
message=issue.get("issue_text", ""),
code_snippet=issue.get("code", ""),
cwe=f"CWE-{issue.get('issue_cwe', {}).get('id', '')}"
if issue.get("issue_cwe")
else None,
fix_recommendation=issue.get("more_info", ""),
)
)
except subprocess.TimeoutExpired:
self.logger.warning("Bandit scan timed out")
except json.JSONDecodeError as e:
self.logger.warning(f"Failed to parse Bandit output: {e}")
except Exception as e:
self.logger.warning(f"Bandit scan failed: {e}")
return findings
def _run_semgrep(self, path: str) -> list[SASTFinding]:
"""Run Semgrep security scanner.
Args:
path: Path to scan.
Returns:
List of SASTFinding objects.
"""
findings = []
# Get Semgrep config from settings
sast_config = self.config.get("security", {}).get("sast", {})
semgrep_rules = sast_config.get("semgrep_rules", "p/security-audit")
try:
result = subprocess.run(
[
"semgrep",
"--config",
semgrep_rules,
"--json",
"--quiet",
path,
],
capture_output=True,
text=True,
timeout=180,
)
if result.stdout:
data = json.loads(result.stdout)
for finding in data.get("results", []):
# Map Semgrep severity to our scale
sev_map = {
"ERROR": "HIGH",
"WARNING": "MEDIUM",
"INFO": "LOW",
}
severity = sev_map.get(
finding.get("extra", {}).get("severity", "WARNING"), "MEDIUM"
)
metadata = finding.get("extra", {}).get("metadata", {})
findings.append(
SASTFinding(
tool="semgrep",
rule_id=finding.get("check_id", ""),
severity=severity,
file=finding.get("path", ""),
line=finding.get("start", {}).get("line", 0),
message=finding.get("extra", {}).get("message", ""),
code_snippet=finding.get("extra", {}).get("lines", ""),
cwe=metadata.get("cwe", [None])[0]
if metadata.get("cwe")
else None,
owasp=metadata.get("owasp", [None])[0]
if metadata.get("owasp")
else None,
fix_recommendation=metadata.get("fix", ""),
)
)
except subprocess.TimeoutExpired:
self.logger.warning("Semgrep scan timed out")
except json.JSONDecodeError as e:
self.logger.warning(f"Failed to parse Semgrep output: {e}")
except Exception as e:
self.logger.warning(f"Semgrep scan failed: {e}")
return findings
def _run_trivy(self, path: str) -> list[SASTFinding]:
"""Run Trivy filesystem scanner.
Args:
path: Path to scan.
Returns:
List of SASTFinding objects.
"""
findings = []
try:
result = subprocess.run(
[
"trivy",
"fs",
"--format",
"json",
"--security-checks",
"vuln,secret,config",
path,
],
capture_output=True,
text=True,
timeout=180,
)
if result.stdout:
data = json.loads(result.stdout)
for result_item in data.get("Results", []):
target = result_item.get("Target", "")
# Process vulnerabilities
for vuln in result_item.get("Vulnerabilities", []):
severity = vuln.get("Severity", "MEDIUM").upper()
findings.append(
SASTFinding(
tool="trivy",
rule_id=vuln.get("VulnerabilityID", ""),
severity=severity,
file=target,
line=0,
message=vuln.get("Title", ""),
cwe=vuln.get("CweIDs", [None])[0]
if vuln.get("CweIDs")
else None,
fix_recommendation=f"Upgrade to {vuln.get('FixedVersion', 'latest')}"
if vuln.get("FixedVersion")
else None,
)
)
# Process secrets
for secret in result_item.get("Secrets", []):
findings.append(
SASTFinding(
tool="trivy",
rule_id=secret.get("RuleID", ""),
severity="HIGH",
file=target,
line=secret.get("StartLine", 0),
message=f"Secret detected: {secret.get('Title', '')}",
code_snippet=secret.get("Match", ""),
)
)
except subprocess.TimeoutExpired:
self.logger.warning("Trivy scan timed out")
except json.JSONDecodeError as e:
self.logger.warning(f"Failed to parse Trivy output: {e}")
except Exception as e:
self.logger.warning(f"Trivy scan failed: {e}")
return findings
def get_sast_scanner(config: dict | None = None) -> SASTScanner:
"""Get a configured SAST scanner instance.
Args:
config: Configuration dictionary.
Returns:
Configured SASTScanner instance.
"""
return SASTScanner(config=config)

View File

@@ -1,9 +1,14 @@
"""Utility Functions Package """Utility Functions Package
This package contains utility functions for webhook sanitization, This package contains utility functions for webhook sanitization,
safe event dispatching, and other helper functions. safe event dispatching, ignore patterns, and other helper functions.
""" """
from utils.ignore_patterns import (
IgnorePatterns,
get_ignore_patterns,
should_ignore_file,
)
from utils.webhook_sanitizer import ( from utils.webhook_sanitizer import (
extract_minimal_context, extract_minimal_context,
sanitize_webhook_data, sanitize_webhook_data,
@@ -16,4 +21,7 @@ __all__ = [
"validate_repository_format", "validate_repository_format",
"extract_minimal_context", "extract_minimal_context",
"validate_webhook_signature", "validate_webhook_signature",
"IgnorePatterns",
"get_ignore_patterns",
"should_ignore_file",
] ]

View File

@@ -0,0 +1,358 @@
"""AI Review Ignore Patterns
Provides .gitignore-style pattern matching for excluding files from AI review.
Reads patterns from .ai-reviewignore files in the repository.
"""
import fnmatch
import os
import re
from dataclasses import dataclass
from pathlib import Path
@dataclass
class IgnoreRule:
"""A single ignore rule."""
pattern: str
negation: bool = False
directory_only: bool = False
anchored: bool = False
regex: re.Pattern = None
def __post_init__(self):
"""Compile the pattern to regex."""
self.regex = self._compile_pattern()
def _compile_pattern(self) -> re.Pattern:
"""Convert gitignore pattern to regex."""
pattern = self.pattern
# Handle directory-only patterns
if pattern.endswith("/"):
pattern = pattern[:-1]
self.directory_only = True
# Handle anchored patterns (starting with /)
if pattern.startswith("/"):
pattern = pattern[1:]
self.anchored = True
# Escape special regex characters except * and ?
regex_pattern = ""
i = 0
while i < len(pattern):
c = pattern[i]
if c == "*":
if i + 1 < len(pattern) and pattern[i + 1] == "*":
# ** matches everything including /
if i + 2 < len(pattern) and pattern[i + 2] == "/":
regex_pattern += "(?:.*/)?"
i += 3
continue
else:
regex_pattern += ".*"
i += 2
continue
else:
# * matches everything except /
regex_pattern += "[^/]*"
elif c == "?":
regex_pattern += "[^/]"
elif c == "[":
# Character class
j = i + 1
if j < len(pattern) and pattern[j] == "!":
regex_pattern += "[^"
j += 1
else:
regex_pattern += "["
while j < len(pattern) and pattern[j] != "]":
regex_pattern += pattern[j]
j += 1
if j < len(pattern):
regex_pattern += "]"
i = j
elif c in ".^$+{}|()":
regex_pattern += "\\" + c
else:
regex_pattern += c
i += 1
# Anchor pattern
if self.anchored:
regex_pattern = "^" + regex_pattern
else:
regex_pattern = "(?:^|/)" + regex_pattern
# Match to end or as directory prefix
regex_pattern += "(?:$|/)"
return re.compile(regex_pattern)
def matches(self, path: str, is_directory: bool = False) -> bool:
"""Check if a path matches this rule.
Args:
path: Relative path to check.
is_directory: Whether the path is a directory.
Returns:
True if the path matches.
"""
if self.directory_only and not is_directory:
return False
# Normalize path
path = path.replace("\\", "/")
if not path.startswith("/"):
path = "/" + path
return bool(self.regex.search(path))
class IgnorePatterns:
"""Manages .ai-reviewignore patterns for a repository."""
DEFAULT_PATTERNS = [
# Version control
".git/",
".svn/",
".hg/",
# Dependencies
"node_modules/",
"vendor/",
"venv/",
".venv/",
"__pycache__/",
"*.pyc",
# Build outputs
"dist/",
"build/",
"out/",
"target/",
"*.min.js",
"*.min.css",
"*.bundle.js",
# IDE/Editor
".idea/",
".vscode/",
"*.swp",
"*.swo",
# Generated files
"*.lock",
"package-lock.json",
"yarn.lock",
"poetry.lock",
"Pipfile.lock",
"Cargo.lock",
"go.sum",
# Binary files
"*.exe",
"*.dll",
"*.so",
"*.dylib",
"*.bin",
"*.o",
"*.a",
# Media files
"*.png",
"*.jpg",
"*.jpeg",
"*.gif",
"*.svg",
"*.ico",
"*.mp3",
"*.mp4",
"*.wav",
"*.pdf",
# Archives
"*.zip",
"*.tar",
"*.gz",
"*.rar",
"*.7z",
# Large data files
"*.csv",
"*.json.gz",
"*.sql",
"*.sqlite",
"*.db",
]
def __init__(
self,
repo_root: str | None = None,
ignore_file: str = ".ai-reviewignore",
use_defaults: bool = True,
):
"""Initialize ignore patterns.
Args:
repo_root: Repository root path.
ignore_file: Name of ignore file to read.
use_defaults: Whether to include default patterns.
"""
self.repo_root = repo_root or os.getcwd()
self.ignore_file = ignore_file
self.rules: list[IgnoreRule] = []
# Load default patterns
if use_defaults:
for pattern in self.DEFAULT_PATTERNS:
self._add_pattern(pattern)
# Load patterns from ignore file
self._load_ignore_file()
def _load_ignore_file(self):
"""Load patterns from .ai-reviewignore file."""
ignore_path = os.path.join(self.repo_root, self.ignore_file)
if not os.path.exists(ignore_path):
return
try:
with open(ignore_path) as f:
for line in f:
line = line.rstrip("\n\r")
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
self._add_pattern(line)
except Exception:
pass # Ignore errors reading the file
def _add_pattern(self, pattern: str):
"""Add a pattern to the rules list.
Args:
pattern: Pattern string (gitignore format).
"""
# Check for negation
negation = False
if pattern.startswith("!"):
negation = True
pattern = pattern[1:]
if not pattern:
return
self.rules.append(IgnoreRule(pattern=pattern, negation=negation))
def is_ignored(self, path: str, is_directory: bool = False) -> bool:
"""Check if a path should be ignored.
Args:
path: Relative path to check.
is_directory: Whether the path is a directory.
Returns:
True if the path should be ignored.
"""
# Normalize path
path = path.replace("\\", "/").lstrip("/")
# Check each rule in order (later rules override earlier ones)
ignored = False
for rule in self.rules:
if rule.matches(path, is_directory):
ignored = not rule.negation
return ignored
def filter_paths(self, paths: list[str]) -> list[str]:
"""Filter a list of paths, removing ignored ones.
Args:
paths: List of relative paths.
Returns:
Filtered list of paths.
"""
return [p for p in paths if not self.is_ignored(p)]
def filter_diff_files(self, files: list[dict]) -> list[dict]:
"""Filter diff file objects, removing ignored ones.
Args:
files: List of file dicts with 'filename' or 'path' key.
Returns:
Filtered list of file dicts.
"""
result = []
for f in files:
path = f.get("filename") or f.get("path") or f.get("name", "")
if not self.is_ignored(path):
result.append(f)
return result
def should_review_file(self, filename: str) -> bool:
"""Check if a file should be reviewed.
Args:
filename: File path to check.
Returns:
True if the file should be reviewed.
"""
return not self.is_ignored(filename)
@classmethod
def from_config(
cls, config: dict, repo_root: str | None = None
) -> "IgnorePatterns":
"""Create IgnorePatterns from config.
Args:
config: Configuration dictionary.
repo_root: Repository root path.
Returns:
IgnorePatterns instance.
"""
ignore_config = config.get("ignore", {})
use_defaults = ignore_config.get("use_defaults", True)
ignore_file = ignore_config.get("file", ".ai-reviewignore")
instance = cls(
repo_root=repo_root,
ignore_file=ignore_file,
use_defaults=use_defaults,
)
# Add any extra patterns from config
extra_patterns = ignore_config.get("patterns", [])
for pattern in extra_patterns:
instance._add_pattern(pattern)
return instance
def get_ignore_patterns(repo_root: str | None = None) -> IgnorePatterns:
"""Get ignore patterns for a repository.
Args:
repo_root: Repository root path.
Returns:
IgnorePatterns instance.
"""
return IgnorePatterns(repo_root=repo_root)
def should_ignore_file(filename: str, repo_root: str | None = None) -> bool:
"""Quick check if a file should be ignored.
Args:
filename: File path to check.
repo_root: Repository root path.
Returns:
True if the file should be ignored.
"""
return get_ignore_patterns(repo_root).is_ignored(filename)