diff --git a/CLAUDE.md b/CLAUDE.md index a3275a4..8a99bbb 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -69,11 +69,17 @@ The codebase uses an **agent-based architecture** where specialized agents handl - `execute(context)` - Main execution logic - Returns `AgentResult` with success status, message, data, and actions taken + **Core Agents:** - **PRAgent** - Reviews pull requests with inline comments and security scanning - - **IssueAgent** - Triages issues and responds to @ai-bot commands + - **IssueAgent** - Triages issues and responds to @codebot commands - **CodebaseAgent** - Analyzes entire codebase health and tech debt - **ChatAgent** - Interactive assistant with tool calling (search_codebase, read_file, search_web) + **Specialized Agents:** + - **DependencyAgent** - Scans dependencies for security vulnerabilities (Python, JavaScript) + - **TestCoverageAgent** - Analyzes code for test coverage gaps and suggests test cases + - **ArchitectureAgent** - Enforces layer separation and detects architecture violations + 3. **Dispatcher** (`dispatcher.py`) - Routes events to appropriate agents: - Registers agents at startup - Determines which agents can handle each event @@ -84,14 +90,23 @@ The codebase uses an **agent-based architecture** where specialized agents handl The `LLMClient` (`clients/llm_client.py`) provides a unified interface for multiple LLM providers: +**Core Providers (in llm_client.py):** - **OpenAI** - Primary provider (gpt-4.1-mini default) - **OpenRouter** - Multi-provider access (claude-3.5-sonnet) - **Ollama** - Self-hosted models (codellama:13b) +**Additional Providers (in clients/providers/):** +- **AnthropicProvider** - Direct Anthropic Claude API (claude-3.5-sonnet) +- **AzureOpenAIProvider** - Azure OpenAI Service with API key auth +- **AzureOpenAIWithAADProvider** - Azure OpenAI with Azure AD authentication +- **GeminiProvider** - Google Gemini API (public) +- **VertexAIGeminiProvider** - Google Vertex AI Gemini (enterprise GCP) + Key features: - Tool/function calling support via `call_with_tools(messages, tools)` - JSON response parsing with fallback extraction - Provider-specific configuration via `config.yml` +- Configurable timeouts per provider ### Platform Abstraction diff --git a/README.md b/README.md index 2707c29..3769002 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # OpenRabbit -Enterprise-grade AI code review system for **Gitea** with automated PR review, issue triage, interactive chat, and codebase analysis. +Enterprise-grade AI code review system for **Gitea** and **GitHub** with automated PR review, issue triage, interactive chat, and codebase analysis. --- @@ -14,9 +14,15 @@ Enterprise-grade AI code review system for **Gitea** with automated PR review, i | **Chat** | Interactive AI chat with codebase search and web search tools | | **@codebot Commands** | `@codebot summarize`, `changelog`, `explain-diff`, `explain`, `suggest`, `triage`, `review-again` in comments | | **Codebase Analysis** | Health scores, tech debt tracking, weekly reports | -| **Security Scanner** | 17 OWASP-aligned rules for vulnerability detection | +| **Security Scanner** | 17 OWASP-aligned rules + SAST integration (Bandit, Semgrep) | +| **Dependency Scanning** | Vulnerability detection for Python, JavaScript dependencies | +| **Test Coverage** | AI-powered test suggestions for untested code | +| **Architecture Compliance** | Layer separation enforcement, circular dependency detection | +| **Notifications** | Slack/Discord alerts for security findings and reviews | +| **Compliance** | Audit trail, CODEOWNERS enforcement, regulatory support | +| **Multi-Provider LLM** | OpenAI, Anthropic Claude, Azure OpenAI, Google Gemini, Ollama | | **Enterprise Ready** | Audit logging, metrics, Prometheus export | -| **Gitea Native** | Built for Gitea workflows and API | +| **Gitea Native** | Built for Gitea workflows and API (also works with GitHub) | --- @@ -116,12 +122,28 @@ tools/ai-review/ │ ├── issue_agent.py # Issue triage & @codebot commands │ ├── pr_agent.py # PR review with security scan │ ├── codebase_agent.py # Codebase health analysis -│ └── chat_agent.py # Interactive chat with tool calling +│ ├── chat_agent.py # Interactive chat with tool calling +│ ├── dependency_agent.py # Dependency vulnerability scanning +│ ├── test_coverage_agent.py # Test coverage analysis +│ └── architecture_agent.py # Architecture compliance checking ├── clients/ # API clients │ ├── gitea_client.py # Gitea REST API wrapper -│ └── llm_client.py # Multi-provider LLM client with tool support +│ ├── llm_client.py # Multi-provider LLM client with tool support +│ └── providers/ # Additional LLM providers +│ ├── anthropic_provider.py # Direct Anthropic Claude API +│ ├── azure_provider.py # Azure OpenAI Service +│ └── gemini_provider.py # Google Gemini API ├── security/ # Security scanning -│ └── security_scanner.py # 17 OWASP-aligned rules +│ ├── security_scanner.py # 17 OWASP-aligned rules +│ └── sast_scanner.py # Bandit, Semgrep, Trivy integration +├── notifications/ # Alerting system +│ └── notifier.py # Slack, Discord, webhook notifications +├── compliance/ # Compliance & audit +│ ├── audit_trail.py # Audit logging with integrity verification +│ └── codeowners.py # CODEOWNERS enforcement +├── utils/ # Utility functions +│ ├── ignore_patterns.py # .ai-reviewignore support +│ └── webhook_sanitizer.py # Input validation ├── enterprise/ # Enterprise features │ ├── audit_logger.py # JSONL audit logging │ └── metrics.py # Prometheus-compatible metrics @@ -182,6 +204,10 @@ In any issue comment: | `@codebot summarize` | Summarize the issue in 2-3 sentences | | `@codebot explain` | Explain what the issue is about | | `@codebot suggest` | Suggest solutions or next steps | +| `@codebot check-deps` | Scan dependencies for security vulnerabilities | +| `@codebot suggest-tests` | Suggest test cases for changed code | +| `@codebot refactor-suggest` | Suggest refactoring opportunities | +| `@codebot architecture` | Check architecture compliance (alias: `arch-check`) | | `@codebot` (any question) | Chat with AI using codebase/web search tools | ### Pull Request Commands @@ -522,19 +548,91 @@ Replace `'Bartender'` with your bot's Gitea username. This prevents the bot from | Provider | Model | Use Case | |----------|-------|----------| -| OpenAI | gpt-4.1-mini | Fast, reliable | +| OpenAI | gpt-4.1-mini | Fast, reliable, default | +| Anthropic | claude-3.5-sonnet | Direct Claude API access | +| Azure OpenAI | gpt-4 (deployment) | Enterprise Azure deployments | +| Google Gemini | gemini-1.5-pro | GCP customers, Vertex AI | | OpenRouter | claude-3.5-sonnet | Multi-provider access | | Ollama | codellama:13b | Self-hosted, private | +### Provider Configuration + +```yaml +# In config.yml +provider: anthropic # openai | anthropic | azure | gemini | openrouter | ollama + +# Azure OpenAI +azure: + endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var + deployment: "gpt-4" + api_version: "2024-02-15-preview" + +# Google Gemini (Vertex AI) +gemini: + project: "" # Set via GOOGLE_CLOUD_PROJECT env var + region: "us-central1" +``` + +### Environment Variables + +| Variable | Provider | Description | +|----------|----------|-------------| +| `OPENAI_API_KEY` | OpenAI | API key | +| `ANTHROPIC_API_KEY` | Anthropic | API key | +| `AZURE_OPENAI_ENDPOINT` | Azure | Service endpoint URL | +| `AZURE_OPENAI_API_KEY` | Azure | API key | +| `AZURE_OPENAI_DEPLOYMENT` | Azure | Deployment name | +| `GOOGLE_API_KEY` | Gemini | API key (public API) | +| `GOOGLE_CLOUD_PROJECT` | Vertex AI | GCP project ID | +| `OPENROUTER_API_KEY` | OpenRouter | API key | +| `OLLAMA_HOST` | Ollama | Server URL (default: localhost:11434) | + --- ## Enterprise Features -- **Audit Logging**: JSONL logs with daily rotation +- **Audit Logging**: JSONL logs with integrity checksums and daily rotation +- **Compliance**: HIPAA, SOC2, PCI-DSS, GDPR support with configurable rules +- **CODEOWNERS Enforcement**: Validate approvals against CODEOWNERS file +- **Notifications**: Slack/Discord webhooks for critical findings +- **SAST Integration**: Bandit, Semgrep, Trivy for advanced security scanning - **Metrics**: Prometheus-compatible export -- **Rate Limiting**: Configurable request limits +- **Rate Limiting**: Configurable request limits and timeouts - **Custom Security Rules**: Define your own patterns via YAML - **Tool Calling**: LLM function calling for interactive chat +- **Ignore Patterns**: `.ai-reviewignore` for excluding files from review + +### Notifications Configuration + +```yaml +# In config.yml +notifications: + enabled: true + threshold: "warning" # info | warning | error | critical + + slack: + enabled: true + webhook_url: "" # Set via SLACK_WEBHOOK_URL env var + channel: "#code-review" + + discord: + enabled: true + webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var +``` + +### Compliance Configuration + +```yaml +compliance: + enabled: true + audit: + enabled: true + log_file: "audit.log" + retention_days: 90 + codeowners: + enabled: true + require_approval: true +``` --- diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py new file mode 100644 index 0000000..bbcd0e5 --- /dev/null +++ b/tests/test_dispatcher.py @@ -0,0 +1,296 @@ +"""Test Suite for Dispatcher + +Tests for event routing and agent execution. +""" + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review")) + +from unittest.mock import MagicMock, Mock, patch + +import pytest + + +class TestDispatcherCreation: + """Test dispatcher initialization.""" + + def test_create_dispatcher(self): + """Test creating dispatcher.""" + from dispatcher import Dispatcher + + dispatcher = Dispatcher() + assert dispatcher is not None + assert dispatcher.agents == [] + + def test_create_dispatcher_with_config(self): + """Test creating dispatcher with config.""" + from dispatcher import Dispatcher + + config = {"dispatcher": {"max_workers": 4}} + dispatcher = Dispatcher(config=config) + assert dispatcher.config == config + + +class TestAgentRegistration: + """Test agent registration.""" + + def test_register_agent(self): + """Test registering an agent.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + class MockAgent(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "test" + + def execute(self, context): + return AgentResult(success=True, message="done") + + dispatcher = Dispatcher() + agent = MockAgent(config={}, gitea_client=None, llm_client=None) + dispatcher.register_agent(agent) + + assert len(dispatcher.agents) == 1 + assert dispatcher.agents[0] == agent + + def test_register_multiple_agents(self): + """Test registering multiple agents.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + class MockAgent1(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "type1" + + def execute(self, context): + return AgentResult(success=True, message="agent1") + + class MockAgent2(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "type2" + + def execute(self, context): + return AgentResult(success=True, message="agent2") + + dispatcher = Dispatcher() + dispatcher.register_agent( + MockAgent1(config={}, gitea_client=None, llm_client=None) + ) + dispatcher.register_agent( + MockAgent2(config={}, gitea_client=None, llm_client=None) + ) + + assert len(dispatcher.agents) == 2 + + +class TestEventRouting: + """Test event routing to agents.""" + + def test_route_to_matching_agent(self): + """Test that events are routed to matching agents.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + class MockAgent(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "issues" + + def execute(self, context): + return AgentResult(success=True, message="handled") + + dispatcher = Dispatcher() + agent = MockAgent(config={}, gitea_client=None, llm_client=None) + dispatcher.register_agent(agent) + + result = dispatcher.dispatch( + event_type="issues", + event_data={"action": "opened"}, + owner="test", + repo="repo", + ) + + assert len(result.agents_run) == 1 + assert result.results[0].success is True + + def test_no_matching_agent(self): + """Test dispatch when no agent matches.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + class MockAgent(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "issues" + + def execute(self, context): + return AgentResult(success=True, message="handled") + + dispatcher = Dispatcher() + agent = MockAgent(config={}, gitea_client=None, llm_client=None) + dispatcher.register_agent(agent) + + result = dispatcher.dispatch( + event_type="pull_request", # Different event type + event_data={"action": "opened"}, + owner="test", + repo="repo", + ) + + assert len(result.agents_run) == 0 + + def test_multiple_matching_agents(self): + """Test dispatch when multiple agents match.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + class MockAgent1(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "issues" + + def execute(self, context): + return AgentResult(success=True, message="agent1") + + class MockAgent2(BaseAgent): + def can_handle(self, event_type, event_data): + return event_type == "issues" + + def execute(self, context): + return AgentResult(success=True, message="agent2") + + dispatcher = Dispatcher() + dispatcher.register_agent( + MockAgent1(config={}, gitea_client=None, llm_client=None) + ) + dispatcher.register_agent( + MockAgent2(config={}, gitea_client=None, llm_client=None) + ) + + result = dispatcher.dispatch( + event_type="issues", + event_data={"action": "opened"}, + owner="test", + repo="repo", + ) + + assert len(result.agents_run) == 2 + + +class TestDispatchResult: + """Test dispatch result structure.""" + + def test_result_structure(self): + """Test DispatchResult has correct structure.""" + from dispatcher import DispatchResult + + result = DispatchResult( + agents_run=["Agent1", "Agent2"], + results=[], + errors=[], + ) + + assert result.agents_run == ["Agent1", "Agent2"] + assert result.results == [] + assert result.errors == [] + + def test_result_with_errors(self): + """Test DispatchResult with errors.""" + from dispatcher import DispatchResult + + result = DispatchResult( + agents_run=["Agent1"], + results=[], + errors=["Error 1", "Error 2"], + ) + + assert len(result.errors) == 2 + + +class TestAgentExecution: + """Test agent execution through dispatcher.""" + + def test_agent_receives_context(self): + """Test that agents receive proper context.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + received_context = None + + class MockAgent(BaseAgent): + def can_handle(self, event_type, event_data): + return True + + def execute(self, context): + nonlocal received_context + received_context = context + return AgentResult(success=True, message="done") + + dispatcher = Dispatcher() + dispatcher.register_agent( + MockAgent(config={}, gitea_client=None, llm_client=None) + ) + + dispatcher.dispatch( + event_type="issues", + event_data={"action": "opened", "issue": {"number": 123}}, + owner="testowner", + repo="testrepo", + ) + + assert received_context is not None + assert received_context.owner == "testowner" + assert received_context.repo == "testrepo" + assert received_context.event_type == "issues" + assert received_context.event_data["action"] == "opened" + + def test_agent_failure_captured(self): + """Test that agent failures are captured in results.""" + from agents.base_agent import AgentContext, AgentResult, BaseAgent + from dispatcher import Dispatcher + + class FailingAgent(BaseAgent): + def can_handle(self, event_type, event_data): + return True + + def execute(self, context): + raise Exception("Test error") + + dispatcher = Dispatcher() + dispatcher.register_agent( + FailingAgent(config={}, gitea_client=None, llm_client=None) + ) + + result = dispatcher.dispatch( + event_type="issues", + event_data={}, + owner="test", + repo="repo", + ) + + # Agent should still be in agents_run + assert len(result.agents_run) == 1 + # Result should indicate failure + assert result.results[0].success is False + + +class TestGetDispatcher: + """Test get_dispatcher factory function.""" + + def test_get_dispatcher_returns_singleton(self): + """Test that get_dispatcher returns configured dispatcher.""" + from dispatcher import get_dispatcher + + dispatcher = get_dispatcher() + assert dispatcher is not None + + def test_get_dispatcher_with_config(self): + """Test get_dispatcher with custom config.""" + from dispatcher import get_dispatcher + + config = {"test": "value"} + dispatcher = get_dispatcher(config=config) + assert dispatcher.config.get("test") == "value" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py new file mode 100644 index 0000000..dc1de11 --- /dev/null +++ b/tests/test_llm_client.py @@ -0,0 +1,456 @@ +"""Test Suite for LLM Client + +Tests for LLM client functionality including provider support, +tool calling, and JSON parsing. +""" + +import json +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review")) + +from unittest.mock import MagicMock, Mock, patch + +import pytest + + +class TestLLMClientCreation: + """Test LLM client initialization.""" + + def test_create_openai_client(self): + """Test creating OpenAI client.""" + from clients.llm_client import LLMClient + + client = LLMClient( + provider="openai", + config={"model": "gpt-4o-mini", "api_key": "test-key"}, + ) + assert client.provider_name == "openai" + + def test_create_openrouter_client(self): + """Test creating OpenRouter client.""" + from clients.llm_client import LLMClient + + client = LLMClient( + provider="openrouter", + config={"model": "anthropic/claude-3.5-sonnet", "api_key": "test-key"}, + ) + assert client.provider_name == "openrouter" + + def test_create_ollama_client(self): + """Test creating Ollama client.""" + from clients.llm_client import LLMClient + + client = LLMClient( + provider="ollama", + config={"model": "codellama:13b", "host": "http://localhost:11434"}, + ) + assert client.provider_name == "ollama" + + def test_invalid_provider_raises_error(self): + """Test that invalid provider raises ValueError.""" + from clients.llm_client import LLMClient + + with pytest.raises(ValueError, match="Unknown provider"): + LLMClient(provider="invalid_provider") + + def test_from_config_openai(self): + """Test creating client from config dict.""" + from clients.llm_client import LLMClient + + config = { + "provider": "openai", + "model": {"openai": "gpt-4o-mini"}, + "temperature": 0, + "max_tokens": 4096, + } + client = LLMClient.from_config(config) + assert client.provider_name == "openai" + + +class TestLLMResponse: + """Test LLM response dataclass.""" + + def test_response_creation(self): + """Test creating LLMResponse.""" + from clients.llm_client import LLMResponse + + response = LLMResponse( + content="Test response", + model="gpt-4o-mini", + provider="openai", + tokens_used=100, + finish_reason="stop", + ) + + assert response.content == "Test response" + assert response.model == "gpt-4o-mini" + assert response.provider == "openai" + assert response.tokens_used == 100 + assert response.finish_reason == "stop" + assert response.tool_calls is None + + def test_response_with_tool_calls(self): + """Test LLMResponse with tool calls.""" + from clients.llm_client import LLMResponse, ToolCall + + tool_calls = [ToolCall(id="call_1", name="search", arguments={"query": "test"})] + + response = LLMResponse( + content="", + model="gpt-4o-mini", + provider="openai", + tool_calls=tool_calls, + ) + + assert response.tool_calls is not None + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "search" + + +class TestToolCall: + """Test ToolCall dataclass.""" + + def test_tool_call_creation(self): + """Test creating ToolCall.""" + from clients.llm_client import ToolCall + + tool_call = ToolCall( + id="call_123", + name="search_codebase", + arguments={"query": "authentication", "file_pattern": "*.py"}, + ) + + assert tool_call.id == "call_123" + assert tool_call.name == "search_codebase" + assert tool_call.arguments["query"] == "authentication" + assert tool_call.arguments["file_pattern"] == "*.py" + + +class TestJSONParsing: + """Test JSON extraction and parsing.""" + + def test_parse_direct_json(self): + """Test parsing direct JSON response.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = '{"key": "value", "number": 42}' + result = client._extract_json(content) + + assert result["key"] == "value" + assert result["number"] == 42 + + def test_parse_json_in_code_block(self): + """Test parsing JSON in markdown code block.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = """Here is the analysis: + +```json +{ + "type": "bug", + "priority": "high" +} +``` + +That's my analysis.""" + + result = client._extract_json(content) + assert result["type"] == "bug" + assert result["priority"] == "high" + + def test_parse_json_in_plain_code_block(self): + """Test parsing JSON in plain code block (no json specifier).""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = """Analysis: + +``` +{"status": "success", "count": 5} +``` +""" + + result = client._extract_json(content) + assert result["status"] == "success" + assert result["count"] == 5 + + def test_parse_json_with_preamble(self): + """Test parsing JSON with text before it.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = """Based on my analysis, here is the result: +{"findings": ["issue1", "issue2"], "severity": "medium"} +""" + + result = client._extract_json(content) + assert result["findings"] == ["issue1", "issue2"] + assert result["severity"] == "medium" + + def test_parse_json_with_postamble(self): + """Test parsing JSON with text after it.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = """{"result": true} + +Let me know if you need more details.""" + + result = client._extract_json(content) + assert result["result"] is True + + def test_parse_nested_json(self): + """Test parsing nested JSON objects.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = """{ + "outer": { + "inner": { + "value": "deep" + } + }, + "array": [1, 2, 3] +}""" + + result = client._extract_json(content) + assert result["outer"]["inner"]["value"] == "deep" + assert result["array"] == [1, 2, 3] + + def test_parse_invalid_json_raises_error(self): + """Test that invalid JSON raises ValueError.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = "This is not JSON at all" + + with pytest.raises(ValueError, match="Failed to parse JSON"): + client._extract_json(content) + + def test_parse_truncated_json_raises_error(self): + """Test that truncated JSON raises ValueError.""" + from clients.llm_client import LLMClient + + client = LLMClient.__new__(LLMClient) + + content = '{"key": "value", "incomplete' + + with pytest.raises(ValueError): + client._extract_json(content) + + +class TestOpenAIProvider: + """Test OpenAI provider.""" + + def test_provider_creation(self): + """Test OpenAI provider initialization.""" + from clients.llm_client import OpenAIProvider + + provider = OpenAIProvider( + api_key="test-key", + model="gpt-4o-mini", + temperature=0.5, + max_tokens=2048, + ) + + assert provider.model == "gpt-4o-mini" + assert provider.temperature == 0.5 + assert provider.max_tokens == 2048 + + def test_provider_requires_api_key(self): + """Test that calling without API key raises error.""" + from clients.llm_client import OpenAIProvider + + provider = OpenAIProvider(api_key="") + + with pytest.raises(ValueError, match="API key is required"): + provider.call("test prompt") + + @patch("clients.llm_client.requests.post") + def test_provider_call_success(self, mock_post): + """Test successful API call.""" + from clients.llm_client import OpenAIProvider + + mock_response = Mock() + mock_response.json.return_value = { + "choices": [ + { + "message": {"content": "Test response"}, + "finish_reason": "stop", + } + ], + "model": "gpt-4o-mini", + "usage": {"total_tokens": 50}, + } + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + provider = OpenAIProvider(api_key="test-key") + response = provider.call("Hello") + + assert response.content == "Test response" + assert response.provider == "openai" + assert response.tokens_used == 50 + + +class TestOpenRouterProvider: + """Test OpenRouter provider.""" + + def test_provider_creation(self): + """Test OpenRouter provider initialization.""" + from clients.llm_client import OpenRouterProvider + + provider = OpenRouterProvider( + api_key="test-key", + model="anthropic/claude-3.5-sonnet", + ) + + assert provider.model == "anthropic/claude-3.5-sonnet" + + @patch("clients.llm_client.requests.post") + def test_provider_call_success(self, mock_post): + """Test successful API call.""" + from clients.llm_client import OpenRouterProvider + + mock_response = Mock() + mock_response.json.return_value = { + "choices": [ + { + "message": {"content": "Claude response"}, + "finish_reason": "stop", + } + ], + "model": "anthropic/claude-3.5-sonnet", + "usage": {"total_tokens": 75}, + } + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + provider = OpenRouterProvider(api_key="test-key") + response = provider.call("Hello") + + assert response.content == "Claude response" + assert response.provider == "openrouter" + + +class TestOllamaProvider: + """Test Ollama provider.""" + + def test_provider_creation(self): + """Test Ollama provider initialization.""" + from clients.llm_client import OllamaProvider + + provider = OllamaProvider( + host="http://localhost:11434", + model="codellama:13b", + ) + + assert provider.model == "codellama:13b" + assert provider.host == "http://localhost:11434" + + @patch("clients.llm_client.requests.post") + def test_provider_call_success(self, mock_post): + """Test successful API call.""" + from clients.llm_client import OllamaProvider + + mock_response = Mock() + mock_response.json.return_value = { + "response": "Ollama response", + "model": "codellama:13b", + "done": True, + "eval_count": 30, + } + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + provider = OllamaProvider() + response = provider.call("Hello") + + assert response.content == "Ollama response" + assert response.provider == "ollama" + + +class TestToolCalling: + """Test tool/function calling support.""" + + @patch("clients.llm_client.requests.post") + def test_openai_tool_calling(self, mock_post): + """Test OpenAI tool calling.""" + from clients.llm_client import OpenAIProvider + + mock_response = Mock() + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "function": { + "name": "search_codebase", + "arguments": '{"query": "auth"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "model": "gpt-4o-mini", + "usage": {"total_tokens": 100}, + } + mock_response.raise_for_status = Mock() + mock_post.return_value = mock_response + + provider = OpenAIProvider(api_key="test-key") + tools = [ + { + "type": "function", + "function": { + "name": "search_codebase", + "description": "Search the codebase", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + }, + } + ] + + response = provider.call_with_tools( + messages=[{"role": "user", "content": "Search for auth"}], + tools=tools, + ) + + assert response.tool_calls is not None + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "search_codebase" + assert response.tool_calls[0].arguments["query"] == "auth" + + def test_ollama_tool_calling_not_supported(self): + """Test that Ollama raises NotImplementedError for tool calling.""" + from clients.llm_client import OllamaProvider + + provider = OllamaProvider() + + with pytest.raises(NotImplementedError): + provider.call_with_tools( + messages=[{"role": "user", "content": "test"}], + tools=[], + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tools/ai-review/agents/__init__.py b/tools/ai-review/agents/__init__.py index 324ac8c..e2452ef 100644 --- a/tools/ai-review/agents/__init__.py +++ b/tools/ai-review/agents/__init__.py @@ -2,20 +2,40 @@ This package contains the modular agent implementations for the enterprise AI code review system. + +Core Agents: +- PRAgent: Pull request review and analysis +- IssueAgent: Issue triage and response +- CodebaseAgent: Codebase health analysis +- ChatAgent: Interactive chat with tool calling + +Specialized Agents: +- DependencyAgent: Dependency vulnerability scanning +- TestCoverageAgent: Test coverage analysis and suggestions +- ArchitectureAgent: Architecture compliance checking """ +from agents.architecture_agent import ArchitectureAgent from agents.base_agent import AgentContext, AgentResult, BaseAgent from agents.chat_agent import ChatAgent from agents.codebase_agent import CodebaseAgent +from agents.dependency_agent import DependencyAgent from agents.issue_agent import IssueAgent from agents.pr_agent import PRAgent +from agents.test_coverage_agent import TestCoverageAgent __all__ = [ + # Base "BaseAgent", "AgentContext", "AgentResult", + # Core Agents "IssueAgent", "PRAgent", "CodebaseAgent", "ChatAgent", + # Specialized Agents + "DependencyAgent", + "TestCoverageAgent", + "ArchitectureAgent", ] diff --git a/tools/ai-review/agents/architecture_agent.py b/tools/ai-review/agents/architecture_agent.py new file mode 100644 index 0000000..b14c009 --- /dev/null +++ b/tools/ai-review/agents/architecture_agent.py @@ -0,0 +1,547 @@ +"""Architecture Compliance Agent + +AI agent for enforcing architectural patterns and layer separation. +Detects cross-layer violations and circular dependencies. +""" + +import base64 +import os +import re +from dataclasses import dataclass, field + +from agents.base_agent import AgentContext, AgentResult, BaseAgent + + +@dataclass +class ArchitectureViolation: + """An architecture violation.""" + + file: str + line: int + violation_type: str # cross_layer, circular, naming, structure + severity: str # HIGH, MEDIUM, LOW + description: str + recommendation: str + source_layer: str | None = None + target_layer: str | None = None + + +@dataclass +class ArchitectureReport: + """Report of architecture analysis.""" + + violations: list[ArchitectureViolation] + layers_detected: dict[str, list[str]] + circular_dependencies: list[tuple[str, str]] + compliance_score: float + recommendations: list[str] + + +class ArchitectureAgent(BaseAgent): + """Agent for enforcing architectural compliance.""" + + # Marker for architecture comments + ARCH_AI_MARKER = "" + + # Default layer definitions + DEFAULT_LAYERS = { + "api": { + "patterns": ["api/", "routes/", "controllers/", "handlers/", "views/"], + "can_import": ["services", "models", "utils", "config"], + "cannot_import": ["db", "repositories", "infrastructure"], + }, + "services": { + "patterns": ["services/", "usecases/", "application/"], + "can_import": ["models", "repositories", "utils", "config"], + "cannot_import": ["api", "controllers", "handlers"], + }, + "repositories": { + "patterns": ["repositories/", "repos/", "data/"], + "can_import": ["models", "db", "utils", "config"], + "cannot_import": ["api", "services", "controllers"], + }, + "models": { + "patterns": ["models/", "entities/", "domain/", "schemas/"], + "can_import": ["utils", "config"], + "cannot_import": ["api", "services", "repositories", "db"], + }, + "db": { + "patterns": ["db/", "database/", "infrastructure/"], + "can_import": ["models", "config"], + "cannot_import": ["api", "services"], + }, + "utils": { + "patterns": ["utils/", "helpers/", "common/", "lib/"], + "can_import": ["config"], + "cannot_import": [], + }, + } + + def can_handle(self, event_type: str, event_data: dict) -> bool: + """Check if this agent handles the given event.""" + agent_config = self.config.get("agents", {}).get("architecture", {}) + if not agent_config.get("enabled", False): + return False + + # Handle PR events + if event_type == "pull_request": + action = event_data.get("action", "") + if action in ["opened", "synchronize"]: + return True + + # Handle @codebot architecture command + if event_type == "issue_comment": + comment_body = event_data.get("comment", {}).get("body", "") + mention_prefix = self.config.get("interaction", {}).get( + "mention_prefix", "@codebot" + ) + if f"{mention_prefix} architecture" in comment_body.lower(): + return True + if f"{mention_prefix} arch-check" in comment_body.lower(): + return True + + return False + + def execute(self, context: AgentContext) -> AgentResult: + """Execute the architecture agent.""" + self.logger.info(f"Checking architecture for {context.owner}/{context.repo}") + + actions_taken = [] + + # Get layer configuration + agent_config = self.config.get("agents", {}).get("architecture", {}) + layers = agent_config.get("layers", self.DEFAULT_LAYERS) + + # Determine issue number + if context.event_type == "issue_comment": + issue = context.event_data.get("issue", {}) + issue_number = issue.get("number") + comment_author = ( + context.event_data.get("comment", {}) + .get("user", {}) + .get("login", "user") + ) + is_pr = issue.get("pull_request") is not None + else: + pr = context.event_data.get("pull_request", {}) + issue_number = pr.get("number") + comment_author = None + is_pr = True + + if is_pr and issue_number: + # Analyze PR diff + diff = self._get_pr_diff(context.owner, context.repo, issue_number) + report = self._analyze_diff(diff, layers) + actions_taken.append(f"Analyzed PR diff for architecture violations") + else: + # Analyze repository structure + report = self._analyze_repository(context.owner, context.repo, layers) + actions_taken.append(f"Analyzed repository architecture") + + # Post report + if issue_number: + comment = self._format_architecture_report(report, comment_author) + self.upsert_comment( + context.owner, + context.repo, + issue_number, + comment, + marker=self.ARCH_AI_MARKER, + ) + actions_taken.append("Posted architecture report") + + return AgentResult( + success=len(report.violations) == 0 or report.compliance_score >= 0.8, + message=f"Architecture check: {len(report.violations)} violations, {report.compliance_score:.0%} compliance", + data={ + "violations_count": len(report.violations), + "compliance_score": report.compliance_score, + "circular_dependencies": len(report.circular_dependencies), + }, + actions_taken=actions_taken, + ) + + def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str: + """Get the PR diff.""" + try: + return self.gitea.get_pull_request_diff(owner, repo, pr_number) + except Exception as e: + self.logger.error(f"Failed to get PR diff: {e}") + return "" + + def _analyze_diff(self, diff: str, layers: dict) -> ArchitectureReport: + """Analyze PR diff for architecture violations.""" + violations = [] + imports_by_file = {} + + current_file = None + current_language = None + + for line in diff.splitlines(): + # Track current file + if line.startswith("diff --git"): + match = re.search(r"b/(.+)$", line) + if match: + current_file = match.group(1) + current_language = self._detect_language(current_file) + imports_by_file[current_file] = [] + + # Look for import statements in added lines + if line.startswith("+") and not line.startswith("+++"): + if current_file and current_language: + imports = self._extract_imports(line[1:], current_language) + imports_by_file.setdefault(current_file, []).extend(imports) + + # Check for violations + for file_path, imports in imports_by_file.items(): + source_layer = self._detect_layer(file_path, layers) + if not source_layer: + continue + + layer_config = layers.get(source_layer, {}) + cannot_import = layer_config.get("cannot_import", []) + + for imp in imports: + target_layer = self._detect_layer_from_import(imp, layers) + if target_layer and target_layer in cannot_import: + violations.append( + ArchitectureViolation( + file=file_path, + line=0, # Line number not tracked in this simple implementation + violation_type="cross_layer", + severity="HIGH", + description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'", + recommendation=f"Move this import to an allowed layer or refactor the dependency", + source_layer=source_layer, + target_layer=target_layer, + ) + ) + + # Detect circular dependencies + circular = self._detect_circular_dependencies(imports_by_file, layers) + + # Calculate compliance score + total_imports = sum(len(imps) for imps in imports_by_file.values()) + if total_imports > 0: + compliance = 1.0 - (len(violations) / max(total_imports, 1)) + else: + compliance = 1.0 + + return ArchitectureReport( + violations=violations, + layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers), + circular_dependencies=circular, + compliance_score=max(0.0, compliance), + recommendations=self._generate_recommendations(violations), + ) + + def _analyze_repository( + self, owner: str, repo: str, layers: dict + ) -> ArchitectureReport: + """Analyze repository structure for architecture compliance.""" + violations = [] + imports_by_file = {} + + # Collect files from each layer + for layer_name, layer_config in layers.items(): + for pattern in layer_config.get("patterns", []): + try: + path = pattern.rstrip("/") + contents = self.gitea.get_file_contents(owner, repo, path) + if isinstance(contents, list): + for item in contents[:20]: # Limit files per layer + if item.get("type") == "file": + filepath = item.get("path", "") + imports = self._get_file_imports(owner, repo, filepath) + imports_by_file[filepath] = imports + except Exception: + pass + + # Check for violations + for file_path, imports in imports_by_file.items(): + source_layer = self._detect_layer(file_path, layers) + if not source_layer: + continue + + layer_config = layers.get(source_layer, {}) + cannot_import = layer_config.get("cannot_import", []) + + for imp in imports: + target_layer = self._detect_layer_from_import(imp, layers) + if target_layer and target_layer in cannot_import: + violations.append( + ArchitectureViolation( + file=file_path, + line=0, + violation_type="cross_layer", + severity="HIGH", + description=f"Layer '{source_layer}' imports from forbidden layer '{target_layer}'", + recommendation=f"Refactor to remove dependency on '{target_layer}'", + source_layer=source_layer, + target_layer=target_layer, + ) + ) + + # Detect circular dependencies + circular = self._detect_circular_dependencies(imports_by_file, layers) + + # Calculate compliance + total_imports = sum(len(imps) for imps in imports_by_file.values()) + if total_imports > 0: + compliance = 1.0 - (len(violations) / max(total_imports, 1)) + else: + compliance = 1.0 + + return ArchitectureReport( + violations=violations, + layers_detected=self._group_files_by_layer(imports_by_file.keys(), layers), + circular_dependencies=circular, + compliance_score=max(0.0, compliance), + recommendations=self._generate_recommendations(violations), + ) + + def _get_file_imports(self, owner: str, repo: str, filepath: str) -> list[str]: + """Get imports from a file.""" + imports = [] + language = self._detect_language(filepath) + + if not language: + return imports + + try: + content_data = self.gitea.get_file_contents(owner, repo, filepath) + if content_data.get("content"): + content = base64.b64decode(content_data["content"]).decode( + "utf-8", errors="ignore" + ) + for line in content.splitlines(): + imports.extend(self._extract_imports(line, language)) + except Exception: + pass + + return imports + + def _detect_language(self, filepath: str) -> str | None: + """Detect programming language from file path.""" + ext_map = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".go": "go", + ".java": "java", + ".rb": "ruby", + } + ext = os.path.splitext(filepath)[1] + return ext_map.get(ext) + + def _extract_imports(self, line: str, language: str) -> list[str]: + """Extract import statements from a line of code.""" + imports = [] + line = line.strip() + + if language == "python": + # from x import y, import x + match = re.match(r"^(?:from\s+(\S+)|import\s+(\S+))", line) + if match: + imp = match.group(1) or match.group(2) + if imp: + imports.append(imp.split(".")[0]) + + elif language in ("javascript", "typescript"): + # import x from 'y', require('y') + match = re.search( + r"(?:from\s+['\"]([^'\"]+)['\"]|require\(['\"]([^'\"]+)['\"]\))", line + ) + if match: + imp = match.group(1) or match.group(2) + if imp and not imp.startswith("."): + imports.append(imp.split("/")[0]) + elif imp: + imports.append(imp) + + elif language == "go": + # import "package" + match = re.search(r'import\s+["\']([^"\']+)["\']', line) + if match: + imports.append(match.group(1).split("/")[-1]) + + elif language == "java": + # import package.Class + match = re.match(r"^import\s+(?:static\s+)?([^;]+);", line) + if match: + parts = match.group(1).split(".") + if len(parts) > 1: + imports.append(parts[-2]) # Package name + + return imports + + def _detect_layer(self, filepath: str, layers: dict) -> str | None: + """Detect which layer a file belongs to.""" + for layer_name, layer_config in layers.items(): + for pattern in layer_config.get("patterns", []): + if pattern.rstrip("/") in filepath: + return layer_name + return None + + def _detect_layer_from_import(self, import_path: str, layers: dict) -> str | None: + """Detect which layer an import refers to.""" + for layer_name, layer_config in layers.items(): + for pattern in layer_config.get("patterns", []): + pattern_name = pattern.rstrip("/").split("/")[-1] + if pattern_name in import_path or import_path.startswith(pattern_name): + return layer_name + return None + + def _detect_circular_dependencies( + self, imports_by_file: dict, layers: dict + ) -> list[tuple[str, str]]: + """Detect circular dependencies between layers.""" + circular = [] + + # Build layer dependency graph + layer_deps = {} + for file_path, imports in imports_by_file.items(): + source_layer = self._detect_layer(file_path, layers) + if not source_layer: + continue + + if source_layer not in layer_deps: + layer_deps[source_layer] = set() + + for imp in imports: + target_layer = self._detect_layer_from_import(imp, layers) + if target_layer and target_layer != source_layer: + layer_deps[source_layer].add(target_layer) + + # Check for circular dependencies + for layer_a, deps_a in layer_deps.items(): + for layer_b in deps_a: + if layer_b in layer_deps and layer_a in layer_deps.get(layer_b, set()): + pair = tuple(sorted([layer_a, layer_b])) + if pair not in circular: + circular.append(pair) + + return circular + + def _group_files_by_layer( + self, files: list[str], layers: dict + ) -> dict[str, list[str]]: + """Group files by their layer.""" + grouped = {} + for filepath in files: + layer = self._detect_layer(filepath, layers) + if layer: + if layer not in grouped: + grouped[layer] = [] + grouped[layer].append(filepath) + return grouped + + def _generate_recommendations( + self, violations: list[ArchitectureViolation] + ) -> list[str]: + """Generate recommendations based on violations.""" + recommendations = [] + + # Count violations by type + cross_layer = sum(1 for v in violations if v.violation_type == "cross_layer") + + if cross_layer > 0: + recommendations.append( + f"Fix {cross_layer} cross-layer violations by moving imports or creating interfaces" + ) + + if cross_layer > 5: + recommendations.append( + "Consider using dependency injection to reduce coupling between layers" + ) + + return recommendations + + def _format_architecture_report( + self, report: ArchitectureReport, user: str | None + ) -> str: + """Format the architecture report as a comment.""" + lines = [] + + if user: + lines.append(f"@{user}") + lines.append("") + + lines.extend( + [ + f"{self.AI_DISCLAIMER}", + "", + "## 🏗️ Architecture Compliance Check", + "", + "### Summary", + "", + f"| Metric | Value |", + f"|--------|-------|", + f"| Compliance Score | {report.compliance_score:.0%} |", + f"| Violations | {len(report.violations)} |", + f"| Circular Dependencies | {len(report.circular_dependencies)} |", + f"| Layers Detected | {len(report.layers_detected)} |", + "", + ] + ) + + # Compliance bar + filled = int(report.compliance_score * 10) + bar = "█" * filled + "░" * (10 - filled) + lines.append(f"`[{bar}]` {report.compliance_score:.0%}") + lines.append("") + + # Violations + if report.violations: + lines.append("### 🚨 Violations") + lines.append("") + + for v in report.violations[:10]: # Limit display + severity_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"} + lines.append( + f"{severity_emoji.get(v.severity, '⚪')} **{v.violation_type.upper()}** in `{v.file}`" + ) + lines.append(f" - {v.description}") + lines.append(f" - 💡 {v.recommendation}") + lines.append("") + + if len(report.violations) > 10: + lines.append(f"*... and {len(report.violations) - 10} more violations*") + lines.append("") + + # Circular dependencies + if report.circular_dependencies: + lines.append("### 🔄 Circular Dependencies") + lines.append("") + for a, b in report.circular_dependencies: + lines.append(f"- `{a}` ↔ `{b}`") + lines.append("") + + # Layers detected + if report.layers_detected: + lines.append("### 📁 Layers Detected") + lines.append("") + for layer, files in report.layers_detected.items(): + lines.append(f"- **{layer}**: {len(files)} files") + lines.append("") + + # Recommendations + if report.recommendations: + lines.append("### 💡 Recommendations") + lines.append("") + for rec in report.recommendations: + lines.append(f"- {rec}") + lines.append("") + + # Overall status + if report.compliance_score >= 0.9: + lines.append("---") + lines.append("✅ **Excellent architecture compliance!**") + elif report.compliance_score >= 0.7: + lines.append("---") + lines.append("⚠️ **Some architectural issues to address**") + else: + lines.append("---") + lines.append("❌ **Significant architectural violations detected**") + + return "\n".join(lines) diff --git a/tools/ai-review/agents/base_agent.py b/tools/ai-review/agents/base_agent.py index 7f163d3..b47225e 100644 --- a/tools/ai-review/agents/base_agent.py +++ b/tools/ai-review/agents/base_agent.py @@ -65,9 +65,10 @@ class BaseAgent(ABC): self.llm = llm_client or LLMClient.from_config(self.config) self.logger = logging.getLogger(self.__class__.__name__) - # Rate limiting + # Rate limiting - now configurable self._last_request_time = 0.0 - self._min_request_interval = 1.0 # seconds + rate_limits = self.config.get("rate_limits", {}) + self._min_request_interval = rate_limits.get("min_interval", 1.0) # seconds @staticmethod def _load_config() -> dict: diff --git a/tools/ai-review/agents/dependency_agent.py b/tools/ai-review/agents/dependency_agent.py new file mode 100644 index 0000000..e831305 --- /dev/null +++ b/tools/ai-review/agents/dependency_agent.py @@ -0,0 +1,548 @@ +"""Dependency Security Agent + +AI agent for scanning dependency files for known vulnerabilities +and outdated packages. Supports multiple package managers. +""" + +import base64 +import json +import logging +import os +import re +import subprocess +from dataclasses import dataclass, field +from typing import Any + +from agents.base_agent import AgentContext, AgentResult, BaseAgent + + +@dataclass +class DependencyFinding: + """A security finding in a dependency.""" + + package: str + version: str + severity: str # CRITICAL, HIGH, MEDIUM, LOW + vulnerability_id: str # CVE, GHSA, etc. + title: str + description: str + fixed_version: str | None = None + references: list[str] = field(default_factory=list) + + +@dataclass +class DependencyReport: + """Report of dependency analysis.""" + + total_packages: int + vulnerable_packages: int + outdated_packages: int + findings: list[DependencyFinding] + recommendations: list[str] + files_scanned: list[str] + + +class DependencyAgent(BaseAgent): + """Agent for scanning dependencies for security vulnerabilities.""" + + # Marker for dependency comments + DEP_AI_MARKER = "" + + # Supported dependency files + DEPENDENCY_FILES = { + "python": ["requirements.txt", "Pipfile", "pyproject.toml", "setup.py"], + "javascript": ["package.json", "package-lock.json", "yarn.lock"], + "ruby": ["Gemfile", "Gemfile.lock"], + "go": ["go.mod", "go.sum"], + "rust": ["Cargo.toml", "Cargo.lock"], + "java": ["pom.xml", "build.gradle", "build.gradle.kts"], + "php": ["composer.json", "composer.lock"], + "dotnet": ["*.csproj", "packages.config", "*.fsproj"], + } + + # Common vulnerable package patterns + KNOWN_VULNERABILITIES = { + "python": { + "requests": { + "< 2.31.0": "CVE-2023-32681 - Proxy-Authorization header leak" + }, + "urllib3": { + "< 2.0.7": "CVE-2023-45803 - Request body not stripped on redirects" + }, + "cryptography": {"< 41.0.0": "Multiple CVEs - Update recommended"}, + "pillow": {"< 10.0.0": "CVE-2023-4863 - WebP vulnerability"}, + "django": {"< 4.2.0": "Multiple security fixes"}, + "flask": {"< 2.3.0": "Security improvements"}, + "pyyaml": {"< 6.0": "CVE-2020-14343 - Arbitrary code execution"}, + "jinja2": {"< 3.1.0": "Security fixes"}, + }, + "javascript": { + "lodash": {"< 4.17.21": "CVE-2021-23337 - Prototype pollution"}, + "axios": {"< 1.6.0": "CVE-2023-45857 - CSRF vulnerability"}, + "express": {"< 4.18.0": "Security updates"}, + "jquery": {"< 3.5.0": "XSS vulnerabilities"}, + "minimist": {"< 1.2.6": "Prototype pollution"}, + "node-fetch": {"< 3.3.0": "Security fixes"}, + }, + } + + def can_handle(self, event_type: str, event_data: dict) -> bool: + """Check if this agent handles the given event.""" + agent_config = self.config.get("agents", {}).get("dependency", {}) + if not agent_config.get("enabled", True): + return False + + # Handle PR events that modify dependency files + if event_type == "pull_request": + action = event_data.get("action", "") + if action in ["opened", "synchronize"]: + # Check if any dependency files are modified + files = event_data.get("files", []) + for f in files: + if self._is_dependency_file(f.get("filename", "")): + return True + + # Handle @codebot check-deps command + if event_type == "issue_comment": + comment_body = event_data.get("comment", {}).get("body", "") + mention_prefix = self.config.get("interaction", {}).get( + "mention_prefix", "@codebot" + ) + if f"{mention_prefix} check-deps" in comment_body.lower(): + return True + + return False + + def _is_dependency_file(self, filename: str) -> bool: + """Check if a file is a dependency file.""" + basename = os.path.basename(filename) + for lang, files in self.DEPENDENCY_FILES.items(): + for pattern in files: + if pattern.startswith("*"): + if basename.endswith(pattern[1:]): + return True + elif basename == pattern: + return True + return False + + def execute(self, context: AgentContext) -> AgentResult: + """Execute the dependency agent.""" + self.logger.info(f"Scanning dependencies for {context.owner}/{context.repo}") + + actions_taken = [] + + # Determine if this is a command or PR event + if context.event_type == "issue_comment": + issue = context.event_data.get("issue", {}) + issue_number = issue.get("number") + comment_author = ( + context.event_data.get("comment", {}) + .get("user", {}) + .get("login", "user") + ) + else: + pr = context.event_data.get("pull_request", {}) + issue_number = pr.get("number") + comment_author = None + + # Collect dependency files + dep_files = self._collect_dependency_files(context.owner, context.repo) + if not dep_files: + message = "No dependency files found in repository." + if issue_number: + self.gitea.create_issue_comment( + context.owner, + context.repo, + issue_number, + f"{self.AI_DISCLAIMER}\n\n{message}", + ) + return AgentResult( + success=True, + message=message, + ) + + actions_taken.append(f"Found {len(dep_files)} dependency files") + + # Analyze dependencies + report = self._analyze_dependencies(context.owner, context.repo, dep_files) + actions_taken.append(f"Analyzed {report.total_packages} packages") + + # Run external scanners if available + external_findings = self._run_external_scanners(context.owner, context.repo) + if external_findings: + report.findings.extend(external_findings) + actions_taken.append( + f"External scanner found {len(external_findings)} issues" + ) + + # Generate and post report + if issue_number: + comment = self._format_dependency_report(report, comment_author) + self.upsert_comment( + context.owner, + context.repo, + issue_number, + comment, + marker=self.DEP_AI_MARKER, + ) + actions_taken.append("Posted dependency report") + + return AgentResult( + success=True, + message=f"Dependency scan complete: {report.vulnerable_packages} vulnerable, {report.outdated_packages} outdated", + data={ + "total_packages": report.total_packages, + "vulnerable_packages": report.vulnerable_packages, + "outdated_packages": report.outdated_packages, + "findings_count": len(report.findings), + }, + actions_taken=actions_taken, + ) + + def _collect_dependency_files( + self, owner: str, repo: str + ) -> dict[str, dict[str, Any]]: + """Collect all dependency files from the repository.""" + dep_files = {} + + # Common paths to check + paths_to_check = [ + "", # Root + "backend/", + "frontend/", + "api/", + "services/", + ] + + for base_path in paths_to_check: + for lang, filenames in self.DEPENDENCY_FILES.items(): + for filename in filenames: + if filename.startswith("*"): + continue # Skip glob patterns for now + + filepath = f"{base_path}{filename}".lstrip("/") + try: + content_data = self.gitea.get_file_contents( + owner, repo, filepath + ) + if content_data.get("content"): + content = base64.b64decode(content_data["content"]).decode( + "utf-8", errors="ignore" + ) + dep_files[filepath] = { + "language": lang, + "content": content, + } + except Exception: + pass # File doesn't exist + + return dep_files + + def _analyze_dependencies( + self, owner: str, repo: str, dep_files: dict + ) -> DependencyReport: + """Analyze dependency files for vulnerabilities.""" + findings = [] + total_packages = 0 + vulnerable_count = 0 + outdated_count = 0 + recommendations = [] + files_scanned = list(dep_files.keys()) + + for filepath, file_info in dep_files.items(): + lang = file_info["language"] + content = file_info["content"] + + if lang == "python": + packages = self._parse_python_deps(content, filepath) + elif lang == "javascript": + packages = self._parse_javascript_deps(content, filepath) + else: + packages = [] + + total_packages += len(packages) + + # Check for known vulnerabilities + known_vulns = self.KNOWN_VULNERABILITIES.get(lang, {}) + for pkg_name, version in packages: + if pkg_name.lower() in known_vulns: + vuln_info = known_vulns[pkg_name.lower()] + for version_constraint, vuln_desc in vuln_info.items(): + if self._version_matches_constraint( + version, version_constraint + ): + findings.append( + DependencyFinding( + package=pkg_name, + version=version or "unknown", + severity="HIGH", + vulnerability_id=vuln_desc.split(" - ")[0] + if " - " in vuln_desc + else "VULN", + title=vuln_desc, + description=f"Package {pkg_name} version {version} has known vulnerabilities", + fixed_version=version_constraint.replace("< ", ""), + ) + ) + vulnerable_count += 1 + + # Add recommendations + if vulnerable_count > 0: + recommendations.append( + f"Update {vulnerable_count} packages with known vulnerabilities" + ) + if total_packages > 50: + recommendations.append( + "Consider auditing dependencies to reduce attack surface" + ) + + return DependencyReport( + total_packages=total_packages, + vulnerable_packages=vulnerable_count, + outdated_packages=outdated_count, + findings=findings, + recommendations=recommendations, + files_scanned=files_scanned, + ) + + def _parse_python_deps( + self, content: str, filepath: str + ) -> list[tuple[str, str | None]]: + """Parse Python dependency file.""" + packages = [] + + if "requirements" in filepath.lower(): + # requirements.txt format + for line in content.splitlines(): + line = line.strip() + if not line or line.startswith("#") or line.startswith("-"): + continue + + # Parse package==version, package>=version, package + match = re.match(r"([a-zA-Z0-9_-]+)([<>=!]+)?(.+)?", line) + if match: + pkg_name = match.group(1) + version = match.group(3) if match.group(3) else None + packages.append((pkg_name, version)) + + elif filepath.endswith("pyproject.toml"): + # pyproject.toml format + in_deps = False + for line in content.splitlines(): + if ( + "[project.dependencies]" in line + or "[tool.poetry.dependencies]" in line + ): + in_deps = True + continue + if in_deps: + if line.startswith("["): + in_deps = False + continue + match = re.match(r'"?([a-zA-Z0-9_-]+)"?\s*[=<>]', line) + if match: + packages.append((match.group(1), None)) + + return packages + + def _parse_javascript_deps( + self, content: str, filepath: str + ) -> list[tuple[str, str | None]]: + """Parse JavaScript dependency file.""" + packages = [] + + if filepath.endswith("package.json"): + try: + data = json.loads(content) + for dep_type in ["dependencies", "devDependencies"]: + deps = data.get(dep_type, {}) + for name, version in deps.items(): + # Strip version prefixes like ^, ~, >= + clean_version = re.sub(r"^[\^~>=<]+", "", version) + packages.append((name, clean_version)) + except json.JSONDecodeError: + pass + + return packages + + def _version_matches_constraint(self, version: str | None, constraint: str) -> bool: + """Check if version matches a vulnerability constraint.""" + if not version: + return True # Assume vulnerable if version unknown + + # Simple version comparison + if constraint.startswith("< "): + target = constraint[2:] + try: + return self._compare_versions(version, target) < 0 + except Exception: + return False + + return False + + def _compare_versions(self, v1: str, v2: str) -> int: + """Compare two version strings. Returns -1, 0, or 1.""" + + def normalize(v): + return [int(x) for x in re.sub(r"[^0-9.]", "", v).split(".") if x] + + try: + parts1 = normalize(v1) + parts2 = normalize(v2) + + for i in range(max(len(parts1), len(parts2))): + p1 = parts1[i] if i < len(parts1) else 0 + p2 = parts2[i] if i < len(parts2) else 0 + if p1 < p2: + return -1 + if p1 > p2: + return 1 + return 0 + except Exception: + return 0 + + def _run_external_scanners(self, owner: str, repo: str) -> list[DependencyFinding]: + """Run external vulnerability scanners if available.""" + findings = [] + agent_config = self.config.get("agents", {}).get("dependency", {}) + + # Try pip-audit for Python + if agent_config.get("pip_audit", False): + try: + result = subprocess.run( + ["pip-audit", "--format", "json"], + capture_output=True, + text=True, + timeout=60, + ) + if result.returncode == 0: + data = json.loads(result.stdout) + for vuln in data.get("vulnerabilities", []): + findings.append( + DependencyFinding( + package=vuln.get("name", ""), + version=vuln.get("version", ""), + severity=vuln.get("severity", "MEDIUM"), + vulnerability_id=vuln.get("id", ""), + title=vuln.get("description", "")[:100], + description=vuln.get("description", ""), + fixed_version=vuln.get("fix_versions", [None])[0], + ) + ) + except Exception as e: + self.logger.debug(f"pip-audit not available: {e}") + + # Try npm audit for JavaScript + if agent_config.get("npm_audit", False): + try: + result = subprocess.run( + ["npm", "audit", "--json"], + capture_output=True, + text=True, + timeout=60, + ) + data = json.loads(result.stdout) + for vuln_id, vuln in data.get("vulnerabilities", {}).items(): + findings.append( + DependencyFinding( + package=vuln.get("name", vuln_id), + version=vuln.get("range", ""), + severity=vuln.get("severity", "moderate").upper(), + vulnerability_id=vuln_id, + title=vuln.get("title", ""), + description=vuln.get("overview", ""), + fixed_version=vuln.get("fixAvailable", {}).get("version"), + ) + ) + except Exception as e: + self.logger.debug(f"npm audit not available: {e}") + + return findings + + def _format_dependency_report( + self, report: DependencyReport, user: str | None = None + ) -> str: + """Format the dependency report as a comment.""" + lines = [] + + if user: + lines.append(f"@{user}") + lines.append("") + + lines.extend( + [ + f"{self.AI_DISCLAIMER}", + "", + "## 🔍 Dependency Security Scan", + "", + "### Summary", + "", + f"| Metric | Value |", + f"|--------|-------|", + f"| Total Packages | {report.total_packages} |", + f"| Vulnerable | {report.vulnerable_packages} |", + f"| Outdated | {report.outdated_packages} |", + f"| Files Scanned | {len(report.files_scanned)} |", + "", + ] + ) + + # Findings by severity + if report.findings: + lines.append("### 🚨 Security Findings") + lines.append("") + + # Group by severity + by_severity = {"CRITICAL": [], "HIGH": [], "MEDIUM": [], "LOW": []} + for finding in report.findings: + sev = finding.severity.upper() + if sev in by_severity: + by_severity[sev].append(finding) + + severity_emoji = { + "CRITICAL": "🔴", + "HIGH": "🟠", + "MEDIUM": "🟡", + "LOW": "🔵", + } + + for severity in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]: + findings = by_severity[severity] + if findings: + lines.append(f"#### {severity_emoji[severity]} {severity}") + lines.append("") + for f in findings[:10]: # Limit display + lines.append(f"- **{f.package}** `{f.version}`") + lines.append(f" - {f.vulnerability_id}: {f.title}") + if f.fixed_version: + lines.append(f" - ✅ Fix: Upgrade to `{f.fixed_version}`") + if len(findings) > 10: + lines.append(f" - ... and {len(findings) - 10} more") + lines.append("") + + # Files scanned + lines.append("### 📁 Files Scanned") + lines.append("") + for f in report.files_scanned: + lines.append(f"- `{f}`") + lines.append("") + + # Recommendations + if report.recommendations: + lines.append("### 💡 Recommendations") + lines.append("") + for rec in report.recommendations: + lines.append(f"- {rec}") + lines.append("") + + # Overall status + if report.vulnerable_packages == 0: + lines.append("---") + lines.append("✅ **No known vulnerabilities detected**") + else: + lines.append("---") + lines.append( + f"⚠️ **{report.vulnerable_packages} vulnerable packages require attention**" + ) + + return "\n".join(lines) diff --git a/tools/ai-review/agents/issue_agent.py b/tools/ai-review/agents/issue_agent.py index 3b4418a..a24216b 100644 --- a/tools/ai-review/agents/issue_agent.py +++ b/tools/ai-review/agents/issue_agent.py @@ -365,9 +365,20 @@ class IssueAgent(BaseAgent): "commands", ["explain", "suggest", "security", "summarize", "triage"] ) - # Also check for setup-labels command (not in config since it's a setup command) - if f"{mention_prefix} setup-labels" in body.lower(): - return "setup-labels" + # Built-in commands not in config + builtin_commands = [ + "setup-labels", + "check-deps", + "suggest-tests", + "refactor-suggest", + "architecture", + "arch-check", + ] + + # Check built-in commands first + for command in builtin_commands: + if f"{mention_prefix} {command}" in body.lower(): + return command for command in commands: if f"{mention_prefix} {command}" in body.lower(): @@ -392,6 +403,14 @@ class IssueAgent(BaseAgent): return self._command_triage(context, issue) elif command == "setup-labels": return self._command_setup_labels(context, issue) + elif command == "check-deps": + return self._command_check_deps(context) + elif command == "suggest-tests": + return self._command_suggest_tests(context) + elif command == "refactor-suggest": + return self._command_refactor_suggest(context) + elif command == "architecture" or command == "arch-check": + return self._command_architecture(context) return f"{self.AI_DISCLAIMER}\n\nSorry, I don't understand the command `{command}`." @@ -464,6 +483,12 @@ Be practical and concise.""" - `{mention_prefix} suggest` - Solution suggestions or next steps - `{mention_prefix} security` - Security-focused analysis of the issue +### Code Quality & Security +- `{mention_prefix} check-deps` - Scan dependencies for security vulnerabilities +- `{mention_prefix} suggest-tests` - Suggest test cases for changed/new code +- `{mention_prefix} refactor-suggest` - Suggest refactoring opportunities +- `{mention_prefix} architecture` - Check architecture compliance (alias: `arch-check`) + ### Interactive Chat - `{mention_prefix} [question]` - Ask questions about the codebase (uses search & file reading tools) - Example: `{mention_prefix} how does authentication work?` @@ -494,9 +519,19 @@ PR reviews run automatically when you open or update a pull request. The bot pro {mention_prefix} triage ``` -**Get help understanding:** +**Check for dependency vulnerabilities:** ``` -{mention_prefix} explain +{mention_prefix} check-deps +``` + +**Get test suggestions:** +``` +{mention_prefix} suggest-tests +``` + +**Check architecture compliance:** +``` +{mention_prefix} architecture ``` **Ask about the codebase:** @@ -504,11 +539,6 @@ PR reviews run automatically when you open or update a pull request. The bot pro {mention_prefix} how does the authentication system work? ``` -**Setup repository labels:** -``` -{mention_prefix} setup-labels -``` - --- *For full documentation, see the [README](https://github.com/YourOrg/OpenRabbit/blob/main/README.md)* @@ -854,3 +884,145 @@ PR reviews run automatically when you open or update a pull request. The bot pro return f"{prefix} - {value}" else: # colon or unknown return base_name + + def _command_check_deps(self, context: AgentContext) -> str: + """Check dependencies for security vulnerabilities.""" + try: + from agents.dependency_agent import DependencyAgent + + agent = DependencyAgent(config=self.config) + result = agent.run(context) + + if result.success: + return result.data.get( + "report", f"{self.AI_DISCLAIMER}\n\n{result.message}" + ) + else: + return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Failed**\n\n{result.error or result.message}" + except ImportError: + return f"{self.AI_DISCLAIMER}\n\n**Dependency Agent Not Available**\n\nThe dependency security agent is not installed." + except Exception as e: + self.logger.error(f"Dependency check failed: {e}") + return f"{self.AI_DISCLAIMER}\n\n**Dependency Check Error**\n\n{e}" + + def _command_suggest_tests(self, context: AgentContext) -> str: + """Suggest tests for changed or new code.""" + try: + from agents.test_coverage_agent import TestCoverageAgent + + agent = TestCoverageAgent(config=self.config) + result = agent.run(context) + + if result.success: + return result.data.get( + "report", f"{self.AI_DISCLAIMER}\n\n{result.message}" + ) + else: + return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Failed**\n\n{result.error or result.message}" + except ImportError: + return f"{self.AI_DISCLAIMER}\n\n**Test Coverage Agent Not Available**\n\nThe test coverage agent is not installed." + except Exception as e: + self.logger.error(f"Test suggestion failed: {e}") + return f"{self.AI_DISCLAIMER}\n\n**Test Suggestion Error**\n\n{e}" + + def _command_architecture(self, context: AgentContext) -> str: + """Check architecture compliance.""" + try: + from agents.architecture_agent import ArchitectureAgent + + agent = ArchitectureAgent(config=self.config) + result = agent.run(context) + + if result.success: + return result.data.get( + "report", f"{self.AI_DISCLAIMER}\n\n{result.message}" + ) + else: + return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Failed**\n\n{result.error or result.message}" + except ImportError: + return f"{self.AI_DISCLAIMER}\n\n**Architecture Agent Not Available**\n\nThe architecture compliance agent is not installed." + except Exception as e: + self.logger.error(f"Architecture check failed: {e}") + return f"{self.AI_DISCLAIMER}\n\n**Architecture Check Error**\n\n{e}" + + def _command_refactor_suggest(self, context: AgentContext) -> str: + """Suggest refactoring opportunities.""" + issue = context.event_data.get("issue", {}) + title = issue.get("title", "") + body = issue.get("body", "") + + # Use LLM to analyze for refactoring opportunities + prompt = f"""Analyze the following issue/context and suggest refactoring opportunities: + +Issue Title: {title} +Issue Body: {body} + +Based on common refactoring patterns, suggest: +1. Code smell detection (if any code is referenced) +2. Design pattern opportunities +3. Complexity reduction suggestions +4. DRY principle violations +5. SOLID principle improvements + +Format your response as a structured report with actionable recommendations. +If no code is referenced in the issue, provide general refactoring guidance based on the context. + +Return as JSON: +{{ + "summary": "Brief summary of refactoring opportunities", + "suggestions": [ + {{ + "category": "Code Smell | Design Pattern | Complexity | DRY | SOLID", + "title": "Short title", + "description": "Detailed description", + "priority": "high | medium | low", + "effort": "small | medium | large" + }} + ], + "general_advice": "Any general refactoring advice" +}}""" + + try: + result = self.call_llm_json(prompt) + + lines = [f"{self.AI_DISCLAIMER}\n"] + lines.append("## Refactoring Suggestions\n") + + if result.get("summary"): + lines.append(f"**Summary:** {result['summary']}\n") + + suggestions = result.get("suggestions", []) + if suggestions: + lines.append("### Recommendations\n") + lines.append("| Priority | Category | Suggestion | Effort |") + lines.append("|----------|----------|------------|--------|") + + for s in suggestions: + priority = s.get("priority", "medium").upper() + priority_icon = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🟢"}.get( + priority, "⚪" + ) + lines.append( + f"| {priority_icon} {priority} | {s.get('category', 'General')} | " + f"**{s.get('title', 'Suggestion')}** | {s.get('effort', 'medium')} |" + ) + + lines.append("") + + # Detailed descriptions + lines.append("### Details\n") + for i, s in enumerate(suggestions, 1): + lines.append(f"**{i}. {s.get('title', 'Suggestion')}**") + lines.append(f"{s.get('description', 'No description')}\n") + + if result.get("general_advice"): + lines.append("### General Advice\n") + lines.append(result["general_advice"]) + + return "\n".join(lines) + + except Exception as e: + self.logger.error(f"Refactor suggestion failed: {e}") + return ( + f"{self.AI_DISCLAIMER}\n\n**Refactor Suggestion Failed**\n\nError: {e}" + ) diff --git a/tools/ai-review/agents/test_coverage_agent.py b/tools/ai-review/agents/test_coverage_agent.py new file mode 100644 index 0000000..da87cc2 --- /dev/null +++ b/tools/ai-review/agents/test_coverage_agent.py @@ -0,0 +1,480 @@ +"""Test Coverage Agent + +AI agent for analyzing code changes and suggesting test cases. +Helps improve test coverage by identifying untested code paths. +""" + +import base64 +import os +import re +from dataclasses import dataclass, field + +from agents.base_agent import AgentContext, AgentResult, BaseAgent + + +@dataclass +class TestSuggestion: + """A suggested test case.""" + + function_name: str + file_path: str + test_type: str # unit, integration, edge_case + description: str + example_code: str | None = None + priority: str = "MEDIUM" # HIGH, MEDIUM, LOW + + +@dataclass +class CoverageReport: + """Report of test coverage analysis.""" + + functions_analyzed: int + functions_with_tests: int + functions_without_tests: int + suggestions: list[TestSuggestion] + existing_tests: list[str] + coverage_estimate: float + + +class TestCoverageAgent(BaseAgent): + """Agent for analyzing test coverage and suggesting tests.""" + + # Marker for test coverage comments + TEST_AI_MARKER = "" + + # Test file patterns by language + TEST_PATTERNS = { + "python": [r"test_.*\.py$", r".*_test\.py$", r"tests?/.*\.py$"], + "javascript": [ + r".*\.test\.[jt]sx?$", + r".*\.spec\.[jt]sx?$", + r"__tests__/.*\.[jt]sx?$", + ], + "go": [r".*_test\.go$"], + "rust": [r"tests?/.*\.rs$"], + "java": [r".*Test\.java$", r".*Tests\.java$"], + "ruby": [r".*_spec\.rb$", r"test_.*\.rb$"], + } + + # Function/method patterns by language + FUNCTION_PATTERNS = { + "python": r"^\s*(?:async\s+)?def\s+(\w+)\s*\(", + "javascript": r"(?:function\s+(\w+)|(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?(?:function|\([^)]*\)\s*=>))", + "go": r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\(", + "rust": r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)", + "java": r"(?:public|private|protected)\s+(?:static\s+)?(?:\w+\s+)?(\w+)\s*\([^)]*\)\s*\{", + "ruby": r"^\s*def\s+(\w+)", + } + + def can_handle(self, event_type: str, event_data: dict) -> bool: + """Check if this agent handles the given event.""" + agent_config = self.config.get("agents", {}).get("test_coverage", {}) + if not agent_config.get("enabled", True): + return False + + # Handle @codebot suggest-tests command + if event_type == "issue_comment": + comment_body = event_data.get("comment", {}).get("body", "") + mention_prefix = self.config.get("interaction", {}).get( + "mention_prefix", "@codebot" + ) + if f"{mention_prefix} suggest-tests" in comment_body.lower(): + return True + + return False + + def execute(self, context: AgentContext) -> AgentResult: + """Execute the test coverage agent.""" + self.logger.info(f"Analyzing test coverage for {context.owner}/{context.repo}") + + actions_taken = [] + + # Get issue/PR number and author + issue = context.event_data.get("issue", {}) + issue_number = issue.get("number") + comment_author = ( + context.event_data.get("comment", {}).get("user", {}).get("login", "user") + ) + + # Check if this is a PR + is_pr = issue.get("pull_request") is not None + + if is_pr: + # Analyze PR diff for changed functions + diff = self._get_pr_diff(context.owner, context.repo, issue_number) + changed_functions = self._extract_changed_functions(diff) + actions_taken.append(f"Analyzed {len(changed_functions)} changed functions") + else: + # Analyze entire repository + changed_functions = self._analyze_repository(context.owner, context.repo) + actions_taken.append( + f"Analyzed {len(changed_functions)} functions in repository" + ) + + # Find existing tests + existing_tests = self._find_existing_tests(context.owner, context.repo) + actions_taken.append(f"Found {len(existing_tests)} existing test files") + + # Generate test suggestions using LLM + report = self._generate_suggestions( + context.owner, context.repo, changed_functions, existing_tests + ) + + # Post report + if issue_number: + comment = self._format_coverage_report(report, comment_author, is_pr) + self.upsert_comment( + context.owner, + context.repo, + issue_number, + comment, + marker=self.TEST_AI_MARKER, + ) + actions_taken.append("Posted test coverage report") + + return AgentResult( + success=True, + message=f"Generated {len(report.suggestions)} test suggestions", + data={ + "functions_analyzed": report.functions_analyzed, + "suggestions_count": len(report.suggestions), + "coverage_estimate": report.coverage_estimate, + }, + actions_taken=actions_taken, + ) + + def _get_pr_diff(self, owner: str, repo: str, pr_number: int) -> str: + """Get the PR diff.""" + try: + return self.gitea.get_pull_request_diff(owner, repo, pr_number) + except Exception as e: + self.logger.error(f"Failed to get PR diff: {e}") + return "" + + def _extract_changed_functions(self, diff: str) -> list[dict]: + """Extract changed functions from diff.""" + functions = [] + current_file = None + current_language = None + + for line in diff.splitlines(): + # Track current file + if line.startswith("diff --git"): + match = re.search(r"b/(.+)$", line) + if match: + current_file = match.group(1) + current_language = self._detect_language(current_file) + + # Look for function definitions in added lines + if line.startswith("+") and not line.startswith("+++"): + if current_language and current_language in self.FUNCTION_PATTERNS: + pattern = self.FUNCTION_PATTERNS[current_language] + match = re.search(pattern, line[1:]) # Skip the + prefix + if match: + func_name = next(g for g in match.groups() if g) + functions.append( + { + "name": func_name, + "file": current_file, + "language": current_language, + "line": line[1:].strip(), + } + ) + + return functions + + def _analyze_repository(self, owner: str, repo: str) -> list[dict]: + """Analyze repository for functions without tests.""" + functions = [] + code_extensions = {".py", ".js", ".ts", ".go", ".rs", ".java", ".rb"} + + # Get repository contents (limited to avoid API exhaustion) + try: + contents = self.gitea.get_file_contents(owner, repo, "") + if isinstance(contents, list): + for item in contents[:50]: # Limit files + if item.get("type") == "file": + filepath = item.get("path", "") + ext = os.path.splitext(filepath)[1] + if ext in code_extensions: + file_functions = self._extract_functions_from_file( + owner, repo, filepath + ) + functions.extend(file_functions) + except Exception as e: + self.logger.warning(f"Failed to analyze repository: {e}") + + return functions[:100] # Limit total functions + + def _extract_functions_from_file( + self, owner: str, repo: str, filepath: str + ) -> list[dict]: + """Extract function definitions from a file.""" + functions = [] + language = self._detect_language(filepath) + + if not language or language not in self.FUNCTION_PATTERNS: + return functions + + try: + content_data = self.gitea.get_file_contents(owner, repo, filepath) + if content_data.get("content"): + content = base64.b64decode(content_data["content"]).decode( + "utf-8", errors="ignore" + ) + + pattern = self.FUNCTION_PATTERNS[language] + for i, line in enumerate(content.splitlines(), 1): + match = re.search(pattern, line) + if match: + func_name = next((g for g in match.groups() if g), None) + if func_name and not func_name.startswith("_"): + functions.append( + { + "name": func_name, + "file": filepath, + "language": language, + "line_number": i, + } + ) + except Exception: + pass + + return functions + + def _detect_language(self, filepath: str) -> str | None: + """Detect programming language from file path.""" + ext_map = { + ".py": "python", + ".js": "javascript", + ".jsx": "javascript", + ".ts": "javascript", + ".tsx": "javascript", + ".go": "go", + ".rs": "rust", + ".java": "java", + ".rb": "ruby", + } + ext = os.path.splitext(filepath)[1] + return ext_map.get(ext) + + def _find_existing_tests(self, owner: str, repo: str) -> list[str]: + """Find existing test files in the repository.""" + test_files = [] + + # Common test directories + test_dirs = ["tests", "test", "__tests__", "spec"] + + for test_dir in test_dirs: + try: + contents = self.gitea.get_file_contents(owner, repo, test_dir) + if isinstance(contents, list): + for item in contents: + if item.get("type") == "file": + test_files.append(item.get("path", "")) + except Exception: + pass + + # Also check root for test files + try: + contents = self.gitea.get_file_contents(owner, repo, "") + if isinstance(contents, list): + for item in contents: + if item.get("type") == "file": + filepath = item.get("path", "") + if self._is_test_file(filepath): + test_files.append(filepath) + except Exception: + pass + + return test_files + + def _is_test_file(self, filepath: str) -> bool: + """Check if a file is a test file.""" + for lang, patterns in self.TEST_PATTERNS.items(): + for pattern in patterns: + if re.search(pattern, filepath): + return True + return False + + def _generate_suggestions( + self, + owner: str, + repo: str, + functions: list[dict], + existing_tests: list[str], + ) -> CoverageReport: + """Generate test suggestions using LLM.""" + suggestions = [] + + # Build prompt for LLM + if functions: + functions_text = "\n".join( + [ + f"- {f['name']} in {f['file']} ({f['language']})" + for f in functions[:20] # Limit for prompt size + ] + ) + + prompt = f"""Analyze these functions and suggest test cases: + +Functions to test: +{functions_text} + +Existing test files: +{", ".join(existing_tests[:10]) if existing_tests else "None found"} + +For each function, suggest: +1. What to test (happy path, edge cases, error handling) +2. Priority (HIGH for public APIs, MEDIUM for internal, LOW for utilities) +3. Brief example test code if possible + +Respond in JSON format: +{{ + "suggestions": [ + {{ + "function_name": "function_name", + "file_path": "path/to/file", + "test_type": "unit|integration|edge_case", + "description": "What to test", + "example_code": "brief example or null", + "priority": "HIGH|MEDIUM|LOW" + }} + ], + "coverage_estimate": 0.0 to 1.0 +}} +""" + + try: + result = self.call_llm_json(prompt) + + for s in result.get("suggestions", []): + suggestions.append( + TestSuggestion( + function_name=s.get("function_name", ""), + file_path=s.get("file_path", ""), + test_type=s.get("test_type", "unit"), + description=s.get("description", ""), + example_code=s.get("example_code"), + priority=s.get("priority", "MEDIUM"), + ) + ) + + coverage_estimate = result.get("coverage_estimate", 0.5) + + except Exception as e: + self.logger.warning(f"LLM suggestion failed: {e}") + # Generate basic suggestions without LLM + for f in functions[:10]: + suggestions.append( + TestSuggestion( + function_name=f["name"], + file_path=f["file"], + test_type="unit", + description=f"Add unit tests for {f['name']}", + priority="MEDIUM", + ) + ) + coverage_estimate = 0.5 + + else: + coverage_estimate = 1.0 if existing_tests else 0.0 + + # Estimate functions with tests + functions_with_tests = int(len(functions) * coverage_estimate) + + return CoverageReport( + functions_analyzed=len(functions), + functions_with_tests=functions_with_tests, + functions_without_tests=len(functions) - functions_with_tests, + suggestions=suggestions, + existing_tests=existing_tests, + coverage_estimate=coverage_estimate, + ) + + def _format_coverage_report( + self, report: CoverageReport, user: str | None, is_pr: bool + ) -> str: + """Format the coverage report as a comment.""" + lines = [] + + if user: + lines.append(f"@{user}") + lines.append("") + + lines.extend( + [ + f"{self.AI_DISCLAIMER}", + "", + "## 🧪 Test Coverage Analysis", + "", + "### Summary", + "", + f"| Metric | Value |", + f"|--------|-------|", + f"| Functions Analyzed | {report.functions_analyzed} |", + f"| Estimated Coverage | {report.coverage_estimate:.0%} |", + f"| Test Files Found | {len(report.existing_tests)} |", + f"| Suggestions | {len(report.suggestions)} |", + "", + ] + ) + + # Suggestions by priority + if report.suggestions: + lines.append("### 💡 Test Suggestions") + lines.append("") + + # Group by priority + by_priority = {"HIGH": [], "MEDIUM": [], "LOW": []} + for s in report.suggestions: + if s.priority in by_priority: + by_priority[s.priority].append(s) + + priority_emoji = {"HIGH": "🔴", "MEDIUM": "🟡", "LOW": "🔵"} + + for priority in ["HIGH", "MEDIUM", "LOW"]: + suggestions = by_priority[priority] + if suggestions: + lines.append(f"#### {priority_emoji[priority]} {priority} Priority") + lines.append("") + for s in suggestions[:5]: # Limit display + lines.append(f"**`{s.function_name}`** in `{s.file_path}`") + lines.append(f"- Type: {s.test_type}") + lines.append(f"- {s.description}") + if s.example_code: + lines.append(f"```") + lines.append(s.example_code[:200]) + lines.append(f"```") + lines.append("") + if len(suggestions) > 5: + lines.append(f"*... and {len(suggestions) - 5} more*") + lines.append("") + + # Existing test files + if report.existing_tests: + lines.append("### 📁 Existing Test Files") + lines.append("") + for f in report.existing_tests[:10]: + lines.append(f"- `{f}`") + if len(report.existing_tests) > 10: + lines.append(f"- *... and {len(report.existing_tests) - 10} more*") + lines.append("") + + # Coverage bar + lines.append("### 📊 Coverage Estimate") + lines.append("") + filled = int(report.coverage_estimate * 10) + bar = "█" * filled + "░" * (10 - filled) + lines.append(f"`[{bar}]` {report.coverage_estimate:.0%}") + lines.append("") + + # Recommendations + if report.coverage_estimate < 0.8: + lines.append("---") + lines.append("⚠️ **Coverage below 80%** - Consider adding more tests") + elif report.coverage_estimate >= 0.8: + lines.append("---") + lines.append("✅ **Good test coverage!**") + + return "\n".join(lines) diff --git a/tools/ai-review/clients/llm_client.py b/tools/ai-review/clients/llm_client.py index 2e066fb..3a0a0e4 100644 --- a/tools/ai-review/clients/llm_client.py +++ b/tools/ai-review/clients/llm_client.py @@ -77,11 +77,13 @@ class OpenAIProvider(BaseLLMProvider): model: str = "gpt-4o-mini", temperature: float = 0, max_tokens: int = 4096, + timeout: int = 120, ): self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") self.model = model self.temperature = temperature self.max_tokens = max_tokens + self.timeout = timeout self.api_url = "https://api.openai.com/v1/chat/completions" def call(self, prompt: str, **kwargs) -> LLMResponse: @@ -101,7 +103,7 @@ class OpenAIProvider(BaseLLMProvider): "max_tokens": kwargs.get("max_tokens", self.max_tokens), "messages": [{"role": "user", "content": prompt}], }, - timeout=120, + timeout=self.timeout, ) response.raise_for_status() data = response.json() @@ -145,7 +147,7 @@ class OpenAIProvider(BaseLLMProvider): "Content-Type": "application/json", }, json=request_body, - timeout=120, + timeout=self.timeout, ) response.raise_for_status() data = response.json() @@ -186,11 +188,13 @@ class OpenRouterProvider(BaseLLMProvider): model: str = "anthropic/claude-3.5-sonnet", temperature: float = 0, max_tokens: int = 4096, + timeout: int = 120, ): self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY", "") self.model = model self.temperature = temperature self.max_tokens = max_tokens + self.timeout = timeout self.api_url = "https://openrouter.ai/api/v1/chat/completions" def call(self, prompt: str, **kwargs) -> LLMResponse: @@ -210,7 +214,7 @@ class OpenRouterProvider(BaseLLMProvider): "max_tokens": kwargs.get("max_tokens", self.max_tokens), "messages": [{"role": "user", "content": prompt}], }, - timeout=120, + timeout=self.timeout, ) response.raise_for_status() data = response.json() @@ -254,7 +258,7 @@ class OpenRouterProvider(BaseLLMProvider): "Content-Type": "application/json", }, json=request_body, - timeout=120, + timeout=self.timeout, ) response.raise_for_status() data = response.json() @@ -294,10 +298,12 @@ class OllamaProvider(BaseLLMProvider): host: str | None = None, model: str = "codellama:13b", temperature: float = 0, + timeout: int = 300, ): self.host = host or os.environ.get("OLLAMA_HOST", "http://localhost:11434") self.model = model self.temperature = temperature + self.timeout = timeout def call(self, prompt: str, **kwargs) -> LLMResponse: """Call Ollama API.""" @@ -311,7 +317,7 @@ class OllamaProvider(BaseLLMProvider): "temperature": kwargs.get("temperature", self.temperature), }, }, - timeout=300, # Longer timeout for local models + timeout=self.timeout, ) response.raise_for_status() data = response.json() @@ -477,12 +483,18 @@ class LLMClient: provider = config.get("provider", "openai") provider_config = {} + # Get timeout configuration + timeouts = config.get("timeouts", {}) + llm_timeout = timeouts.get("llm", 120) + ollama_timeout = timeouts.get("ollama", 300) + # Map config keys to provider-specific settings if provider == "openai": provider_config = { "model": config.get("model", {}).get("openai", "gpt-4o-mini"), "temperature": config.get("temperature", 0), "max_tokens": config.get("max_tokens", 16000), + "timeout": llm_timeout, } elif provider == "openrouter": provider_config = { @@ -491,11 +503,13 @@ class LLMClient: ), "temperature": config.get("temperature", 0), "max_tokens": config.get("max_tokens", 16000), + "timeout": llm_timeout, } elif provider == "ollama": provider_config = { "model": config.get("model", {}).get("ollama", "codellama:13b"), "temperature": config.get("temperature", 0), + "timeout": ollama_timeout, } return cls(provider=provider, config=provider_config) diff --git a/tools/ai-review/clients/providers/__init__.py b/tools/ai-review/clients/providers/__init__.py new file mode 100644 index 0000000..6fdb693 --- /dev/null +++ b/tools/ai-review/clients/providers/__init__.py @@ -0,0 +1,27 @@ +"""LLM Providers Package + +This package contains additional LLM provider implementations +beyond the core providers in llm_client.py. + +Providers: +- AnthropicProvider: Direct Anthropic Claude API +- AzureOpenAIProvider: Azure OpenAI Service with API key auth +- AzureOpenAIWithAADProvider: Azure OpenAI with Azure AD auth +- GeminiProvider: Google Gemini API (public) +- VertexAIGeminiProvider: Google Vertex AI Gemini (enterprise GCP) +""" + +from clients.providers.anthropic_provider import AnthropicProvider +from clients.providers.azure_provider import ( + AzureOpenAIProvider, + AzureOpenAIWithAADProvider, +) +from clients.providers.gemini_provider import GeminiProvider, VertexAIGeminiProvider + +__all__ = [ + "AnthropicProvider", + "AzureOpenAIProvider", + "AzureOpenAIWithAADProvider", + "GeminiProvider", + "VertexAIGeminiProvider", +] diff --git a/tools/ai-review/clients/providers/anthropic_provider.py b/tools/ai-review/clients/providers/anthropic_provider.py new file mode 100644 index 0000000..249b883 --- /dev/null +++ b/tools/ai-review/clients/providers/anthropic_provider.py @@ -0,0 +1,249 @@ +"""Anthropic Claude Provider + +Direct integration with Anthropic's Claude API. +Supports Claude 3.5 Sonnet, Claude 3 Opus, and other models. +""" + +import json +import os + +# Import base classes from parent module +import sys +from dataclasses import dataclass + +import requests + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall + + +class AnthropicProvider(BaseLLMProvider): + """Anthropic Claude API provider. + + Provides direct integration with Anthropic's Claude models + without going through OpenRouter. + + Supports: + - Claude 3.5 Sonnet (claude-3-5-sonnet-20241022) + - Claude 3 Opus (claude-3-opus-20240229) + - Claude 3 Sonnet (claude-3-sonnet-20240229) + - Claude 3 Haiku (claude-3-haiku-20240307) + """ + + API_URL = "https://api.anthropic.com/v1/messages" + API_VERSION = "2023-06-01" + + def __init__( + self, + api_key: str | None = None, + model: str = "claude-3-5-sonnet-20241022", + temperature: float = 0, + max_tokens: int = 4096, + ): + """Initialize the Anthropic provider. + + Args: + api_key: Anthropic API key. Defaults to ANTHROPIC_API_KEY env var. + model: Model to use. Defaults to Claude 3.5 Sonnet. + temperature: Sampling temperature (0-1). + max_tokens: Maximum tokens in response. + """ + self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY", "") + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + def call(self, prompt: str, **kwargs) -> LLMResponse: + """Make a call to the Anthropic API. + + Args: + prompt: The prompt to send. + **kwargs: Additional options (model, temperature, max_tokens). + + Returns: + LLMResponse with the generated content. + + Raises: + ValueError: If API key is not set. + requests.HTTPError: If the API request fails. + """ + if not self.api_key: + raise ValueError("Anthropic API key is required") + + response = requests.post( + self.API_URL, + headers={ + "x-api-key": self.api_key, + "anthropic-version": self.API_VERSION, + "Content-Type": "application/json", + }, + json={ + "model": kwargs.get("model", self.model), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + "messages": [{"role": "user", "content": prompt}], + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Extract content from response + content = "" + for block in data.get("content", []): + if block.get("type") == "text": + content += block.get("text", "") + + return LLMResponse( + content=content, + model=data.get("model", self.model), + provider="anthropic", + tokens_used=data.get("usage", {}).get("input_tokens", 0) + + data.get("usage", {}).get("output_tokens", 0), + finish_reason=data.get("stop_reason"), + ) + + def call_with_tools( + self, + messages: list[dict], + tools: list[dict] | None = None, + **kwargs, + ) -> LLMResponse: + """Make a call to the Anthropic API with tool support. + + Args: + messages: List of message dicts with 'role' and 'content'. + tools: List of tool definitions in OpenAI format. + **kwargs: Additional options. + + Returns: + LLMResponse with content and/or tool_calls. + """ + if not self.api_key: + raise ValueError("Anthropic API key is required") + + # Convert OpenAI-style messages to Anthropic format + anthropic_messages = [] + system_content = None + + for msg in messages: + role = msg.get("role", "user") + + if role == "system": + system_content = msg.get("content", "") + elif role == "assistant": + # Handle assistant messages with tool calls + if msg.get("tool_calls"): + content = [] + if msg.get("content"): + content.append({"type": "text", "text": msg["content"]}) + for tc in msg["tool_calls"]: + content.append( + { + "type": "tool_use", + "id": tc["id"], + "name": tc["function"]["name"], + "input": json.loads(tc["function"]["arguments"]) + if isinstance(tc["function"]["arguments"], str) + else tc["function"]["arguments"], + } + ) + anthropic_messages.append({"role": "assistant", "content": content}) + else: + anthropic_messages.append( + { + "role": "assistant", + "content": msg.get("content", ""), + } + ) + elif role == "tool": + # Tool response + anthropic_messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": msg.get("content", ""), + } + ], + } + ) + else: + anthropic_messages.append( + { + "role": "user", + "content": msg.get("content", ""), + } + ) + + # Convert OpenAI-style tools to Anthropic format + anthropic_tools = None + if tools: + anthropic_tools = [] + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + anthropic_tools.append( + { + "name": func["name"], + "description": func.get("description", ""), + "input_schema": func.get("parameters", {}), + } + ) + + request_body = { + "model": kwargs.get("model", self.model), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + "messages": anthropic_messages, + } + + if system_content: + request_body["system"] = system_content + + if anthropic_tools: + request_body["tools"] = anthropic_tools + + response = requests.post( + self.API_URL, + headers={ + "x-api-key": self.api_key, + "anthropic-version": self.API_VERSION, + "Content-Type": "application/json", + }, + json=request_body, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Parse response + content = "" + tool_calls = None + + for block in data.get("content", []): + if block.get("type") == "text": + content += block.get("text", "") + elif block.get("type") == "tool_use": + if tool_calls is None: + tool_calls = [] + tool_calls.append( + ToolCall( + id=block.get("id", ""), + name=block.get("name", ""), + arguments=block.get("input", {}), + ) + ) + + return LLMResponse( + content=content, + model=data.get("model", self.model), + provider="anthropic", + tokens_used=data.get("usage", {}).get("input_tokens", 0) + + data.get("usage", {}).get("output_tokens", 0), + finish_reason=data.get("stop_reason"), + tool_calls=tool_calls, + ) diff --git a/tools/ai-review/clients/providers/azure_provider.py b/tools/ai-review/clients/providers/azure_provider.py new file mode 100644 index 0000000..50543be --- /dev/null +++ b/tools/ai-review/clients/providers/azure_provider.py @@ -0,0 +1,420 @@ +"""Azure OpenAI Provider + +Integration with Azure OpenAI Service for enterprise deployments. +Supports custom deployments, regional endpoints, and Azure AD auth. +""" + +import json +import os +import sys + +import requests + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall + + +class AzureOpenAIProvider(BaseLLMProvider): + """Azure OpenAI Service provider. + + Provides integration with Azure-hosted OpenAI models for + enterprise customers with Azure deployments. + + Supports: + - GPT-4, GPT-4 Turbo, GPT-4o + - GPT-3.5 Turbo + - Custom fine-tuned models + + Environment Variables: + - AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL + - AZURE_OPENAI_API_KEY: API key for authentication + - AZURE_OPENAI_DEPLOYMENT: Default deployment name + - AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview) + """ + + DEFAULT_API_VERSION = "2024-02-15-preview" + + def __init__( + self, + endpoint: str | None = None, + api_key: str | None = None, + deployment: str | None = None, + api_version: str | None = None, + temperature: float = 0, + max_tokens: int = 4096, + ): + """Initialize the Azure OpenAI provider. + + Args: + endpoint: Azure OpenAI endpoint URL. + Defaults to AZURE_OPENAI_ENDPOINT env var. + api_key: API key for authentication. + Defaults to AZURE_OPENAI_API_KEY env var. + deployment: Deployment name to use. + Defaults to AZURE_OPENAI_DEPLOYMENT env var. + api_version: API version string. + Defaults to AZURE_OPENAI_API_VERSION env var or latest. + temperature: Sampling temperature (0-2). + max_tokens: Maximum tokens in response. + """ + self.endpoint = ( + endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "") + ).rstrip("/") + self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "") + self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "") + self.api_version = api_version or os.environ.get( + "AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION + ) + self.temperature = temperature + self.max_tokens = max_tokens + + def _get_api_url(self, deployment: str | None = None) -> str: + """Build the API URL for a given deployment. + + Args: + deployment: Deployment name. Uses default if not specified. + + Returns: + Full API URL for chat completions. + """ + deploy = deployment or self.deployment + return ( + f"{self.endpoint}/openai/deployments/{deploy}" + f"/chat/completions?api-version={self.api_version}" + ) + + def call(self, prompt: str, **kwargs) -> LLMResponse: + """Make a call to the Azure OpenAI API. + + Args: + prompt: The prompt to send. + **kwargs: Additional options (deployment, temperature, max_tokens). + + Returns: + LLMResponse with the generated content. + + Raises: + ValueError: If required configuration is missing. + requests.HTTPError: If the API request fails. + """ + if not self.endpoint: + raise ValueError("Azure OpenAI endpoint is required") + if not self.api_key: + raise ValueError("Azure OpenAI API key is required") + if not self.deployment and not kwargs.get("deployment"): + raise ValueError("Azure OpenAI deployment name is required") + + deployment = kwargs.get("deployment", self.deployment) + + response = requests.post( + self._get_api_url(deployment), + headers={ + "api-key": self.api_key, + "Content-Type": "application/json", + }, + json={ + "messages": [{"role": "user", "content": prompt}], + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + choice = data.get("choices", [{}])[0] + message = choice.get("message", {}) + + return LLMResponse( + content=message.get("content", ""), + model=data.get("model", deployment), + provider="azure", + tokens_used=data.get("usage", {}).get("total_tokens", 0), + finish_reason=choice.get("finish_reason"), + ) + + def call_with_tools( + self, + messages: list[dict], + tools: list[dict] | None = None, + **kwargs, + ) -> LLMResponse: + """Make a call to the Azure OpenAI API with tool support. + + Azure OpenAI uses the same format as OpenAI for tools. + + Args: + messages: List of message dicts with 'role' and 'content'. + tools: List of tool definitions in OpenAI format. + **kwargs: Additional options. + + Returns: + LLMResponse with content and/or tool_calls. + """ + if not self.endpoint: + raise ValueError("Azure OpenAI endpoint is required") + if not self.api_key: + raise ValueError("Azure OpenAI API key is required") + if not self.deployment and not kwargs.get("deployment"): + raise ValueError("Azure OpenAI deployment name is required") + + deployment = kwargs.get("deployment", self.deployment) + + request_body = { + "messages": messages, + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + } + + if tools: + request_body["tools"] = tools + request_body["tool_choice"] = kwargs.get("tool_choice", "auto") + + response = requests.post( + self._get_api_url(deployment), + headers={ + "api-key": self.api_key, + "Content-Type": "application/json", + }, + json=request_body, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + choice = data.get("choices", [{}])[0] + message = choice.get("message", {}) + + # Parse tool calls if present + tool_calls = None + if message.get("tool_calls"): + tool_calls = [] + for tc in message["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {} + tool_calls.append( + ToolCall( + id=tc.get("id", ""), + name=func.get("name", ""), + arguments=args, + ) + ) + + return LLMResponse( + content=message.get("content", "") or "", + model=data.get("model", deployment), + provider="azure", + tokens_used=data.get("usage", {}).get("total_tokens", 0), + finish_reason=choice.get("finish_reason"), + tool_calls=tool_calls, + ) + + +class AzureOpenAIWithAADProvider(AzureOpenAIProvider): + """Azure OpenAI provider with Azure Active Directory authentication. + + Uses Azure AD tokens instead of API keys for authentication. + Requires azure-identity package for token acquisition. + + Environment Variables: + - AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL + - AZURE_OPENAI_DEPLOYMENT: Default deployment name + - AZURE_TENANT_ID: Azure AD tenant ID (optional) + - AZURE_CLIENT_ID: Azure AD client ID (optional) + - AZURE_CLIENT_SECRET: Azure AD client secret (optional) + """ + + SCOPE = "https://cognitiveservices.azure.com/.default" + + def __init__( + self, + endpoint: str | None = None, + deployment: str | None = None, + api_version: str | None = None, + temperature: float = 0, + max_tokens: int = 4096, + credential=None, + ): + """Initialize the Azure OpenAI AAD provider. + + Args: + endpoint: Azure OpenAI endpoint URL. + deployment: Deployment name to use. + api_version: API version string. + temperature: Sampling temperature (0-2). + max_tokens: Maximum tokens in response. + credential: Azure credential object. If not provided, + uses DefaultAzureCredential. + """ + super().__init__( + endpoint=endpoint, + api_key="", # Not used with AAD + deployment=deployment, + api_version=api_version, + temperature=temperature, + max_tokens=max_tokens, + ) + self._credential = credential + self._token = None + self._token_expires_at = 0 + + def _get_token(self) -> str: + """Get an Azure AD token for authentication. + + Returns: + Bearer token string. + + Raises: + ImportError: If azure-identity is not installed. + """ + import time + + # Return cached token if still valid (with 5 min buffer) + if self._token and self._token_expires_at > time.time() + 300: + return self._token + + try: + from azure.identity import DefaultAzureCredential + except ImportError: + raise ImportError( + "azure-identity package is required for AAD authentication. " + "Install with: pip install azure-identity" + ) + + if self._credential is None: + self._credential = DefaultAzureCredential() + + token = self._credential.get_token(self.SCOPE) + self._token = token.token + self._token_expires_at = token.expires_on + + return self._token + + def call(self, prompt: str, **kwargs) -> LLMResponse: + """Make a call to the Azure OpenAI API using AAD auth. + + Args: + prompt: The prompt to send. + **kwargs: Additional options. + + Returns: + LLMResponse with the generated content. + """ + if not self.endpoint: + raise ValueError("Azure OpenAI endpoint is required") + if not self.deployment and not kwargs.get("deployment"): + raise ValueError("Azure OpenAI deployment name is required") + + deployment = kwargs.get("deployment", self.deployment) + token = self._get_token() + + response = requests.post( + self._get_api_url(deployment), + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={ + "messages": [{"role": "user", "content": prompt}], + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + choice = data.get("choices", [{}])[0] + message = choice.get("message", {}) + + return LLMResponse( + content=message.get("content", ""), + model=data.get("model", deployment), + provider="azure", + tokens_used=data.get("usage", {}).get("total_tokens", 0), + finish_reason=choice.get("finish_reason"), + ) + + def call_with_tools( + self, + messages: list[dict], + tools: list[dict] | None = None, + **kwargs, + ) -> LLMResponse: + """Make a call to the Azure OpenAI API with tool support using AAD auth. + + Args: + messages: List of message dicts. + tools: List of tool definitions. + **kwargs: Additional options. + + Returns: + LLMResponse with content and/or tool_calls. + """ + if not self.endpoint: + raise ValueError("Azure OpenAI endpoint is required") + if not self.deployment and not kwargs.get("deployment"): + raise ValueError("Azure OpenAI deployment name is required") + + deployment = kwargs.get("deployment", self.deployment) + token = self._get_token() + + request_body = { + "messages": messages, + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + } + + if tools: + request_body["tools"] = tools + request_body["tool_choice"] = kwargs.get("tool_choice", "auto") + + response = requests.post( + self._get_api_url(deployment), + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json=request_body, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + choice = data.get("choices", [{}])[0] + message = choice.get("message", {}) + + # Parse tool calls if present + tool_calls = None + if message.get("tool_calls"): + tool_calls = [] + for tc in message["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {} + tool_calls.append( + ToolCall( + id=tc.get("id", ""), + name=func.get("name", ""), + arguments=args, + ) + ) + + return LLMResponse( + content=message.get("content", "") or "", + model=data.get("model", deployment), + provider="azure", + tokens_used=data.get("usage", {}).get("total_tokens", 0), + finish_reason=choice.get("finish_reason"), + tool_calls=tool_calls, + ) diff --git a/tools/ai-review/clients/providers/gemini_provider.py b/tools/ai-review/clients/providers/gemini_provider.py new file mode 100644 index 0000000..d6f389f --- /dev/null +++ b/tools/ai-review/clients/providers/gemini_provider.py @@ -0,0 +1,599 @@ +"""Google Gemini Provider + +Integration with Google's Gemini API for GCP customers. +Supports Gemini Pro, Gemini Ultra, and other models. +""" + +import json +import os +import sys + +import requests + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall + + +class GeminiProvider(BaseLLMProvider): + """Google Gemini API provider. + + Provides integration with Google's Gemini models. + + Supports: + - Gemini 1.5 Pro (gemini-1.5-pro) + - Gemini 1.5 Flash (gemini-1.5-flash) + - Gemini 1.0 Pro (gemini-pro) + + Environment Variables: + - GOOGLE_API_KEY: Google AI API key + - GEMINI_MODEL: Default model (optional) + """ + + API_URL = "https://generativelanguage.googleapis.com/v1beta/models" + + def __init__( + self, + api_key: str | None = None, + model: str = "gemini-1.5-pro", + temperature: float = 0, + max_tokens: int = 4096, + ): + """Initialize the Gemini provider. + + Args: + api_key: Google API key. Defaults to GOOGLE_API_KEY env var. + model: Model to use. Defaults to gemini-1.5-pro. + temperature: Sampling temperature (0-1). + max_tokens: Maximum tokens in response. + """ + self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "") + self.model = model or os.environ.get("GEMINI_MODEL", "gemini-1.5-pro") + self.temperature = temperature + self.max_tokens = max_tokens + + def _get_api_url(self, model: str | None = None, stream: bool = False) -> str: + """Build the API URL for a given model. + + Args: + model: Model name. Uses default if not specified. + stream: Whether to use streaming endpoint. + + Returns: + Full API URL. + """ + m = model or self.model + action = "streamGenerateContent" if stream else "generateContent" + return f"{self.API_URL}/{m}:{action}?key={self.api_key}" + + def call(self, prompt: str, **kwargs) -> LLMResponse: + """Make a call to the Gemini API. + + Args: + prompt: The prompt to send. + **kwargs: Additional options (model, temperature, max_tokens). + + Returns: + LLMResponse with the generated content. + + Raises: + ValueError: If API key is not set. + requests.HTTPError: If the API request fails. + """ + if not self.api_key: + raise ValueError("Google API key is required") + + model = kwargs.get("model", self.model) + + response = requests.post( + self._get_api_url(model), + headers={"Content-Type": "application/json"}, + json={ + "contents": [{"parts": [{"text": prompt}]}], + "generationConfig": { + "temperature": kwargs.get("temperature", self.temperature), + "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), + }, + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Extract content from response + content = "" + candidates = data.get("candidates", []) + if candidates: + parts = candidates[0].get("content", {}).get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + + # Get token counts + usage = data.get("usageMetadata", {}) + tokens_used = usage.get("promptTokenCount", 0) + usage.get( + "candidatesTokenCount", 0 + ) + + finish_reason = None + if candidates: + finish_reason = candidates[0].get("finishReason") + + return LLMResponse( + content=content, + model=model, + provider="gemini", + tokens_used=tokens_used, + finish_reason=finish_reason, + ) + + def call_with_tools( + self, + messages: list[dict], + tools: list[dict] | None = None, + **kwargs, + ) -> LLMResponse: + """Make a call to the Gemini API with tool support. + + Args: + messages: List of message dicts with 'role' and 'content'. + tools: List of tool definitions in OpenAI format. + **kwargs: Additional options. + + Returns: + LLMResponse with content and/or tool_calls. + """ + if not self.api_key: + raise ValueError("Google API key is required") + + model = kwargs.get("model", self.model) + + # Convert OpenAI-style messages to Gemini format + gemini_contents = [] + system_instruction = None + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + system_instruction = content + elif role == "assistant": + # Handle assistant messages with tool calls + parts = [] + if content: + parts.append({"text": content}) + + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments", {}) + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {} + parts.append( + { + "functionCall": { + "name": func.get("name", ""), + "args": args, + } + } + ) + + gemini_contents.append({"role": "model", "parts": parts}) + elif role == "tool": + # Tool response in Gemini format + gemini_contents.append( + { + "role": "function", + "parts": [ + { + "functionResponse": { + "name": msg.get("name", ""), + "response": {"result": content}, + } + } + ], + } + ) + else: + # User message + gemini_contents.append({"role": "user", "parts": [{"text": content}]}) + + # Convert OpenAI-style tools to Gemini format + gemini_tools = None + if tools: + function_declarations = [] + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + function_declarations.append( + { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ) + if function_declarations: + gemini_tools = [{"functionDeclarations": function_declarations}] + + request_body = { + "contents": gemini_contents, + "generationConfig": { + "temperature": kwargs.get("temperature", self.temperature), + "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), + }, + } + + if system_instruction: + request_body["systemInstruction"] = { + "parts": [{"text": system_instruction}] + } + + if gemini_tools: + request_body["tools"] = gemini_tools + + response = requests.post( + self._get_api_url(model), + headers={"Content-Type": "application/json"}, + json=request_body, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Parse response + content = "" + tool_calls = None + + candidates = data.get("candidates", []) + if candidates: + parts = candidates[0].get("content", {}).get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + elif "functionCall" in part: + if tool_calls is None: + tool_calls = [] + fc = part["functionCall"] + tool_calls.append( + ToolCall( + id=f"call_{len(tool_calls)}", # Gemini doesn't provide IDs + name=fc.get("name", ""), + arguments=fc.get("args", {}), + ) + ) + + # Get token counts + usage = data.get("usageMetadata", {}) + tokens_used = usage.get("promptTokenCount", 0) + usage.get( + "candidatesTokenCount", 0 + ) + + finish_reason = None + if candidates: + finish_reason = candidates[0].get("finishReason") + + return LLMResponse( + content=content, + model=model, + provider="gemini", + tokens_used=tokens_used, + finish_reason=finish_reason, + tool_calls=tool_calls, + ) + + +class VertexAIGeminiProvider(BaseLLMProvider): + """Google Vertex AI Gemini provider for enterprise GCP deployments. + + Uses Vertex AI endpoints instead of the public Gemini API. + Supports regional deployments and IAM authentication. + + Environment Variables: + - GOOGLE_CLOUD_PROJECT: GCP project ID + - GOOGLE_CLOUD_REGION: GCP region (default: us-central1) + - VERTEX_AI_MODEL: Default model (optional) + """ + + def __init__( + self, + project: str | None = None, + region: str = "us-central1", + model: str = "gemini-1.5-pro", + temperature: float = 0, + max_tokens: int = 4096, + credentials=None, + ): + """Initialize the Vertex AI Gemini provider. + + Args: + project: GCP project ID. Defaults to GOOGLE_CLOUD_PROJECT env var. + region: GCP region. Defaults to us-central1. + model: Model to use. Defaults to gemini-1.5-pro. + temperature: Sampling temperature (0-1). + max_tokens: Maximum tokens in response. + credentials: Google credentials object. If not provided, + uses Application Default Credentials. + """ + self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT", "") + self.region = region or os.environ.get("GOOGLE_CLOUD_REGION", "us-central1") + self.model = model or os.environ.get("VERTEX_AI_MODEL", "gemini-1.5-pro") + self.temperature = temperature + self.max_tokens = max_tokens + self._credentials = credentials + self._token = None + self._token_expires_at = 0 + + def _get_token(self) -> str: + """Get a Google Cloud access token. + + Returns: + Access token string. + + Raises: + ImportError: If google-auth is not installed. + """ + import time + + # Return cached token if still valid (with 5 min buffer) + if self._token and self._token_expires_at > time.time() + 300: + return self._token + + try: + import google.auth + from google.auth.transport.requests import Request + except ImportError: + raise ImportError( + "google-auth package is required for Vertex AI authentication. " + "Install with: pip install google-auth" + ) + + if self._credentials is None: + self._credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + if not self._credentials.valid: + self._credentials.refresh(Request()) + + self._token = self._credentials.token + # Tokens typically expire in 1 hour + self._token_expires_at = time.time() + 3500 + + return self._token + + def _get_api_url(self, model: str | None = None) -> str: + """Build the Vertex AI API URL. + + Args: + model: Model name. Uses default if not specified. + + Returns: + Full API URL. + """ + m = model or self.model + return ( + f"https://{self.region}-aiplatform.googleapis.com/v1/" + f"projects/{self.project}/locations/{self.region}/" + f"publishers/google/models/{m}:generateContent" + ) + + def call(self, prompt: str, **kwargs) -> LLMResponse: + """Make a call to Vertex AI Gemini. + + Args: + prompt: The prompt to send. + **kwargs: Additional options. + + Returns: + LLMResponse with the generated content. + """ + if not self.project: + raise ValueError("GCP project ID is required") + + model = kwargs.get("model", self.model) + token = self._get_token() + + response = requests.post( + self._get_api_url(model), + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={ + "contents": [{"parts": [{"text": prompt}]}], + "generationConfig": { + "temperature": kwargs.get("temperature", self.temperature), + "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), + }, + }, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Extract content from response + content = "" + candidates = data.get("candidates", []) + if candidates: + parts = candidates[0].get("content", {}).get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + + # Get token counts + usage = data.get("usageMetadata", {}) + tokens_used = usage.get("promptTokenCount", 0) + usage.get( + "candidatesTokenCount", 0 + ) + + finish_reason = None + if candidates: + finish_reason = candidates[0].get("finishReason") + + return LLMResponse( + content=content, + model=model, + provider="vertex-ai", + tokens_used=tokens_used, + finish_reason=finish_reason, + ) + + def call_with_tools( + self, + messages: list[dict], + tools: list[dict] | None = None, + **kwargs, + ) -> LLMResponse: + """Make a call to Vertex AI Gemini with tool support. + + Args: + messages: List of message dicts. + tools: List of tool definitions. + **kwargs: Additional options. + + Returns: + LLMResponse with content and/or tool_calls. + """ + if not self.project: + raise ValueError("GCP project ID is required") + + model = kwargs.get("model", self.model) + token = self._get_token() + + # Convert messages to Gemini format (same as GeminiProvider) + gemini_contents = [] + system_instruction = None + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "system": + system_instruction = content + elif role == "assistant": + parts = [] + if content: + parts.append({"text": content}) + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments", {}) + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {} + parts.append( + { + "functionCall": { + "name": func.get("name", ""), + "args": args, + } + } + ) + gemini_contents.append({"role": "model", "parts": parts}) + elif role == "tool": + gemini_contents.append( + { + "role": "function", + "parts": [ + { + "functionResponse": { + "name": msg.get("name", ""), + "response": {"result": content}, + } + } + ], + } + ) + else: + gemini_contents.append({"role": "user", "parts": [{"text": content}]}) + + # Convert tools to Gemini format + gemini_tools = None + if tools: + function_declarations = [] + for tool in tools: + if tool.get("type") == "function": + func = tool["function"] + function_declarations.append( + { + "name": func["name"], + "description": func.get("description", ""), + "parameters": func.get("parameters", {}), + } + ) + if function_declarations: + gemini_tools = [{"functionDeclarations": function_declarations}] + + request_body = { + "contents": gemini_contents, + "generationConfig": { + "temperature": kwargs.get("temperature", self.temperature), + "maxOutputTokens": kwargs.get("max_tokens", self.max_tokens), + }, + } + + if system_instruction: + request_body["systemInstruction"] = { + "parts": [{"text": system_instruction}] + } + + if gemini_tools: + request_body["tools"] = gemini_tools + + response = requests.post( + self._get_api_url(model), + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json=request_body, + timeout=120, + ) + response.raise_for_status() + data = response.json() + + # Parse response + content = "" + tool_calls = None + + candidates = data.get("candidates", []) + if candidates: + parts = candidates[0].get("content", {}).get("parts", []) + for part in parts: + if "text" in part: + content += part["text"] + elif "functionCall" in part: + if tool_calls is None: + tool_calls = [] + fc = part["functionCall"] + tool_calls.append( + ToolCall( + id=f"call_{len(tool_calls)}", + name=fc.get("name", ""), + arguments=fc.get("args", {}), + ) + ) + + usage = data.get("usageMetadata", {}) + tokens_used = usage.get("promptTokenCount", 0) + usage.get( + "candidatesTokenCount", 0 + ) + + finish_reason = None + if candidates: + finish_reason = candidates[0].get("finishReason") + + return LLMResponse( + content=content, + model=model, + provider="vertex-ai", + tokens_used=tokens_used, + finish_reason=finish_reason, + tool_calls=tool_calls, + ) diff --git a/tools/ai-review/compliance/__init__.py b/tools/ai-review/compliance/__init__.py new file mode 100644 index 0000000..4d6050d --- /dev/null +++ b/tools/ai-review/compliance/__init__.py @@ -0,0 +1,14 @@ +"""Compliance Module + +Provides audit trail, compliance reporting, and regulatory checks. +""" + +from compliance.audit_trail import AuditEvent, AuditLogger, AuditTrail +from compliance.codeowners import CodeownersChecker + +__all__ = [ + "AuditTrail", + "AuditLogger", + "AuditEvent", + "CodeownersChecker", +] diff --git a/tools/ai-review/compliance/audit_trail.py b/tools/ai-review/compliance/audit_trail.py new file mode 100644 index 0000000..d495394 --- /dev/null +++ b/tools/ai-review/compliance/audit_trail.py @@ -0,0 +1,430 @@ +"""Audit Trail + +Provides comprehensive audit logging for compliance requirements. +Supports HIPAA, SOC2, and other regulatory frameworks. +""" + +import hashlib +import json +import logging +import os +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any + + +class AuditAction(Enum): + """Types of auditable actions.""" + + # Review actions + REVIEW_STARTED = "review_started" + REVIEW_COMPLETED = "review_completed" + REVIEW_FAILED = "review_failed" + + # Security actions + SECURITY_SCAN_STARTED = "security_scan_started" + SECURITY_SCAN_COMPLETED = "security_scan_completed" + SECURITY_FINDING_DETECTED = "security_finding_detected" + SECURITY_FINDING_RESOLVED = "security_finding_resolved" + + # Comment actions + COMMENT_POSTED = "comment_posted" + COMMENT_UPDATED = "comment_updated" + COMMENT_DELETED = "comment_deleted" + + # Label actions + LABEL_ADDED = "label_added" + LABEL_REMOVED = "label_removed" + + # Configuration actions + CONFIG_LOADED = "config_loaded" + CONFIG_CHANGED = "config_changed" + + # Access actions + API_CALL = "api_call" + AUTHENTICATION = "authentication" + + # Approval actions + APPROVAL_GRANTED = "approval_granted" + APPROVAL_REVOKED = "approval_revoked" + CHANGES_REQUESTED = "changes_requested" + + +@dataclass +class AuditEvent: + """An auditable event.""" + + action: AuditAction + timestamp: str + actor: str + resource_type: str + resource_id: str + repository: str + details: dict[str, Any] = field(default_factory=dict) + outcome: str = "success" + error: str | None = None + correlation_id: str | None = None + checksum: str | None = None + + def __post_init__(self): + """Calculate checksum for integrity verification.""" + if not self.checksum: + self.checksum = self._calculate_checksum() + + def _calculate_checksum(self) -> str: + """Calculate SHA-256 checksum of event data.""" + data = { + "action": self.action.value + if isinstance(self.action, AuditAction) + else self.action, + "timestamp": self.timestamp, + "actor": self.actor, + "resource_type": self.resource_type, + "resource_id": self.resource_id, + "repository": self.repository, + "details": self.details, + "outcome": self.outcome, + "error": self.error, + } + json_str = json.dumps(data, sort_keys=True) + return hashlib.sha256(json_str.encode()).hexdigest() + + def to_dict(self) -> dict: + """Convert event to dictionary.""" + data = asdict(self) + if isinstance(self.action, AuditAction): + data["action"] = self.action.value + return data + + def to_json(self) -> str: + """Convert event to JSON string.""" + return json.dumps(self.to_dict()) + + +class AuditLogger: + """Logger for audit events.""" + + def __init__( + self, + log_file: str | None = None, + log_to_stdout: bool = False, + log_level: str = "INFO", + ): + """Initialize audit logger. + + Args: + log_file: Path to audit log file. + log_to_stdout: Also log to stdout. + log_level: Logging level. + """ + self.log_file = log_file + self.log_to_stdout = log_to_stdout + self.logger = logging.getLogger("audit") + self.logger.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + + # Clear existing handlers + self.logger.handlers = [] + + # Add file handler if specified + if log_file: + log_dir = os.path.dirname(log_file) + if log_dir: + os.makedirs(log_dir, exist_ok=True) + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter( + logging.Formatter("%(message)s") # JSON lines format + ) + self.logger.addHandler(file_handler) + + # Add stdout handler if requested + if log_to_stdout: + stdout_handler = logging.StreamHandler() + stdout_handler.setFormatter(logging.Formatter("[AUDIT] %(message)s")) + self.logger.addHandler(stdout_handler) + + def log(self, event: AuditEvent): + """Log an audit event. + + Args: + event: The audit event to log. + """ + self.logger.info(event.to_json()) + + def log_action( + self, + action: AuditAction, + actor: str, + resource_type: str, + resource_id: str, + repository: str, + details: dict | None = None, + outcome: str = "success", + error: str | None = None, + correlation_id: str | None = None, + ): + """Log an action as an audit event. + + Args: + action: The action being performed. + actor: Who performed the action. + resource_type: Type of resource affected. + resource_id: ID of the resource. + repository: Repository context. + details: Additional details. + outcome: success, failure, or partial. + error: Error message if failed. + correlation_id: ID to correlate related events. + """ + event = AuditEvent( + action=action, + timestamp=datetime.now(timezone.utc).isoformat(), + actor=actor, + resource_type=resource_type, + resource_id=resource_id, + repository=repository, + details=details or {}, + outcome=outcome, + error=error, + correlation_id=correlation_id, + ) + self.log(event) + + +class AuditTrail: + """High-level audit trail management.""" + + def __init__(self, config: dict): + """Initialize audit trail. + + Args: + config: Configuration dictionary. + """ + self.config = config + compliance_config = config.get("compliance", {}) + audit_config = compliance_config.get("audit", {}) + + self.enabled = audit_config.get("enabled", False) + self.log_file = audit_config.get("log_file", "audit.log") + self.log_to_stdout = audit_config.get("log_to_stdout", False) + self.retention_days = audit_config.get("retention_days", 90) + + if self.enabled: + self.logger = AuditLogger( + log_file=self.log_file, + log_to_stdout=self.log_to_stdout, + ) + else: + self.logger = None + + self._correlation_id = None + + def set_correlation_id(self, correlation_id: str): + """Set correlation ID for subsequent events. + + Args: + correlation_id: ID to correlate related events. + """ + self._correlation_id = correlation_id + + def log( + self, + action: AuditAction, + actor: str, + resource_type: str, + resource_id: str, + repository: str, + details: dict | None = None, + outcome: str = "success", + error: str | None = None, + ): + """Log an audit event. + + Args: + action: The action being performed. + actor: Who performed the action. + resource_type: Type of resource (pr, issue, comment, etc). + resource_id: ID of the resource. + repository: Repository (owner/repo). + details: Additional details. + outcome: success, failure, or partial. + error: Error message if failed. + """ + if not self.enabled or not self.logger: + return + + self.logger.log_action( + action=action, + actor=actor, + resource_type=resource_type, + resource_id=resource_id, + repository=repository, + details=details, + outcome=outcome, + error=error, + correlation_id=self._correlation_id, + ) + + def log_review_started( + self, + repository: str, + pr_number: int, + reviewer: str = "openrabbit", + ): + """Log that a review has started.""" + self.log( + action=AuditAction.REVIEW_STARTED, + actor=reviewer, + resource_type="pull_request", + resource_id=str(pr_number), + repository=repository, + ) + + def log_review_completed( + self, + repository: str, + pr_number: int, + recommendation: str, + findings_count: int, + reviewer: str = "openrabbit", + ): + """Log that a review has completed.""" + self.log( + action=AuditAction.REVIEW_COMPLETED, + actor=reviewer, + resource_type="pull_request", + resource_id=str(pr_number), + repository=repository, + details={ + "recommendation": recommendation, + "findings_count": findings_count, + }, + ) + + def log_security_finding( + self, + repository: str, + pr_number: int, + finding: dict, + scanner: str = "openrabbit", + ): + """Log a security finding.""" + self.log( + action=AuditAction.SECURITY_FINDING_DETECTED, + actor=scanner, + resource_type="pull_request", + resource_id=str(pr_number), + repository=repository, + details={ + "severity": finding.get("severity"), + "category": finding.get("category"), + "file": finding.get("file"), + "line": finding.get("line"), + "cwe": finding.get("cwe"), + }, + ) + + def log_approval( + self, + repository: str, + pr_number: int, + approver: str, + approval_type: str = "ai", + ): + """Log an approval action.""" + self.log( + action=AuditAction.APPROVAL_GRANTED, + actor=approver, + resource_type="pull_request", + resource_id=str(pr_number), + repository=repository, + details={"approval_type": approval_type}, + ) + + def log_changes_requested( + self, + repository: str, + pr_number: int, + requester: str, + reason: str | None = None, + ): + """Log a changes requested action.""" + self.log( + action=AuditAction.CHANGES_REQUESTED, + actor=requester, + resource_type="pull_request", + resource_id=str(pr_number), + repository=repository, + details={"reason": reason} if reason else {}, + ) + + def generate_report( + self, + start_date: datetime | None = None, + end_date: datetime | None = None, + repository: str | None = None, + ) -> dict: + """Generate an audit report. + + Args: + start_date: Start of reporting period. + end_date: End of reporting period. + repository: Filter by repository. + + Returns: + Report dictionary with statistics and events. + """ + if not self.log_file or not os.path.exists(self.log_file): + return {"events": [], "statistics": {}} + + events = [] + with open(self.log_file) as f: + for line in f: + try: + event = json.loads(line.strip()) + event_time = datetime.fromisoformat( + event["timestamp"].replace("Z", "+00:00") + ) + + # Apply filters + if start_date and event_time < start_date: + continue + if end_date and event_time > end_date: + continue + if repository and event.get("repository") != repository: + continue + + events.append(event) + except (json.JSONDecodeError, KeyError): + continue + + # Calculate statistics + action_counts = {} + outcome_counts = {"success": 0, "failure": 0, "partial": 0} + security_findings = 0 + + for event in events: + action = event.get("action", "unknown") + action_counts[action] = action_counts.get(action, 0) + 1 + + outcome = event.get("outcome", "success") + if outcome in outcome_counts: + outcome_counts[outcome] += 1 + + if action == "security_finding_detected": + security_findings += 1 + + return { + "events": events, + "statistics": { + "total_events": len(events), + "action_counts": action_counts, + "outcome_counts": outcome_counts, + "security_findings": security_findings, + }, + "period": { + "start": start_date.isoformat() if start_date else None, + "end": end_date.isoformat() if end_date else None, + }, + } diff --git a/tools/ai-review/compliance/codeowners.py b/tools/ai-review/compliance/codeowners.py new file mode 100644 index 0000000..0b02c29 --- /dev/null +++ b/tools/ai-review/compliance/codeowners.py @@ -0,0 +1,314 @@ +"""CODEOWNERS Checker + +Parses and validates CODEOWNERS files for compliance enforcement. +""" + +import fnmatch +import logging +import os +import re +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class CodeOwnerRule: + """A CODEOWNERS rule.""" + + pattern: str + owners: list[str] + line_number: int + is_negation: bool = False + + def matches(self, path: str) -> bool: + """Check if a path matches this rule. + + Args: + path: File path to check. + + Returns: + True if the path matches. + """ + path = path.lstrip("/") + pattern = self.pattern.lstrip("/") + + # Handle directory patterns + if pattern.endswith("/"): + return path.startswith(pattern) or fnmatch.fnmatch(path, pattern + "*") + + # Handle ** patterns + if "**" in pattern: + regex = pattern.replace("**", ".*").replace("*", "[^/]*") + return bool(re.match(f"^{regex}$", path)) + + # Standard fnmatch + return fnmatch.fnmatch(path, pattern) or fnmatch.fnmatch(path, f"**/{pattern}") + + +class CodeownersChecker: + """Checker for CODEOWNERS file compliance.""" + + CODEOWNERS_LOCATIONS = [ + "CODEOWNERS", + ".github/CODEOWNERS", + ".gitea/CODEOWNERS", + "docs/CODEOWNERS", + ] + + def __init__(self, repo_root: str | None = None): + """Initialize CODEOWNERS checker. + + Args: + repo_root: Repository root path. + """ + self.repo_root = repo_root or os.getcwd() + self.rules: list[CodeOwnerRule] = [] + self.codeowners_path: str | None = None + self.logger = logging.getLogger(__name__) + + self._load_codeowners() + + def _load_codeowners(self): + """Load CODEOWNERS file from repository.""" + for location in self.CODEOWNERS_LOCATIONS: + path = os.path.join(self.repo_root, location) + if os.path.exists(path): + self.codeowners_path = path + self._parse_codeowners(path) + break + + def _parse_codeowners(self, path: str): + """Parse a CODEOWNERS file. + + Args: + path: Path to CODEOWNERS file. + """ + with open(path) as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + + # Parse pattern and owners + parts = line.split() + if len(parts) < 2: + continue + + pattern = parts[0] + owners = parts[1:] + + # Check for negation (optional syntax) + is_negation = pattern.startswith("!") + if is_negation: + pattern = pattern[1:] + + self.rules.append( + CodeOwnerRule( + pattern=pattern, + owners=owners, + line_number=line_num, + is_negation=is_negation, + ) + ) + + def get_owners(self, path: str) -> list[str]: + """Get owners for a file path. + + Args: + path: File path to check. + + Returns: + List of owner usernames/teams. + """ + owners = [] + + # Apply rules in order (later rules override earlier ones) + for rule in self.rules: + if rule.matches(path): + if rule.is_negation: + owners = [] # Clear owners for negation + else: + owners = rule.owners + + return owners + + def get_owners_for_files(self, files: list[str]) -> dict[str, list[str]]: + """Get owners for multiple files. + + Args: + files: List of file paths. + + Returns: + Dict mapping file paths to owner lists. + """ + return {f: self.get_owners(f) for f in files} + + def get_required_reviewers(self, files: list[str]) -> set[str]: + """Get all required reviewers for a set of files. + + Args: + files: List of file paths. + + Returns: + Set of all required reviewer usernames/teams. + """ + reviewers = set() + for f in files: + reviewers.update(self.get_owners(f)) + return reviewers + + def check_approval( + self, + files: list[str], + approvers: list[str], + ) -> dict: + """Check if files have required approvals. + + Args: + files: List of changed files. + approvers: List of users who approved. + + Returns: + Dict with approval status and missing approvers. + """ + required = self.get_required_reviewers(files) + approvers_set = set(approvers) + + # Normalize @ prefixes + required_normalized = {r.lstrip("@") for r in required} + approvers_normalized = {a.lstrip("@") for a in approvers_set} + + missing = required_normalized - approvers_normalized + + # Check for team approvals (simplified - actual implementation + # would need API calls to check team membership) + teams = {r for r in missing if "/" in r} + missing_users = missing - teams + + return { + "approved": len(missing_users) == 0, + "required_reviewers": list(required_normalized), + "actual_approvers": list(approvers_normalized), + "missing_approvers": list(missing_users), + "pending_teams": list(teams), + } + + def get_coverage_report(self, files: list[str]) -> dict: + """Generate a coverage report for files. + + Args: + files: List of file paths. + + Returns: + Coverage report with owned and unowned files. + """ + owned = [] + unowned = [] + + for f in files: + owners = self.get_owners(f) + if owners: + owned.append({"file": f, "owners": owners}) + else: + unowned.append(f) + + return { + "total_files": len(files), + "owned_files": len(owned), + "unowned_files": len(unowned), + "coverage_percent": (len(owned) / len(files) * 100) if files else 0, + "owned": owned, + "unowned": unowned, + } + + def validate_codeowners(self) -> dict: + """Validate the CODEOWNERS file. + + Returns: + Validation result with warnings and errors. + """ + if not self.codeowners_path: + return { + "valid": False, + "errors": ["No CODEOWNERS file found"], + "warnings": [], + } + + errors = [] + warnings = [] + + # Check for empty rules + for rule in self.rules: + if not rule.owners: + errors.append( + f"Line {rule.line_number}: Pattern '{rule.pattern}' has no owners" + ) + + # Check for invalid owner formats + for rule in self.rules: + for owner in rule.owners: + if not owner.startswith("@") and "/" not in owner: + warnings.append( + f"Line {rule.line_number}: Owner '{owner}' should start with @ or be a team (org/team)" + ) + + # Check for overlapping patterns + patterns_seen = {} + for rule in self.rules: + if rule.pattern in patterns_seen: + warnings.append( + f"Line {rule.line_number}: Pattern '{rule.pattern}' duplicates line {patterns_seen[rule.pattern]}" + ) + patterns_seen[rule.pattern] = rule.line_number + + return { + "valid": len(errors) == 0, + "errors": errors, + "warnings": warnings, + "rules_count": len(self.rules), + "file_path": self.codeowners_path, + } + + @classmethod + def from_content(cls, content: str) -> "CodeownersChecker": + """Create checker from CODEOWNERS content string. + + Args: + content: CODEOWNERS file content. + + Returns: + CodeownersChecker instance. + """ + checker = cls.__new__(cls) + checker.repo_root = None + checker.rules = [] + checker.codeowners_path = "" + checker.logger = logging.getLogger(__name__) + + for line_num, line in enumerate(content.split("\n"), 1): + line = line.strip() + if not line or line.startswith("#"): + continue + + parts = line.split() + if len(parts) < 2: + continue + + pattern = parts[0] + owners = parts[1:] + is_negation = pattern.startswith("!") + if is_negation: + pattern = pattern[1:] + + checker.rules.append( + CodeOwnerRule( + pattern=pattern, + owners=owners, + line_number=line_num, + is_negation=is_negation, + ) + ) + + return checker diff --git a/tools/ai-review/config.yml b/tools/ai-review/config.yml index d3dc77e..82ae042 100644 --- a/tools/ai-review/config.yml +++ b/tools/ai-review/config.yml @@ -1,233 +1,355 @@ -provider: openai # openai | openrouter | ollama +# OpenRabbit AI Code Review Configuration +# ========================================= + +# LLM Provider Configuration +# -------------------------- +# Available providers: openai | openrouter | ollama | anthropic | azure | gemini +provider: openai model: - openai: gpt-4.1-mini - openrouter: anthropic/claude-3.5-sonnet - ollama: codellama:13b + openai: gpt-4.1-mini + openrouter: anthropic/claude-3.5-sonnet + ollama: codellama:13b + anthropic: claude-3-5-sonnet-20241022 + azure: gpt-4 # Deployment name + gemini: gemini-1.5-pro temperature: 0 max_tokens: 4096 +# Azure OpenAI specific settings (when provider: azure) +azure: + endpoint: "" # Set via AZURE_OPENAI_ENDPOINT env var + deployment: "" # Set via AZURE_OPENAI_DEPLOYMENT env var + api_version: "2024-02-15-preview" + +# Google Gemini specific settings (when provider: gemini) +gemini: + project: "" # For Vertex AI, set via GOOGLE_CLOUD_PROJECT env var + region: "us-central1" + +# Rate Limits and Timeouts +# ------------------------ +rate_limits: + min_interval: 1.0 # Minimum seconds between API requests + +timeouts: + llm: 120 # LLM API timeout in seconds (OpenAI, OpenRouter, Anthropic, etc.) + ollama: 300 # Ollama timeout (longer for local models) + gitea: 30 # Gitea/GitHub API timeout + # Review settings +# --------------- review: - fail_on_severity: HIGH - max_diff_lines: 800 + fail_on_severity: HIGH + max_diff_lines: 800 + inline_comments: true + security_scan: true + +# File Ignore Patterns +# -------------------- +# Similar to .gitignore, controls which files are excluded from review +ignore: + use_defaults: true # Include default patterns (node_modules, .git, etc.) + file: ".ai-reviewignore" # Custom ignore file name + patterns: [] # Additional patterns to ignore + +# Agent Configuration +# ------------------- +agents: + issue: + enabled: true + auto_label: true + auto_triage: true + duplicate_threshold: 0.85 + events: + - opened + - labeled + pr: + enabled: true inline_comments: true security_scan: true - -# Agent settings -agents: - issue: - enabled: true - auto_label: true - auto_triage: true - duplicate_threshold: 0.85 - events: - - opened - - labeled - pr: - enabled: true - inline_comments: true - security_scan: true - events: - - opened - - synchronize - auto_summary: - enabled: true # Auto-generate summary for PRs with empty descriptions - post_as_comment: true # true = post as comment, false = update PR description - codebase: - enabled: true - schedule: "0 0 * * 0" # Weekly on Sunday - chat: - enabled: true - name: "Bartender" - max_iterations: 5 # Max tool call iterations per chat - tools: - - search_codebase - - read_file - - search_web - searxng_url: "" # Set via SEARXNG_URL env var or here - -# Interaction settings -# CUSTOMIZE YOUR BOT NAME HERE! -# Change mention_prefix to your preferred bot name: -# "@ai-bot" - Default -# "@bartender" - Friendly bar theme -# "@uni" - Short and simple -# "@joey" - Personal assistant name -# "@codebot" - Code-focused name -# NOTE: Also update the workflow files (.github/workflows/ or .gitea/workflows/) -# to match this prefix in the 'if: contains(...)' condition -interaction: - respond_to_mentions: true - mention_prefix: "@codebot" # Change this to customize your bot's name! - commands: - - help - - explain - - suggest - - security - - summarize # Generate PR summary (works on both issues and PRs) - - changelog # Generate Keep a Changelog format entries (PR comments only) - - explain-diff # Explain code changes in plain language (PR comments only) - - triage - - review-again - -# Enterprise settings -enterprise: - audit_log: true - audit_path: "/var/log/ai-review/" - metrics_enabled: true - rate_limit: - requests_per_minute: 30 - max_concurrent: 4 - -# Label mappings for auto-labeling -# Each label has: -# name: The label name to use/create (string) or full config (dict) -# aliases: Alternative names for auto-detection (optional) -# color: Hex color code without # (optional, for label creation) -# description: Label description (optional, for label creation) -labels: - priority: - critical: - name: "priority: critical" - color: "b60205" # Dark Red - description: "Critical priority - immediate attention required" - aliases: - ["Priority - Critical", "P0", "critical", "Priority/Critical"] - high: - name: "priority: high" - color: "d73a4a" # Red - description: "High priority issue" - aliases: ["Priority - High", "P1", "high", "Priority/High"] - medium: - name: "priority: medium" - color: "fbca04" # Yellow - description: "Medium priority issue" - aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"] - low: - name: "priority: low" - color: "28a745" # Green - description: "Low priority issue" - aliases: ["Priority - Low", "P3", "low", "Priority/Low"] - type: - bug: - name: "type: bug" - color: "d73a4a" # Red - description: "Something isn't working" - aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"] - feature: - name: "type: feature" - color: "1d76db" # Blue - description: "New feature request" - aliases: - [ - "Kind/Feature", - "feature", - "enhancement", - "Kind/Enhancement", - "Type: Feature", - "Type/Feature", - "Kind - Feature", - ] - question: - name: "type: question" - color: "cc317c" # Purple - description: "Further information is requested" - aliases: - [ - "Kind/Question", - "question", - "Type: Question", - "Type/Question", - "Kind - Question", - ] - docs: - name: "type: documentation" - color: "0075ca" # Light Blue - description: "Documentation improvements" - aliases: - [ - "Kind/Documentation", - "documentation", - "docs", - "Type: Documentation", - "Type/Documentation", - "Kind - Documentation", - ] - security: - name: "type: security" - color: "b60205" # Dark Red - description: "Security vulnerability or concern" - aliases: - [ - "Kind/Security", - "security", - "Type: Security", - "Type/Security", - "Kind - Security", - ] - testing: - name: "type: testing" - color: "0e8a16" # Green - description: "Related to testing" - aliases: - [ - "Kind/Testing", - "testing", - "tests", - "Type: Testing", - "Type/Testing", - "Kind - Testing", - ] - status: - ai_approved: - name: "ai-approved" - color: "28a745" # Green - description: "AI review approved this PR" - aliases: - [ - "Status - Approved", - "approved", - "Status/Approved", - "Status - AI Approved", - ] - ai_changes_required: - name: "ai-changes-required" - color: "d73a4a" # Red - description: "AI review found issues requiring changes" - aliases: - [ - "Status - Changes Required", - "changes-required", - "Status/Changes Required", - "Status - AI Changes Required", - ] - ai_reviewed: - name: "ai-reviewed" - color: "1d76db" # Blue - description: "This issue/PR has been reviewed by AI" - aliases: - [ - "Reviewed - Confirmed", - "reviewed", - "Status/Reviewed", - "Reviewed/Confirmed", - "Status - Reviewed", - ] - -# Label schema detection patterns -# Used by setup-labels command to detect existing naming conventions -label_patterns: - # Detect prefix-based naming (e.g., Kind/Bug, Type/Feature) - prefix_slash: "^(Kind|Type|Category)/(.+)$" - # Detect dash-separated naming (e.g., Priority - High, Status - Blocked) - prefix_dash: "^(Priority|Status|Reviewed) - (.+)$" - # Detect colon-separated naming (e.g., type: bug, priority: high) - colon: "^(type|priority|status): (.+)$" - -# Security scanning rules -security: + events: + - opened + - synchronize + auto_summary: + enabled: true + post_as_comment: true + codebase: enabled: true - fail_on_high: true - rules_file: "security/security_rules.yml" + schedule: "0 0 * * 0" # Weekly on Sunday + chat: + enabled: true + name: "Bartender" + max_iterations: 5 + tools: + - search_codebase + - read_file + - search_web + searxng_url: "" # Set via SEARXNG_URL env var + + # Dependency Security Agent + dependency: + enabled: true + scan_on_pr: true # Auto-scan PRs that modify dependency files + vulnerability_threshold: "medium" # low | medium | high | critical + update_suggestions: true # Suggest version updates + + # Test Coverage Agent + test_coverage: + enabled: true + suggest_tests: true + min_coverage_percent: 80 # Warn if coverage below this + + # Architecture Compliance Agent + architecture: + enabled: true + layers: + api: + can_import_from: [utils, models, services] + cannot_import_from: [db, repositories] + services: + can_import_from: [utils, models, repositories] + cannot_import_from: [api] + repositories: + can_import_from: [utils, models, db] + cannot_import_from: [api, services] + +# Interaction Settings +# -------------------- +# CUSTOMIZE YOUR BOT NAME HERE! +interaction: + respond_to_mentions: true + mention_prefix: "@codebot" + commands: + - help + - explain + - suggest + - security + - summarize + - changelog + - explain-diff + - triage + - review-again + # New commands + - check-deps # Check dependencies for vulnerabilities + - suggest-tests # Suggest test cases + - refactor-suggest # Suggest refactoring opportunities + - architecture # Check architecture compliance + - arch-check # Alias for architecture + +# Security Scanning +# ----------------- +security: + enabled: true + fail_on_high: true + rules_file: "security/security_rules.yml" + + # SAST Integration + sast: + enabled: true + bandit: true # Python AST-based security scanner + semgrep: true # Polyglot security scanner with custom rules + trivy: false # Container/filesystem scanner (requires trivy installed) + +# Notifications +# ------------- +notifications: + enabled: false + threshold: "warning" # info | warning | error | critical + + slack: + enabled: false + webhook_url: "" # Set via SLACK_WEBHOOK_URL env var + channel: "" # Override channel (optional) + username: "OpenRabbit" + + discord: + enabled: false + webhook_url: "" # Set via DISCORD_WEBHOOK_URL env var + username: "OpenRabbit" + avatar_url: "" + + # Custom webhooks for other integrations + webhooks: [] + # Example: + # - url: "https://your-webhook.example.com/notify" + # enabled: true + # headers: + # Authorization: "Bearer your-token" + +# Compliance & Audit +# ------------------ +compliance: + enabled: false + + # Audit Trail + audit: + enabled: false + log_file: "audit.log" + log_to_stdout: false + retention_days: 90 + + # CODEOWNERS Enforcement + codeowners: + enabled: false + require_approval: true # Require approval from code owners + + # Regulatory Compliance + regulations: + hipaa: false + soc2: false + pci_dss: false + gdpr: false + +# Enterprise Settings +# ------------------- +enterprise: + audit_log: true + audit_path: "/var/log/ai-review/" + metrics_enabled: true + rate_limit: + requests_per_minute: 30 + max_concurrent: 4 + +# Label Mappings +# -------------- +# Each label has: +# name: The label name to use/create +# aliases: Alternative names for auto-detection +# color: Hex color code without # +# description: Label description +labels: + priority: + critical: + name: "priority: critical" + color: "b60205" + description: "Critical priority - immediate attention required" + aliases: ["Priority - Critical", "P0", "critical", "Priority/Critical"] + high: + name: "priority: high" + color: "d73a4a" + description: "High priority issue" + aliases: ["Priority - High", "P1", "high", "Priority/High"] + medium: + name: "priority: medium" + color: "fbca04" + description: "Medium priority issue" + aliases: ["Priority - Medium", "P2", "medium", "Priority/Medium"] + low: + name: "priority: low" + color: "28a745" + description: "Low priority issue" + aliases: ["Priority - Low", "P3", "low", "Priority/Low"] + type: + bug: + name: "type: bug" + color: "d73a4a" + description: "Something isn't working" + aliases: ["Kind/Bug", "bug", "Type: Bug", "Type/Bug", "Kind - Bug"] + feature: + name: "type: feature" + color: "1d76db" + description: "New feature request" + aliases: + [ + "Kind/Feature", + "feature", + "enhancement", + "Kind/Enhancement", + "Type: Feature", + "Type/Feature", + "Kind - Feature", + ] + question: + name: "type: question" + color: "cc317c" + description: "Further information is requested" + aliases: + [ + "Kind/Question", + "question", + "Type: Question", + "Type/Question", + "Kind - Question", + ] + docs: + name: "type: documentation" + color: "0075ca" + description: "Documentation improvements" + aliases: + [ + "Kind/Documentation", + "documentation", + "docs", + "Type: Documentation", + "Type/Documentation", + "Kind - Documentation", + ] + security: + name: "type: security" + color: "b60205" + description: "Security vulnerability or concern" + aliases: + [ + "Kind/Security", + "security", + "Type: Security", + "Type/Security", + "Kind - Security", + ] + testing: + name: "type: testing" + color: "0e8a16" + description: "Related to testing" + aliases: + [ + "Kind/Testing", + "testing", + "tests", + "Type: Testing", + "Type/Testing", + "Kind - Testing", + ] + status: + ai_approved: + name: "ai-approved" + color: "28a745" + description: "AI review approved this PR" + aliases: + [ + "Status - Approved", + "approved", + "Status/Approved", + "Status - AI Approved", + ] + ai_changes_required: + name: "ai-changes-required" + color: "d73a4a" + description: "AI review found issues requiring changes" + aliases: + [ + "Status - Changes Required", + "changes-required", + "Status/Changes Required", + "Status - AI Changes Required", + ] + ai_reviewed: + name: "ai-reviewed" + color: "1d76db" + description: "This issue/PR has been reviewed by AI" + aliases: + [ + "Reviewed - Confirmed", + "reviewed", + "Status/Reviewed", + "Reviewed/Confirmed", + "Status - Reviewed", + ] + +# Label Pattern Detection +# ----------------------- +label_patterns: + prefix_slash: "^(Kind|Type|Category)/(.+)$" + prefix_dash: "^(Priority|Status|Reviewed) - (.+)$" + colon: "^(type|priority|status): (.+)$" diff --git a/tools/ai-review/notifications/__init__.py b/tools/ai-review/notifications/__init__.py new file mode 100644 index 0000000..02a2d6f --- /dev/null +++ b/tools/ai-review/notifications/__init__.py @@ -0,0 +1,20 @@ +"""Notifications Package + +Provides webhook-based notifications for Slack, Discord, and other platforms. +""" + +from notifications.notifier import ( + DiscordNotifier, + Notifier, + NotifierFactory, + SlackNotifier, + WebhookNotifier, +) + +__all__ = [ + "Notifier", + "SlackNotifier", + "DiscordNotifier", + "WebhookNotifier", + "NotifierFactory", +] diff --git a/tools/ai-review/notifications/notifier.py b/tools/ai-review/notifications/notifier.py new file mode 100644 index 0000000..de4a144 --- /dev/null +++ b/tools/ai-review/notifications/notifier.py @@ -0,0 +1,542 @@ +"""Notification System + +Provides webhook-based notifications for Slack, Discord, and other platforms. +Supports critical security findings, review summaries, and custom alerts. +""" + +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any + +import requests + + +class NotificationLevel(Enum): + """Notification severity levels.""" + + INFO = "info" + WARNING = "warning" + ERROR = "error" + CRITICAL = "critical" + + +@dataclass +class NotificationMessage: + """A notification message.""" + + title: str + message: str + level: NotificationLevel = NotificationLevel.INFO + fields: dict[str, str] | None = None + url: str | None = None + footer: str | None = None + + +class Notifier(ABC): + """Abstract base class for notification providers.""" + + @abstractmethod + def send(self, message: NotificationMessage) -> bool: + """Send a notification. + + Args: + message: The notification message to send. + + Returns: + True if sent successfully, False otherwise. + """ + pass + + @abstractmethod + def send_raw(self, payload: dict) -> bool: + """Send a raw payload to the webhook. + + Args: + payload: Raw payload in provider-specific format. + + Returns: + True if sent successfully, False otherwise. + """ + pass + + +class SlackNotifier(Notifier): + """Slack webhook notifier.""" + + # Color mapping for different levels + LEVEL_COLORS = { + NotificationLevel.INFO: "#36a64f", # Green + NotificationLevel.WARNING: "#ffcc00", # Yellow + NotificationLevel.ERROR: "#ff6600", # Orange + NotificationLevel.CRITICAL: "#cc0000", # Red + } + + LEVEL_EMOJIS = { + NotificationLevel.INFO: ":information_source:", + NotificationLevel.WARNING: ":warning:", + NotificationLevel.ERROR: ":x:", + NotificationLevel.CRITICAL: ":rotating_light:", + } + + def __init__( + self, + webhook_url: str | None = None, + channel: str | None = None, + username: str = "OpenRabbit", + icon_emoji: str = ":robot_face:", + timeout: int = 10, + ): + """Initialize Slack notifier. + + Args: + webhook_url: Slack incoming webhook URL. + Defaults to SLACK_WEBHOOK_URL env var. + channel: Override channel (optional). + username: Bot username to display. + icon_emoji: Bot icon emoji. + timeout: Request timeout in seconds. + """ + self.webhook_url = webhook_url or os.environ.get("SLACK_WEBHOOK_URL", "") + self.channel = channel + self.username = username + self.icon_emoji = icon_emoji + self.timeout = timeout + self.logger = logging.getLogger(__name__) + + def send(self, message: NotificationMessage) -> bool: + """Send a notification to Slack.""" + if not self.webhook_url: + self.logger.warning("Slack webhook URL not configured") + return False + + # Build attachment + attachment = { + "color": self.LEVEL_COLORS.get(message.level, "#36a64f"), + "title": f"{self.LEVEL_EMOJIS.get(message.level, '')} {message.title}", + "text": message.message, + "mrkdwn_in": ["text", "fields"], + } + + if message.url: + attachment["title_link"] = message.url + + if message.fields: + attachment["fields"] = [ + {"title": k, "value": v, "short": len(v) < 40} + for k, v in message.fields.items() + ] + + if message.footer: + attachment["footer"] = message.footer + attachment["footer_icon"] = "https://github.com/favicon.ico" + + payload = { + "username": self.username, + "icon_emoji": self.icon_emoji, + "attachments": [attachment], + } + + if self.channel: + payload["channel"] = self.channel + + return self.send_raw(payload) + + def send_raw(self, payload: dict) -> bool: + """Send raw payload to Slack webhook.""" + if not self.webhook_url: + return False + + try: + response = requests.post( + self.webhook_url, + json=payload, + timeout=self.timeout, + ) + response.raise_for_status() + return True + except requests.RequestException as e: + self.logger.error(f"Failed to send Slack notification: {e}") + return False + + +class DiscordNotifier(Notifier): + """Discord webhook notifier.""" + + # Color mapping for different levels (Discord uses decimal colors) + LEVEL_COLORS = { + NotificationLevel.INFO: 3066993, # Green + NotificationLevel.WARNING: 16776960, # Yellow + NotificationLevel.ERROR: 16744448, # Orange + NotificationLevel.CRITICAL: 13369344, # Red + } + + def __init__( + self, + webhook_url: str | None = None, + username: str = "OpenRabbit", + avatar_url: str | None = None, + timeout: int = 10, + ): + """Initialize Discord notifier. + + Args: + webhook_url: Discord webhook URL. + Defaults to DISCORD_WEBHOOK_URL env var. + username: Bot username to display. + avatar_url: Bot avatar URL. + timeout: Request timeout in seconds. + """ + self.webhook_url = webhook_url or os.environ.get("DISCORD_WEBHOOK_URL", "") + self.username = username + self.avatar_url = avatar_url + self.timeout = timeout + self.logger = logging.getLogger(__name__) + + def send(self, message: NotificationMessage) -> bool: + """Send a notification to Discord.""" + if not self.webhook_url: + self.logger.warning("Discord webhook URL not configured") + return False + + # Build embed + embed = { + "title": message.title, + "description": message.message, + "color": self.LEVEL_COLORS.get(message.level, 3066993), + } + + if message.url: + embed["url"] = message.url + + if message.fields: + embed["fields"] = [ + {"name": k, "value": v, "inline": len(v) < 40} + for k, v in message.fields.items() + ] + + if message.footer: + embed["footer"] = {"text": message.footer} + + payload = { + "username": self.username, + "embeds": [embed], + } + + if self.avatar_url: + payload["avatar_url"] = self.avatar_url + + return self.send_raw(payload) + + def send_raw(self, payload: dict) -> bool: + """Send raw payload to Discord webhook.""" + if not self.webhook_url: + return False + + try: + response = requests.post( + self.webhook_url, + json=payload, + timeout=self.timeout, + ) + response.raise_for_status() + return True + except requests.RequestException as e: + self.logger.error(f"Failed to send Discord notification: {e}") + return False + + +class WebhookNotifier(Notifier): + """Generic webhook notifier for custom integrations.""" + + def __init__( + self, + webhook_url: str, + headers: dict[str, str] | None = None, + timeout: int = 10, + ): + """Initialize generic webhook notifier. + + Args: + webhook_url: Webhook URL. + headers: Custom headers to include. + timeout: Request timeout in seconds. + """ + self.webhook_url = webhook_url + self.headers = headers or {"Content-Type": "application/json"} + self.timeout = timeout + self.logger = logging.getLogger(__name__) + + def send(self, message: NotificationMessage) -> bool: + """Send a notification as JSON payload.""" + payload = { + "title": message.title, + "message": message.message, + "level": message.level.value, + "fields": message.fields or {}, + "url": message.url, + "footer": message.footer, + } + return self.send_raw(payload) + + def send_raw(self, payload: dict) -> bool: + """Send raw payload to webhook.""" + try: + response = requests.post( + self.webhook_url, + json=payload, + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + return True + except requests.RequestException as e: + self.logger.error(f"Failed to send webhook notification: {e}") + return False + + +class NotifierFactory: + """Factory for creating notifier instances from config.""" + + @staticmethod + def create_from_config(config: dict) -> list[Notifier]: + """Create notifier instances from configuration. + + Args: + config: Configuration dictionary with 'notifications' section. + + Returns: + List of configured notifier instances. + """ + notifiers = [] + notifications_config = config.get("notifications", {}) + + if not notifications_config.get("enabled", False): + return notifiers + + # Slack + slack_config = notifications_config.get("slack", {}) + if slack_config.get("enabled", False): + webhook_url = slack_config.get("webhook_url") or os.environ.get( + "SLACK_WEBHOOK_URL" + ) + if webhook_url: + notifiers.append( + SlackNotifier( + webhook_url=webhook_url, + channel=slack_config.get("channel"), + username=slack_config.get("username", "OpenRabbit"), + ) + ) + + # Discord + discord_config = notifications_config.get("discord", {}) + if discord_config.get("enabled", False): + webhook_url = discord_config.get("webhook_url") or os.environ.get( + "DISCORD_WEBHOOK_URL" + ) + if webhook_url: + notifiers.append( + DiscordNotifier( + webhook_url=webhook_url, + username=discord_config.get("username", "OpenRabbit"), + avatar_url=discord_config.get("avatar_url"), + ) + ) + + # Generic webhooks + webhooks_config = notifications_config.get("webhooks", []) + for webhook in webhooks_config: + if webhook.get("enabled", True) and webhook.get("url"): + notifiers.append( + WebhookNotifier( + webhook_url=webhook["url"], + headers=webhook.get("headers"), + ) + ) + + return notifiers + + @staticmethod + def should_notify( + level: NotificationLevel, + config: dict, + ) -> bool: + """Check if notification should be sent based on level threshold. + + Args: + level: Notification level. + config: Configuration dictionary. + + Returns: + True if notification should be sent. + """ + notifications_config = config.get("notifications", {}) + if not notifications_config.get("enabled", False): + return False + + threshold = notifications_config.get("threshold", "warning") + level_order = ["info", "warning", "error", "critical"] + + try: + threshold_idx = level_order.index(threshold) + level_idx = level_order.index(level.value) + return level_idx >= threshold_idx + except ValueError: + return True + + +class NotificationService: + """High-level notification service for sending alerts.""" + + def __init__(self, config: dict): + """Initialize notification service. + + Args: + config: Configuration dictionary. + """ + self.config = config + self.notifiers = NotifierFactory.create_from_config(config) + self.logger = logging.getLogger(__name__) + + def notify(self, message: NotificationMessage) -> bool: + """Send notification to all configured notifiers. + + Args: + message: Notification message. + + Returns: + True if at least one notification succeeded. + """ + if not NotifierFactory.should_notify(message.level, self.config): + return True # Not an error, just below threshold + + if not self.notifiers: + self.logger.debug("No notifiers configured") + return True + + success = False + for notifier in self.notifiers: + try: + if notifier.send(message): + success = True + except Exception as e: + self.logger.error(f"Notifier failed: {e}") + + return success + + def notify_security_finding( + self, + repo: str, + pr_number: int, + finding: dict, + pr_url: str | None = None, + ) -> bool: + """Send notification for a security finding. + + Args: + repo: Repository name (owner/repo). + pr_number: Pull request number. + finding: Security finding dict with severity, description, etc. + pr_url: URL to the pull request. + + Returns: + True if notification succeeded. + """ + severity = finding.get("severity", "MEDIUM").upper() + level_map = { + "HIGH": NotificationLevel.CRITICAL, + "MEDIUM": NotificationLevel.WARNING, + "LOW": NotificationLevel.INFO, + } + level = level_map.get(severity, NotificationLevel.WARNING) + + message = NotificationMessage( + title=f"Security Finding in {repo} PR #{pr_number}", + message=finding.get("description", "Security issue detected"), + level=level, + fields={ + "Severity": severity, + "Category": finding.get("category", "Unknown"), + "File": finding.get("file", "N/A"), + "Line": str(finding.get("line", "N/A")), + }, + url=pr_url, + footer=f"CWE: {finding.get('cwe', 'N/A')}", + ) + + return self.notify(message) + + def notify_review_complete( + self, + repo: str, + pr_number: int, + summary: dict, + pr_url: str | None = None, + ) -> bool: + """Send notification for completed review. + + Args: + repo: Repository name (owner/repo). + pr_number: Pull request number. + summary: Review summary dict. + pr_url: URL to the pull request. + + Returns: + True if notification succeeded. + """ + recommendation = summary.get("recommendation", "COMMENT") + level_map = { + "APPROVE": NotificationLevel.INFO, + "COMMENT": NotificationLevel.INFO, + "REQUEST_CHANGES": NotificationLevel.WARNING, + } + level = level_map.get(recommendation, NotificationLevel.INFO) + + high_issues = summary.get("high_severity_count", 0) + if high_issues > 0: + level = NotificationLevel.ERROR + + message = NotificationMessage( + title=f"Review Complete: {repo} PR #{pr_number}", + message=summary.get("summary", "AI review completed"), + level=level, + fields={ + "Recommendation": recommendation, + "High Severity": str(high_issues), + "Medium Severity": str(summary.get("medium_severity_count", 0)), + "Files Reviewed": str(summary.get("files_reviewed", 0)), + }, + url=pr_url, + footer="OpenRabbit AI Review", + ) + + return self.notify(message) + + def notify_error( + self, + repo: str, + error: str, + context: str | None = None, + ) -> bool: + """Send notification for an error. + + Args: + repo: Repository name. + error: Error message. + context: Additional context. + + Returns: + True if notification succeeded. + """ + message = NotificationMessage( + title=f"Error in {repo}", + message=error, + level=NotificationLevel.ERROR, + fields={"Context": context} if context else None, + footer="OpenRabbit Error Alert", + ) + + return self.notify(message) diff --git a/tools/ai-review/security/sast_scanner.py b/tools/ai-review/security/sast_scanner.py new file mode 100644 index 0000000..6b286eb --- /dev/null +++ b/tools/ai-review/security/sast_scanner.py @@ -0,0 +1,431 @@ +"""SAST Scanner Integration + +Integrates with external SAST tools like Bandit and Semgrep +to provide comprehensive security analysis. +""" + +import json +import logging +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class SASTFinding: + """A finding from a SAST tool.""" + + tool: str + rule_id: str + severity: str # CRITICAL, HIGH, MEDIUM, LOW + file: str + line: int + message: str + code_snippet: str | None = None + cwe: str | None = None + owasp: str | None = None + fix_recommendation: str | None = None + + +@dataclass +class SASTReport: + """Combined report from all SAST tools.""" + + total_findings: int + findings_by_severity: dict[str, int] + findings_by_tool: dict[str, int] + findings: list[SASTFinding] + tools_run: list[str] + errors: list[str] = field(default_factory=list) + + +class SASTScanner: + """Aggregator for multiple SAST tools.""" + + def __init__(self, config: dict | None = None): + """Initialize the SAST scanner. + + Args: + config: Configuration dictionary with tool settings. + """ + self.config = config or {} + self.logger = logging.getLogger(self.__class__.__name__) + + def scan_directory(self, path: str) -> SASTReport: + """Scan a directory with all enabled SAST tools. + + Args: + path: Path to the directory to scan. + + Returns: + Combined SASTReport from all tools. + """ + all_findings = [] + tools_run = [] + errors = [] + + sast_config = self.config.get("security", {}).get("sast", {}) + + # Run Bandit (Python) + if sast_config.get("bandit", True): + if self._is_tool_available("bandit"): + try: + findings = self._run_bandit(path) + all_findings.extend(findings) + tools_run.append("bandit") + except Exception as e: + errors.append(f"Bandit error: {e}") + else: + self.logger.debug("Bandit not installed, skipping") + + # Run Semgrep + if sast_config.get("semgrep", True): + if self._is_tool_available("semgrep"): + try: + findings = self._run_semgrep(path) + all_findings.extend(findings) + tools_run.append("semgrep") + except Exception as e: + errors.append(f"Semgrep error: {e}") + else: + self.logger.debug("Semgrep not installed, skipping") + + # Run Trivy (if enabled for filesystem scanning) + if sast_config.get("trivy", False): + if self._is_tool_available("trivy"): + try: + findings = self._run_trivy(path) + all_findings.extend(findings) + tools_run.append("trivy") + except Exception as e: + errors.append(f"Trivy error: {e}") + else: + self.logger.debug("Trivy not installed, skipping") + + # Calculate statistics + by_severity = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0} + by_tool = {} + + for finding in all_findings: + sev = finding.severity.upper() + if sev in by_severity: + by_severity[sev] += 1 + tool = finding.tool + by_tool[tool] = by_tool.get(tool, 0) + 1 + + return SASTReport( + total_findings=len(all_findings), + findings_by_severity=by_severity, + findings_by_tool=by_tool, + findings=all_findings, + tools_run=tools_run, + errors=errors, + ) + + def scan_content(self, content: str, filename: str) -> list[SASTFinding]: + """Scan file content with SAST tools. + + Args: + content: File content to scan. + filename: Name of the file (for language detection). + + Returns: + List of SASTFinding objects. + """ + # Create temporary file for scanning + with tempfile.NamedTemporaryFile( + mode="w", + suffix=os.path.splitext(filename)[1], + delete=False, + ) as f: + f.write(content) + temp_path = f.name + + try: + report = self.scan_directory(os.path.dirname(temp_path)) + # Filter findings for our specific file + findings = [ + f + for f in report.findings + if os.path.basename(f.file) == os.path.basename(temp_path) + ] + # Update file path to original filename + for finding in findings: + finding.file = filename + return findings + finally: + os.unlink(temp_path) + + def scan_diff(self, diff: str) -> list[SASTFinding]: + """Scan a diff for security issues. + + Only scans added/modified lines. + + Args: + diff: Git diff content. + + Returns: + List of SASTFinding objects. + """ + findings = [] + + # Parse diff and extract added content per file + files_content = {} + current_file = None + current_content = [] + + for line in diff.splitlines(): + if line.startswith("diff --git"): + if current_file and current_content: + files_content[current_file] = "\n".join(current_content) + current_file = None + current_content = [] + # Extract filename + match = line.split(" b/") + if len(match) > 1: + current_file = match[1] + elif line.startswith("+") and not line.startswith("+++"): + if current_file: + current_content.append(line[1:]) # Remove + prefix + + # Don't forget last file + if current_file and current_content: + files_content[current_file] = "\n".join(current_content) + + # Scan each file's content + for filename, content in files_content.items(): + if content.strip(): + file_findings = self.scan_content(content, filename) + findings.extend(file_findings) + + return findings + + def _is_tool_available(self, tool: str) -> bool: + """Check if a tool is installed and available.""" + return shutil.which(tool) is not None + + def _run_bandit(self, path: str) -> list[SASTFinding]: + """Run Bandit security scanner. + + Args: + path: Path to scan. + + Returns: + List of SASTFinding objects. + """ + findings = [] + + try: + result = subprocess.run( + [ + "bandit", + "-r", + path, + "-f", + "json", + "-ll", # Only high and medium severity + "--quiet", + ], + capture_output=True, + text=True, + timeout=120, + ) + + if result.stdout: + data = json.loads(result.stdout) + + for issue in data.get("results", []): + severity = issue.get("issue_severity", "MEDIUM").upper() + + findings.append( + SASTFinding( + tool="bandit", + rule_id=issue.get("test_id", ""), + severity=severity, + file=issue.get("filename", ""), + line=issue.get("line_number", 0), + message=issue.get("issue_text", ""), + code_snippet=issue.get("code", ""), + cwe=f"CWE-{issue.get('issue_cwe', {}).get('id', '')}" + if issue.get("issue_cwe") + else None, + fix_recommendation=issue.get("more_info", ""), + ) + ) + + except subprocess.TimeoutExpired: + self.logger.warning("Bandit scan timed out") + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to parse Bandit output: {e}") + except Exception as e: + self.logger.warning(f"Bandit scan failed: {e}") + + return findings + + def _run_semgrep(self, path: str) -> list[SASTFinding]: + """Run Semgrep security scanner. + + Args: + path: Path to scan. + + Returns: + List of SASTFinding objects. + """ + findings = [] + + # Get Semgrep config from settings + sast_config = self.config.get("security", {}).get("sast", {}) + semgrep_rules = sast_config.get("semgrep_rules", "p/security-audit") + + try: + result = subprocess.run( + [ + "semgrep", + "--config", + semgrep_rules, + "--json", + "--quiet", + path, + ], + capture_output=True, + text=True, + timeout=180, + ) + + if result.stdout: + data = json.loads(result.stdout) + + for finding in data.get("results", []): + # Map Semgrep severity to our scale + sev_map = { + "ERROR": "HIGH", + "WARNING": "MEDIUM", + "INFO": "LOW", + } + severity = sev_map.get( + finding.get("extra", {}).get("severity", "WARNING"), "MEDIUM" + ) + + metadata = finding.get("extra", {}).get("metadata", {}) + + findings.append( + SASTFinding( + tool="semgrep", + rule_id=finding.get("check_id", ""), + severity=severity, + file=finding.get("path", ""), + line=finding.get("start", {}).get("line", 0), + message=finding.get("extra", {}).get("message", ""), + code_snippet=finding.get("extra", {}).get("lines", ""), + cwe=metadata.get("cwe", [None])[0] + if metadata.get("cwe") + else None, + owasp=metadata.get("owasp", [None])[0] + if metadata.get("owasp") + else None, + fix_recommendation=metadata.get("fix", ""), + ) + ) + + except subprocess.TimeoutExpired: + self.logger.warning("Semgrep scan timed out") + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to parse Semgrep output: {e}") + except Exception as e: + self.logger.warning(f"Semgrep scan failed: {e}") + + return findings + + def _run_trivy(self, path: str) -> list[SASTFinding]: + """Run Trivy filesystem scanner. + + Args: + path: Path to scan. + + Returns: + List of SASTFinding objects. + """ + findings = [] + + try: + result = subprocess.run( + [ + "trivy", + "fs", + "--format", + "json", + "--security-checks", + "vuln,secret,config", + path, + ], + capture_output=True, + text=True, + timeout=180, + ) + + if result.stdout: + data = json.loads(result.stdout) + + for result_item in data.get("Results", []): + target = result_item.get("Target", "") + + # Process vulnerabilities + for vuln in result_item.get("Vulnerabilities", []): + severity = vuln.get("Severity", "MEDIUM").upper() + + findings.append( + SASTFinding( + tool="trivy", + rule_id=vuln.get("VulnerabilityID", ""), + severity=severity, + file=target, + line=0, + message=vuln.get("Title", ""), + cwe=vuln.get("CweIDs", [None])[0] + if vuln.get("CweIDs") + else None, + fix_recommendation=f"Upgrade to {vuln.get('FixedVersion', 'latest')}" + if vuln.get("FixedVersion") + else None, + ) + ) + + # Process secrets + for secret in result_item.get("Secrets", []): + findings.append( + SASTFinding( + tool="trivy", + rule_id=secret.get("RuleID", ""), + severity="HIGH", + file=target, + line=secret.get("StartLine", 0), + message=f"Secret detected: {secret.get('Title', '')}", + code_snippet=secret.get("Match", ""), + ) + ) + + except subprocess.TimeoutExpired: + self.logger.warning("Trivy scan timed out") + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to parse Trivy output: {e}") + except Exception as e: + self.logger.warning(f"Trivy scan failed: {e}") + + return findings + + +def get_sast_scanner(config: dict | None = None) -> SASTScanner: + """Get a configured SAST scanner instance. + + Args: + config: Configuration dictionary. + + Returns: + Configured SASTScanner instance. + """ + return SASTScanner(config=config) diff --git a/tools/ai-review/utils/__init__.py b/tools/ai-review/utils/__init__.py index a77ba35..b10333c 100644 --- a/tools/ai-review/utils/__init__.py +++ b/tools/ai-review/utils/__init__.py @@ -1,9 +1,14 @@ """Utility Functions Package This package contains utility functions for webhook sanitization, -safe event dispatching, and other helper functions. +safe event dispatching, ignore patterns, and other helper functions. """ +from utils.ignore_patterns import ( + IgnorePatterns, + get_ignore_patterns, + should_ignore_file, +) from utils.webhook_sanitizer import ( extract_minimal_context, sanitize_webhook_data, @@ -16,4 +21,7 @@ __all__ = [ "validate_repository_format", "extract_minimal_context", "validate_webhook_signature", + "IgnorePatterns", + "get_ignore_patterns", + "should_ignore_file", ] diff --git a/tools/ai-review/utils/ignore_patterns.py b/tools/ai-review/utils/ignore_patterns.py new file mode 100644 index 0000000..0e3541b --- /dev/null +++ b/tools/ai-review/utils/ignore_patterns.py @@ -0,0 +1,358 @@ +"""AI Review Ignore Patterns + +Provides .gitignore-style pattern matching for excluding files from AI review. +Reads patterns from .ai-reviewignore files in the repository. +""" + +import fnmatch +import os +import re +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class IgnoreRule: + """A single ignore rule.""" + + pattern: str + negation: bool = False + directory_only: bool = False + anchored: bool = False + regex: re.Pattern = None + + def __post_init__(self): + """Compile the pattern to regex.""" + self.regex = self._compile_pattern() + + def _compile_pattern(self) -> re.Pattern: + """Convert gitignore pattern to regex.""" + pattern = self.pattern + + # Handle directory-only patterns + if pattern.endswith("/"): + pattern = pattern[:-1] + self.directory_only = True + + # Handle anchored patterns (starting with /) + if pattern.startswith("/"): + pattern = pattern[1:] + self.anchored = True + + # Escape special regex characters except * and ? + regex_pattern = "" + i = 0 + while i < len(pattern): + c = pattern[i] + if c == "*": + if i + 1 < len(pattern) and pattern[i + 1] == "*": + # ** matches everything including / + if i + 2 < len(pattern) and pattern[i + 2] == "/": + regex_pattern += "(?:.*/)?" + i += 3 + continue + else: + regex_pattern += ".*" + i += 2 + continue + else: + # * matches everything except / + regex_pattern += "[^/]*" + elif c == "?": + regex_pattern += "[^/]" + elif c == "[": + # Character class + j = i + 1 + if j < len(pattern) and pattern[j] == "!": + regex_pattern += "[^" + j += 1 + else: + regex_pattern += "[" + while j < len(pattern) and pattern[j] != "]": + regex_pattern += pattern[j] + j += 1 + if j < len(pattern): + regex_pattern += "]" + i = j + elif c in ".^$+{}|()": + regex_pattern += "\\" + c + else: + regex_pattern += c + i += 1 + + # Anchor pattern + if self.anchored: + regex_pattern = "^" + regex_pattern + else: + regex_pattern = "(?:^|/)" + regex_pattern + + # Match to end or as directory prefix + regex_pattern += "(?:$|/)" + + return re.compile(regex_pattern) + + def matches(self, path: str, is_directory: bool = False) -> bool: + """Check if a path matches this rule. + + Args: + path: Relative path to check. + is_directory: Whether the path is a directory. + + Returns: + True if the path matches. + """ + if self.directory_only and not is_directory: + return False + + # Normalize path + path = path.replace("\\", "/") + if not path.startswith("/"): + path = "/" + path + + return bool(self.regex.search(path)) + + +class IgnorePatterns: + """Manages .ai-reviewignore patterns for a repository.""" + + DEFAULT_PATTERNS = [ + # Version control + ".git/", + ".svn/", + ".hg/", + # Dependencies + "node_modules/", + "vendor/", + "venv/", + ".venv/", + "__pycache__/", + "*.pyc", + # Build outputs + "dist/", + "build/", + "out/", + "target/", + "*.min.js", + "*.min.css", + "*.bundle.js", + # IDE/Editor + ".idea/", + ".vscode/", + "*.swp", + "*.swo", + # Generated files + "*.lock", + "package-lock.json", + "yarn.lock", + "poetry.lock", + "Pipfile.lock", + "Cargo.lock", + "go.sum", + # Binary files + "*.exe", + "*.dll", + "*.so", + "*.dylib", + "*.bin", + "*.o", + "*.a", + # Media files + "*.png", + "*.jpg", + "*.jpeg", + "*.gif", + "*.svg", + "*.ico", + "*.mp3", + "*.mp4", + "*.wav", + "*.pdf", + # Archives + "*.zip", + "*.tar", + "*.gz", + "*.rar", + "*.7z", + # Large data files + "*.csv", + "*.json.gz", + "*.sql", + "*.sqlite", + "*.db", + ] + + def __init__( + self, + repo_root: str | None = None, + ignore_file: str = ".ai-reviewignore", + use_defaults: bool = True, + ): + """Initialize ignore patterns. + + Args: + repo_root: Repository root path. + ignore_file: Name of ignore file to read. + use_defaults: Whether to include default patterns. + """ + self.repo_root = repo_root or os.getcwd() + self.ignore_file = ignore_file + self.rules: list[IgnoreRule] = [] + + # Load default patterns + if use_defaults: + for pattern in self.DEFAULT_PATTERNS: + self._add_pattern(pattern) + + # Load patterns from ignore file + self._load_ignore_file() + + def _load_ignore_file(self): + """Load patterns from .ai-reviewignore file.""" + ignore_path = os.path.join(self.repo_root, self.ignore_file) + + if not os.path.exists(ignore_path): + return + + try: + with open(ignore_path) as f: + for line in f: + line = line.rstrip("\n\r") + + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + + self._add_pattern(line) + except Exception: + pass # Ignore errors reading the file + + def _add_pattern(self, pattern: str): + """Add a pattern to the rules list. + + Args: + pattern: Pattern string (gitignore format). + """ + # Check for negation + negation = False + if pattern.startswith("!"): + negation = True + pattern = pattern[1:] + + if not pattern: + return + + self.rules.append(IgnoreRule(pattern=pattern, negation=negation)) + + def is_ignored(self, path: str, is_directory: bool = False) -> bool: + """Check if a path should be ignored. + + Args: + path: Relative path to check. + is_directory: Whether the path is a directory. + + Returns: + True if the path should be ignored. + """ + # Normalize path + path = path.replace("\\", "/").lstrip("/") + + # Check each rule in order (later rules override earlier ones) + ignored = False + for rule in self.rules: + if rule.matches(path, is_directory): + ignored = not rule.negation + + return ignored + + def filter_paths(self, paths: list[str]) -> list[str]: + """Filter a list of paths, removing ignored ones. + + Args: + paths: List of relative paths. + + Returns: + Filtered list of paths. + """ + return [p for p in paths if not self.is_ignored(p)] + + def filter_diff_files(self, files: list[dict]) -> list[dict]: + """Filter diff file objects, removing ignored ones. + + Args: + files: List of file dicts with 'filename' or 'path' key. + + Returns: + Filtered list of file dicts. + """ + result = [] + for f in files: + path = f.get("filename") or f.get("path") or f.get("name", "") + if not self.is_ignored(path): + result.append(f) + return result + + def should_review_file(self, filename: str) -> bool: + """Check if a file should be reviewed. + + Args: + filename: File path to check. + + Returns: + True if the file should be reviewed. + """ + return not self.is_ignored(filename) + + @classmethod + def from_config( + cls, config: dict, repo_root: str | None = None + ) -> "IgnorePatterns": + """Create IgnorePatterns from config. + + Args: + config: Configuration dictionary. + repo_root: Repository root path. + + Returns: + IgnorePatterns instance. + """ + ignore_config = config.get("ignore", {}) + use_defaults = ignore_config.get("use_defaults", True) + ignore_file = ignore_config.get("file", ".ai-reviewignore") + + instance = cls( + repo_root=repo_root, + ignore_file=ignore_file, + use_defaults=use_defaults, + ) + + # Add any extra patterns from config + extra_patterns = ignore_config.get("patterns", []) + for pattern in extra_patterns: + instance._add_pattern(pattern) + + return instance + + +def get_ignore_patterns(repo_root: str | None = None) -> IgnorePatterns: + """Get ignore patterns for a repository. + + Args: + repo_root: Repository root path. + + Returns: + IgnorePatterns instance. + """ + return IgnorePatterns(repo_root=repo_root) + + +def should_ignore_file(filename: str, repo_root: str | None = None) -> bool: + """Quick check if a file should be ignored. + + Args: + filename: File path to check. + repo_root: Repository root path. + + Returns: + True if the file should be ignored. + """ + return get_ignore_patterns(repo_root).is_ignored(filename)