feat: harden gateway with policy engine, secure tools, and governance docs

This commit is contained in:
2026-02-14 16:05:56 +01:00
parent e17d34e6d7
commit 5969892af3
55 changed files with 4711 additions and 1587 deletions

View File

@@ -1,3 +1,3 @@
"""AegisGitea MCP - Security-first MCP server for self-hosted Gitea."""
"""AegisGitea MCP - Security-first MCP gateway for self-hosted Gitea."""
__version__ = "0.1.0"
__version__ = "0.2.0"

View File

@@ -1,50 +1,110 @@
"""Audit logging system for MCP tool invocations."""
"""Tamper-evident audit logging for MCP tool invocations and security events."""
from __future__ import annotations
import hashlib
import json
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional
import structlog
from typing import Any
from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.request_context import get_request_id
from aegis_gitea_mcp.security import sanitize_data
_GENESIS_HASH = "GENESIS"
class AuditLogger:
"""Audit logger for tracking all MCP tool invocations."""
"""Append-only tamper-evident audit logger.
def __init__(self, log_path: Optional[Path] = None) -> None:
Every line in the audit file is hash-chained to the previous line. This makes
post-hoc modifications detectable by integrity validation.
"""
def __init__(self, log_path: Path | None = None) -> None:
"""Initialize audit logger.
Args:
log_path: Path to audit log file (defaults to config value)
log_path: Path to audit log file (defaults to config value).
"""
self.settings = get_settings()
self.log_path = log_path or self.settings.audit_log_path
# Ensure log directory exists
self.log_path.parent.mkdir(parents=True, exist_ok=True)
self._log_file = self._get_log_file()
self._lock = threading.Lock()
self._log_file = open(self.log_path, "a+", encoding="utf-8")
self._last_hash = self._read_last_hash()
# Configure structlog for audit logging
structlog.configure(
processors=[
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.dict_tracebacks,
structlog.processors.JSONRenderer(),
],
wrapper_class=structlog.BoundLogger,
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(file=self._log_file),
cache_logger_on_first_use=True,
def _read_last_hash(self) -> str:
"""Read the previous hash from the last log entry."""
try:
entries = self.log_path.read_text(encoding="utf-8").splitlines()
except FileNotFoundError:
return _GENESIS_HASH
if not entries:
return _GENESIS_HASH
last_line = entries[-1]
try:
payload = json.loads(last_line)
entry_hash = payload.get("entry_hash")
if isinstance(entry_hash, str) and entry_hash:
return entry_hash
except json.JSONDecodeError:
pass
# Corrupt trailing line forces a new chain segment.
return _GENESIS_HASH
@staticmethod
def _compute_entry_hash(prev_hash: str, payload: dict[str, Any]) -> str:
"""Compute deterministic hash for an audit entry payload."""
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
digest = hashlib.sha256(f"{prev_hash}:{canonical}".encode()).hexdigest()
return digest
def _append_entry(self, event_type: str, payload: dict[str, Any]) -> str:
"""Append a hash-chained entry to audit log.
Args:
event_type: Event category.
payload: Event payload data.
Returns:
Correlation ID for the appended entry.
"""
correlation_id = payload.get("correlation_id")
if not isinstance(correlation_id, str) or not correlation_id:
correlation_id = str(uuid.uuid4())
payload["correlation_id"] = correlation_id
# Security decision: sanitize all audit payloads before persistence.
mode = "mask" if self.settings.secret_detection_mode != "off" else "off"
safe_payload = payload if mode == "off" else sanitize_data(payload, mode=mode)
base_entry: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"event_type": event_type,
"request_id": get_request_id(),
"payload": safe_payload,
"prev_hash": self._last_hash,
}
entry_hash = self._compute_entry_hash(self._last_hash, base_entry)
base_entry["entry_hash"] = entry_hash
serialized = json.dumps(
base_entry, sort_keys=True, separators=(",", ":"), ensure_ascii=True
)
self.logger = structlog.get_logger("audit")
with self._lock:
self._log_file.write(serialized + "\n")
self._log_file.flush()
self._last_hash = entry_hash
def _get_log_file(self) -> Any:
"""Get file handle for audit log."""
return open(self.log_path, "a", encoding="utf-8")
return correlation_id
def close(self) -> None:
"""Close open audit log resources."""
@@ -56,111 +116,108 @@ class AuditLogger:
def log_tool_invocation(
self,
tool_name: str,
repository: Optional[str] = None,
target: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
correlation_id: Optional[str] = None,
repository: str | None = None,
target: str | None = None,
params: dict[str, Any] | None = None,
correlation_id: str | None = None,
result_status: str = "pending",
error: Optional[str] = None,
error: str | None = None,
) -> str:
"""Log an MCP tool invocation.
Args:
tool_name: Name of the MCP tool being invoked
repository: Repository identifier (owner/repo)
target: Target path, commit hash, issue number, etc.
params: Additional parameters passed to the tool
correlation_id: Request correlation ID (auto-generated if not provided)
result_status: Status of the invocation (pending, success, error)
error: Error message if invocation failed
Returns:
Correlation ID for this invocation
"""
if correlation_id is None:
correlation_id = str(uuid.uuid4())
audit_entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"correlation_id": correlation_id,
"""Log an MCP tool invocation."""
payload: dict[str, Any] = {
"correlation_id": correlation_id or str(uuid.uuid4()),
"tool_name": tool_name,
"repository": repository,
"target": target,
"params": params or {},
"result_status": result_status,
}
if error:
audit_entry["error"] = error
self.logger.info("tool_invocation", **audit_entry)
return correlation_id
payload["error"] = error
return self._append_entry("tool_invocation", payload)
def log_access_denied(
self,
tool_name: str,
repository: Optional[str] = None,
repository: str | None = None,
reason: str = "unauthorized",
correlation_id: Optional[str] = None,
correlation_id: str | None = None,
) -> str:
"""Log an access denial event.
Args:
tool_name: Name of the tool that was denied access
repository: Repository identifier that access was denied to
reason: Reason for denial
correlation_id: Request correlation ID
Returns:
Correlation ID for this event
"""
if correlation_id is None:
correlation_id = str(uuid.uuid4())
self.logger.warning(
"""Log an access denial event."""
return self._append_entry(
"access_denied",
timestamp=datetime.now(timezone.utc).isoformat(),
correlation_id=correlation_id,
tool_name=tool_name,
repository=repository,
reason=reason,
{
"correlation_id": correlation_id or str(uuid.uuid4()),
"tool_name": tool_name,
"repository": repository,
"reason": reason,
},
)
return correlation_id
def log_security_event(
self,
event_type: str,
description: str,
severity: str = "medium",
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> str:
"""Log a security-related event.
Args:
event_type: Type of security event (e.g., rate_limit, invalid_input)
description: Human-readable description of the event
severity: Severity level (low, medium, high, critical)
metadata: Additional metadata about the event
Returns:
Correlation ID for this event
"""
correlation_id = str(uuid.uuid4())
self.logger.warning(
"""Log a security event."""
return self._append_entry(
"security_event",
timestamp=datetime.now(timezone.utc).isoformat(),
correlation_id=correlation_id,
event_type=event_type,
description=description,
severity=severity,
metadata=metadata or {},
{
"event_type": event_type,
"description": description,
"severity": severity,
"metadata": metadata or {},
},
)
return correlation_id
# Global audit logger instance
_audit_logger: Optional[AuditLogger] = None
def validate_audit_log_integrity(log_path: Path) -> tuple[bool, list[str]]:
"""Validate audit log hash chain integrity.
Args:
log_path: Path to an audit log file.
Returns:
Tuple of (is_valid, errors).
"""
if not log_path.exists():
return True, []
errors: list[str] = []
prev_hash = _GENESIS_HASH
for line_number, raw_line in enumerate(
log_path.read_text(encoding="utf-8").splitlines(), start=1
):
if not raw_line.strip():
continue
try:
entry = json.loads(raw_line)
except json.JSONDecodeError:
errors.append(f"line {line_number}: invalid JSON")
continue
line_prev_hash = entry.get("prev_hash")
entry_hash = entry.get("entry_hash")
if line_prev_hash != prev_hash:
errors.append(f"line {line_number}: prev_hash mismatch")
# Recompute hash after removing the stored entry hash.
cloned = dict(entry)
cloned.pop("entry_hash", None)
expected_hash = AuditLogger._compute_entry_hash(prev_hash, cloned)
if entry_hash != expected_hash:
errors.append(f"line {line_number}: entry_hash mismatch")
prev_hash = expected_hash
else:
prev_hash = str(entry_hash)
return len(errors) == 0, errors
_audit_logger: AuditLogger | None = None
def get_audit_logger() -> AuditLogger:

View File

@@ -1,10 +1,11 @@
"""Authentication module for MCP server API key validation."""
from __future__ import annotations
import hashlib
import hmac
import secrets
from datetime import datetime, timezone
from typing import Optional, Tuple
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.config import get_settings
@@ -13,70 +14,43 @@ from aegis_gitea_mcp.config import get_settings
class AuthenticationError(Exception):
"""Raised when authentication fails."""
pass
class APIKeyValidator:
"""Validates API keys for MCP server access."""
"""Validate API keys for MCP server access."""
def __init__(self) -> None:
"""Initialize API key validator."""
"""Initialize API key validator state."""
self.settings = get_settings()
self.audit = get_audit_logger()
self._failed_attempts: dict[str, list[datetime]] = {}
def _constant_time_compare(self, a: str, b: str) -> bool:
"""Compare two strings in constant time to prevent timing attacks.
Args:
a: First string
b: Second string
Returns:
True if strings are equal, False otherwise
"""
return hmac.compare_digest(a, b)
def _constant_time_compare(self, candidate: str, expected: str) -> bool:
"""Compare API keys in constant time to mitigate timing attacks."""
return hmac.compare_digest(candidate, expected)
def _check_rate_limit(self, identifier: str) -> bool:
"""Check if identifier has exceeded failed authentication rate limit.
Args:
identifier: IP address or other identifier
Returns:
True if within rate limit, False if exceeded
"""
"""Check whether authentication failures exceed configured threshold."""
now = datetime.now(timezone.utc)
window_start = now.timestamp() - self.settings.auth_failure_window
boundary = now.timestamp() - self.settings.auth_failure_window
# Clean up old attempts
if identifier in self._failed_attempts:
self._failed_attempts[identifier] = [
attempt
for attempt in self._failed_attempts[identifier]
if attempt.timestamp() > window_start
if attempt.timestamp() > boundary
]
# Check count
attempt_count = len(self._failed_attempts.get(identifier, []))
return attempt_count < self.settings.max_auth_failures
return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures
def _record_failed_attempt(self, identifier: str) -> None:
"""Record a failed authentication attempt.
"""Record a failed authentication attempt for rate limiting."""
attempt_time = datetime.now(timezone.utc)
self._failed_attempts.setdefault(identifier, []).append(attempt_time)
Args:
identifier: IP address or other identifier
"""
now = datetime.now(timezone.utc)
if identifier not in self._failed_attempts:
self._failed_attempts[identifier] = []
self._failed_attempts[identifier].append(now)
# Check if threshold exceeded
if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures:
self.audit.log_security_event(
event_type="auth_rate_limit_exceeded",
description=f"IP {identifier} exceeded auth failure threshold",
description="Authentication failure threshold exceeded",
severity="high",
metadata={
"identifier": identifier,
@@ -86,29 +60,31 @@ class APIKeyValidator:
)
def validate_api_key(
self, provided_key: Optional[str], client_ip: str, user_agent: str
) -> Tuple[bool, Optional[str]]:
self,
provided_key: str | None,
client_ip: str,
user_agent: str,
) -> tuple[bool, str | None]:
"""Validate an API key.
Args:
provided_key: API key provided by client
client_ip: Client IP address
user_agent: Client user agent string
provided_key: API key provided by client.
client_ip: Request source IP address.
user_agent: Request user agent.
Returns:
Tuple of (is_valid, error_message)
Tuple of `(is_valid, error_message)`.
"""
# Check if authentication is enabled
if not self.settings.auth_enabled:
# Security note: auth-disabled mode is explicit and should be monitored.
self.audit.log_security_event(
event_type="auth_disabled",
description="Authentication is disabled - allowing all requests",
description="Authentication disabled; request was allowed",
severity="critical",
metadata={"client_ip": client_ip},
)
return True, None
# Check rate limit
if not self._check_rate_limit(client_ip):
self.audit.log_access_denied(
tool_name="api_authentication",
@@ -116,7 +92,6 @@ class APIKeyValidator:
)
return False, "Too many failed authentication attempts. Please try again later."
# Check if key was provided
if not provided_key:
self._record_failed_attempt(client_ip)
self.audit.log_access_denied(
@@ -125,8 +100,8 @@ class APIKeyValidator:
)
return False, "Authorization header missing. Required: Authorization: Bearer <api-key>"
# Validate key format (should be at least 32 characters)
if len(provided_key) < 32:
# Validation logic: reject short keys early to reduce brute force surface.
self._record_failed_attempt(client_ip)
self.audit.log_access_denied(
tool_name="api_authentication",
@@ -134,99 +109,87 @@ class APIKeyValidator:
)
return False, "Invalid API key format"
# Get valid API keys from config
valid_keys = self.settings.mcp_api_keys
if not valid_keys:
self.audit.log_security_event(
event_type="no_api_keys_configured",
description="No API keys configured in environment",
description="No API keys configured while auth is enabled",
severity="critical",
metadata={"client_ip": client_ip},
)
return False, "Server configuration error: No API keys configured"
# Check against all valid keys (constant time comparison)
is_valid = any(self._constant_time_compare(provided_key, valid_key) for valid_key in valid_keys)
is_valid = any(
self._constant_time_compare(provided_key, valid_key) for valid_key in valid_keys
)
if is_valid:
# Success - log and return
key_hint = f"{provided_key[:8]}...{provided_key[-4:]}"
key_fingerprint = hashlib.sha256(provided_key.encode("utf-8")).hexdigest()[:12]
self.audit.log_tool_invocation(
tool_name="api_authentication",
result_status="success",
params={"client_ip": client_ip, "user_agent": user_agent, "key_hint": key_hint},
)
return True, None
else:
# Failure - record attempt and log
self._record_failed_attempt(client_ip)
key_hint = f"{provided_key[:8]}..." if len(provided_key) >= 8 else "too_short"
self.audit.log_access_denied(
tool_name="api_authentication",
reason="invalid_api_key",
)
self.audit.log_security_event(
event_type="invalid_api_key_attempt",
description=f"Invalid API key attempted from {client_ip}",
severity="medium",
metadata={
params={
"client_ip": client_ip,
"user_agent": user_agent,
"key_hint": key_hint,
"key_fingerprint": key_fingerprint,
},
)
return False, "Invalid API key"
return True, None
def extract_bearer_token(self, authorization_header: Optional[str]) -> Optional[str]:
"""Extract bearer token from Authorization header.
self._record_failed_attempt(client_ip)
self.audit.log_access_denied(
tool_name="api_authentication",
reason="invalid_api_key",
)
self.audit.log_security_event(
event_type="invalid_api_key_attempt",
description="Invalid API key was presented",
severity="medium",
metadata={"client_ip": client_ip, "user_agent": user_agent},
)
return False, "Invalid API key"
Args:
authorization_header: Authorization header value
def extract_bearer_token(self, authorization_header: str | None) -> str | None:
"""Extract API token from `Authorization: Bearer <token>` header.
Returns:
Extracted token or None if invalid format
Security note:
The scheme is case-sensitive by policy (`Bearer`) to prevent accepting
ambiguous client implementations and to align strict API contracts.
"""
if not authorization_header:
return None
parts = authorization_header.split()
parts = authorization_header.split(" ")
if len(parts) != 2:
return None
scheme, token = parts
if scheme.lower() != "bearer":
if scheme != "Bearer":
return None
if not token.strip():
return None
return token
return token.strip()
def generate_api_key(length: int = 64) -> str:
"""Generate a cryptographically secure API key.
Args:
length: Length of the key in characters (default: 64)
length: Length of key in characters.
Returns:
Generated API key as hex string
Generated API key string.
"""
return secrets.token_hex(length // 2)
def hash_api_key(api_key: str) -> str:
"""Hash an API key for secure storage (future use).
Args:
api_key: Plain text API key
Returns:
SHA256 hash of the key
"""
return hashlib.sha256(api_key.encode()).hexdigest()
"""Hash an API key for secure storage and comparison."""
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
# Global validator instance
_validator: Optional[APIKeyValidator] = None
_validator: APIKeyValidator | None = None
def get_validator() -> APIKeyValidator:
@@ -238,6 +201,6 @@ def get_validator() -> APIKeyValidator:
def reset_validator() -> None:
"""Reset global validator instance (primarily for testing)."""
"""Reset global API key validator instance (primarily for testing)."""
global _validator
_validator = None

View File

@@ -0,0 +1,220 @@
"""Automation workflows for webhooks and scheduled jobs."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.gitea_client import GiteaClient
from aegis_gitea_mcp.policy import get_policy_engine
class AutomationError(RuntimeError):
"""Raised when an automation action is denied or invalid."""
def _parse_timestamp(value: str) -> datetime | None:
"""Parse ISO8601 timestamp with best-effort normalization."""
normalized = value.replace("Z", "+00:00")
try:
return datetime.fromisoformat(normalized)
except ValueError:
return None
class AutomationManager:
"""Policy-controlled automation manager."""
def __init__(self) -> None:
"""Initialize automation manager with runtime services."""
self.settings = get_settings()
self.audit = get_audit_logger()
async def handle_webhook(
self,
event_type: str,
payload: dict[str, Any],
repository: str | None,
) -> dict[str, Any]:
"""Handle inbound webhook event.
Args:
event_type: Event type identifier.
payload: Event payload body.
repository: Optional target repository (`owner/repo`).
Returns:
Result summary for webhook processing.
"""
if not self.settings.automation_enabled:
raise AutomationError("automation is disabled")
decision = get_policy_engine().authorize(
tool_name="automation_webhook_ingest",
is_write=False,
repository=repository,
)
if not decision.allowed:
raise AutomationError(f"policy denied webhook: {decision.reason}")
self.audit.log_tool_invocation(
tool_name="automation_webhook_ingest",
repository=repository,
params={"event_type": event_type},
result_status="success",
)
# Safe default: treat webhook payload as data only.
return {
"status": "accepted",
"event_type": event_type,
"repository": repository,
"keys": sorted(payload.keys()),
}
async def run_job(
self,
job_name: str,
owner: str,
repo: str,
finding_title: str | None = None,
finding_body: str | None = None,
) -> dict[str, Any]:
"""Run a named automation job for a repository.
Args:
job_name: Job identifier.
owner: Repository owner.
repo: Repository name.
Returns:
Job execution summary.
"""
if not self.settings.automation_enabled:
raise AutomationError("automation is disabled")
repository = f"{owner}/{repo}"
is_write = job_name == "auto_issue_creation"
decision = get_policy_engine().authorize(
tool_name=f"automation_{job_name}",
is_write=is_write,
repository=repository,
)
if not decision.allowed:
raise AutomationError(f"policy denied automation job: {decision.reason}")
if job_name == "dependency_hygiene_scan":
return await self._dependency_hygiene_scan(owner, repo)
if job_name == "stale_issue_detection":
return await self._stale_issue_detection(owner, repo)
if job_name == "auto_issue_creation":
return await self._auto_issue_creation(
owner,
repo,
finding_title=finding_title,
finding_body=finding_body,
)
raise AutomationError(f"unsupported automation job: {job_name}")
async def _dependency_hygiene_scan(self, owner: str, repo: str) -> dict[str, Any]:
"""Run dependency hygiene scan placeholder workflow.
Security note:
This job intentionally performs read-only checks and does not mutate
repository state directly.
"""
repository = f"{owner}/{repo}"
self.audit.log_tool_invocation(
tool_name="automation_dependency_hygiene_scan",
repository=repository,
result_status="success",
)
# Placeholder output for policy-controlled automation scaffold.
return {
"job": "dependency_hygiene_scan",
"repository": repository,
"status": "completed",
"findings": [],
}
async def _stale_issue_detection(self, owner: str, repo: str) -> dict[str, Any]:
"""Detect stale issues using repository issue metadata."""
repository = f"{owner}/{repo}"
cutoff = datetime.now(timezone.utc) - timedelta(days=self.settings.automation_stale_days)
stale_issue_numbers: list[int] = []
async with GiteaClient() as gitea:
issues = await gitea.list_issues(
owner,
repo,
state="open",
page=1,
limit=100,
labels=None,
)
for issue in issues:
updated_at = issue.get("updated_at")
if not isinstance(updated_at, str):
continue
parsed = _parse_timestamp(updated_at)
if parsed and parsed < cutoff:
number = issue.get("number")
if isinstance(number, int):
stale_issue_numbers.append(number)
self.audit.log_tool_invocation(
tool_name="automation_stale_issue_detection",
repository=repository,
params={"stale_count": len(stale_issue_numbers)},
result_status="success",
)
return {
"job": "stale_issue_detection",
"repository": repository,
"status": "completed",
"stale_issue_numbers": stale_issue_numbers,
"stale_count": len(stale_issue_numbers),
}
async def _auto_issue_creation(
self,
owner: str,
repo: str,
finding_title: str | None,
finding_body: str | None,
) -> dict[str, Any]:
"""Create issue from automation finding payload."""
repository = f"{owner}/{repo}"
title = finding_title or "Automated security finding"
body = finding_body or "Automated finding created by Aegis automation workflow."
async with GiteaClient() as gitea:
issue = await gitea.create_issue(
owner,
repo,
title=title,
body=body,
labels=["security", "automation"],
assignees=None,
)
issue_number = issue.get("number", 0) if isinstance(issue, dict) else 0
self.audit.log_tool_invocation(
tool_name="automation_auto_issue_creation",
repository=repository,
params={"issue_number": issue_number},
result_status="success",
)
return {
"job": "auto_issue_creation",
"repository": repository,
"status": "completed",
"issue_number": issue_number,
}

View File

@@ -1,11 +1,16 @@
"""Configuration management for AegisGitea MCP server."""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from pydantic import Field, HttpUrl, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
_ALLOWED_LOG_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
_ALLOWED_SECRET_MODES = {"off", "mask", "block"}
_ALLOWED_ENVIRONMENTS = {"development", "staging", "production", "test"}
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
@@ -15,64 +20,86 @@ class Settings(BaseSettings):
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
# Don't try to parse env vars as JSON for complex types
env_parse_none_str="null",
)
# Runtime environment
environment: str = Field(
default="production",
description="Runtime environment name",
)
# Gitea configuration
gitea_url: HttpUrl = Field(
...,
description="Base URL of the Gitea instance",
)
gitea_token: str = Field(
...,
description="Bot user access token for Gitea API",
min_length=1,
)
gitea_url: HttpUrl = Field(..., description="Base URL of the Gitea instance")
gitea_token: str = Field(..., description="Bot user access token for Gitea API", min_length=1)
# MCP server configuration
mcp_host: str = Field(
default="0.0.0.0",
description="Host to bind MCP server to",
default="127.0.0.1",
description="Host interface to bind MCP server to",
)
mcp_port: int = Field(
default=8080,
description="Port to bind MCP server to",
ge=1,
le=65535,
mcp_port: int = Field(default=8080, description="Port to bind MCP server to", ge=1, le=65535)
allow_insecure_bind: bool = Field(
default=False,
description="Allow binding to 0.0.0.0 (disabled by default for local hardening)",
)
# Logging configuration
log_level: str = Field(
default="INFO",
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)
# Logging and observability
log_level: str = Field(default="INFO", description="Application logging level")
audit_log_path: Path = Field(
default=Path("/var/log/aegis-mcp/audit.log"),
description="Path to audit log file",
description="Path to tamper-evident audit log file",
)
metrics_enabled: bool = Field(default=True, description="Enable Prometheus metrics endpoint")
expose_error_details: bool = Field(
default=False,
description="Return internal error details in API responses (disabled by default)",
)
startup_validate_gitea: bool = Field(
default=True,
description="Validate Gitea connectivity during startup",
)
# Security configuration
# Security limits
max_file_size_bytes: int = Field(
default=1_048_576, # 1MB
description="Maximum file size that can be read (in bytes)",
default=1_048_576,
description="Maximum file size that can be read (bytes)",
ge=1,
)
request_timeout_seconds: int = Field(
default=30,
description="Timeout for Gitea API requests (in seconds)",
description="Timeout for Gitea API requests (seconds)",
ge=1,
)
rate_limit_per_minute: int = Field(
default=60,
description="Maximum number of requests per minute",
description="Maximum requests per minute for a single IP",
ge=1,
)
token_rate_limit_per_minute: int = Field(
default=120,
description="Maximum requests per minute per authenticated token",
ge=1,
)
max_tool_response_items: int = Field(
default=200,
description="Maximum list items returned by a tool response",
ge=1,
)
max_tool_response_chars: int = Field(
default=20_000,
description="Maximum characters returned in text fields",
ge=1,
)
secret_detection_mode: str = Field(
default="mask",
description="Secret detection mode: off, mask, or block",
)
# Authentication configuration
auth_enabled: bool = Field(
default=True,
description="Enable API key authentication (disable only for testing)",
description="Enable API key authentication (disable only in controlled testing)",
)
mcp_api_keys_raw: str = Field(
default="",
@@ -81,81 +108,149 @@ class Settings(BaseSettings):
)
max_auth_failures: int = Field(
default=5,
description="Maximum authentication failures before rate limiting",
description="Maximum authentication failures before auth rate limiting",
ge=1,
)
auth_failure_window: int = Field(
default=300, # 5 minutes
description="Time window for counting auth failures (in seconds)",
default=300,
description="Time window for counting auth failures (seconds)",
ge=1,
)
# Policy and write-mode configuration
policy_file_path: Path = Field(
default=Path("policy.yaml"),
description="Path to YAML authorization policy file",
)
write_mode: bool = Field(default=False, description="Enable write-capable tools")
write_repository_whitelist_raw: str = Field(
default="",
description="Comma-separated repository whitelist for write mode (owner/repo)",
alias="WRITE_REPOSITORY_WHITELIST",
)
automation_enabled: bool = Field(
default=False,
description="Enable automation endpoints and workflows",
)
automation_scheduler_enabled: bool = Field(
default=False,
description="Enable built-in scheduled job loop",
)
automation_stale_days: int = Field(
default=30,
description="Number of days before an issue is considered stale",
ge=1,
)
@field_validator("environment")
@classmethod
def validate_environment(cls, value: str) -> str:
"""Validate deployment environment name."""
normalized = value.strip().lower()
if normalized not in _ALLOWED_ENVIRONMENTS:
raise ValueError(f"environment must be one of {_ALLOWED_ENVIRONMENTS}")
return normalized
@field_validator("log_level")
@classmethod
def validate_log_level(cls, v: str) -> str:
def validate_log_level(cls, value: str) -> str:
"""Validate log level is one of the allowed values."""
allowed_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
v_upper = v.upper()
if v_upper not in allowed_levels:
raise ValueError(f"log_level must be one of {allowed_levels}")
return v_upper
normalized = value.upper()
if normalized not in _ALLOWED_LOG_LEVELS:
raise ValueError(f"log_level must be one of {_ALLOWED_LOG_LEVELS}")
return normalized
@field_validator("gitea_token")
@classmethod
def validate_token_not_empty(cls, v: str) -> str:
"""Validate Gitea token is not empty or whitespace."""
if not v.strip():
def validate_token_not_empty(cls, value: str) -> str:
"""Validate Gitea token is non-empty and trimmed."""
cleaned = value.strip()
if not cleaned:
raise ValueError("gitea_token cannot be empty or whitespace")
return v.strip()
return cleaned
@field_validator("secret_detection_mode")
@classmethod
def validate_secret_detection_mode(cls, value: str) -> str:
"""Validate secret detection behavior setting."""
normalized = value.lower().strip()
if normalized not in _ALLOWED_SECRET_MODES:
raise ValueError(f"secret_detection_mode must be one of {_ALLOWED_SECRET_MODES}")
return normalized
@model_validator(mode="after")
def validate_and_parse_api_keys(self) -> "Settings":
"""Parse and validate API keys if authentication is enabled."""
# Parse comma-separated keys into list
keys: list[str] = []
if self.mcp_api_keys_raw and self.mcp_api_keys_raw.strip():
keys = [key.strip() for key in self.mcp_api_keys_raw.split(",") if key.strip()]
def validate_security_constraints(self) -> Settings:
"""Validate cross-field security constraints."""
parsed_keys: list[str] = []
if self.mcp_api_keys_raw.strip():
parsed_keys = [
value.strip() for value in self.mcp_api_keys_raw.split(",") if value.strip()
]
# Store in a property we'll access
object.__setattr__(self, "_mcp_api_keys", keys)
object.__setattr__(self, "_mcp_api_keys", parsed_keys)
# Validate if auth is enabled
if self.auth_enabled and not keys:
write_repositories: list[str] = []
if self.write_repository_whitelist_raw.strip():
write_repositories = [
value.strip()
for value in self.write_repository_whitelist_raw.split(",")
if value.strip()
]
for repository in write_repositories:
if "/" not in repository:
raise ValueError("WRITE_REPOSITORY_WHITELIST entries must be in owner/repo format")
object.__setattr__(self, "_write_repository_whitelist", write_repositories)
# Security decision: binding all interfaces requires explicit opt-in.
if self.mcp_host == "0.0.0.0" and not self.allow_insecure_bind:
raise ValueError(
"At least one API key must be configured when auth_enabled=True. "
"Set MCP_API_KEYS environment variable or disable auth with AUTH_ENABLED=false"
"Binding to 0.0.0.0 is blocked by default. "
"Set ALLOW_INSECURE_BIND=true to explicitly permit this."
)
# Validate key format (at least 32 characters for security)
for key in keys:
if self.auth_enabled and not parsed_keys:
raise ValueError(
"At least one API key must be configured when auth_enabled=True. "
"Set MCP_API_KEYS or disable auth explicitly for controlled testing."
)
# Enforce minimum key length to reduce brute-force success probability.
for key in parsed_keys:
if len(key) < 32:
raise ValueError(
f"API keys must be at least 32 characters long. "
f"Use scripts/generate_api_key.py to generate secure keys."
)
raise ValueError("API keys must be at least 32 characters long")
if self.write_mode and not write_repositories:
raise ValueError("WRITE_MODE=true requires WRITE_REPOSITORY_WHITELIST to be configured")
return self
@property
def mcp_api_keys(self) -> list[str]:
"""Get parsed list of API keys."""
return getattr(self, "_mcp_api_keys", [])
return list(getattr(self, "_mcp_api_keys", []))
@property
def write_repository_whitelist(self) -> list[str]:
"""Get parsed list of repositories allowed for write-mode operations."""
return list(getattr(self, "_write_repository_whitelist", []))
@property
def gitea_base_url(self) -> str:
"""Get Gitea base URL as string."""
"""Get Gitea base URL as normalized string."""
return str(self.gitea_url).rstrip("/")
# Global settings instance
_settings: Optional[Settings] = None
_settings: Settings | None = None
def get_settings() -> Settings:
"""Get or create global settings instance."""
global _settings
if _settings is None:
_settings = Settings()
# Mypy limitation: BaseSettings loads from environment dynamically.
_settings = Settings() # type: ignore[call-arg]
return _settings

View File

@@ -1,8 +1,9 @@
"""Gitea API client with bot user authentication."""
"""Gitea API client with hardened request handling."""
from typing import Any, Dict, List, Optional
from __future__ import annotations
from typing import Any
import httpx
from httpx import AsyncClient, Response
from aegis_gitea_mcp.audit import get_audit_logger
@@ -12,47 +13,37 @@ from aegis_gitea_mcp.config import get_settings
class GiteaError(Exception):
"""Base exception for Gitea API errors."""
pass
class GiteaAuthenticationError(GiteaError):
"""Raised when authentication with Gitea fails."""
pass
class GiteaAuthorizationError(GiteaError):
"""Raised when bot user lacks permission for an operation."""
pass
class GiteaNotFoundError(GiteaError):
"""Raised when a requested resource is not found."""
pass
"""Raised when requested resource is not found."""
class GiteaClient:
"""Client for interacting with Gitea API as a bot user."""
def __init__(self, base_url: Optional[str] = None, token: Optional[str] = None) -> None:
def __init__(self, base_url: str | None = None, token: str | None = None) -> None:
"""Initialize Gitea client.
Args:
base_url: Base URL of Gitea instance (defaults to config value)
token: Bot user access token (defaults to config value)
base_url: Optional base URL override.
token: Optional token override.
"""
self.settings = get_settings()
self.audit = get_audit_logger()
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
self.token = token or self.settings.gitea_token
self.client: AsyncClient | None = None
self.client: Optional[AsyncClient] = None
async def __aenter__(self) -> "GiteaClient":
"""Async context manager entry."""
async def __aenter__(self) -> GiteaClient:
"""Create async HTTP client context."""
self.client = AsyncClient(
base_url=self.base_url,
headers={
@@ -65,26 +56,22 @@ class GiteaClient:
return self
async def __aexit__(self, *args: Any) -> None:
"""Async context manager exit."""
"""Close async HTTP client context."""
if self.client:
await self.client.aclose()
def _handle_response(self, response: Response, correlation_id: str) -> Any:
"""Handle Gitea API response and raise appropriate exceptions.
Args:
response: HTTP response from Gitea
correlation_id: Correlation ID for audit logging
Returns:
Parsed JSON response
def _ensure_client(self) -> AsyncClient:
"""Return initialized HTTP client.
Raises:
GiteaAuthenticationError: On 401 responses
GiteaAuthorizationError: On 403 responses
GiteaNotFoundError: On 404 responses
GiteaError: On other error responses
RuntimeError: If called outside async context manager.
"""
if not self.client:
raise RuntimeError("Client not initialized - use async context manager")
return self.client
def _handle_response(self, response: Response, correlation_id: str) -> Any:
"""Handle HTTP response and map to domain exceptions."""
if response.status_code == 401:
self.audit.log_security_event(
event_type="authentication_failure",
@@ -97,7 +84,7 @@ class GiteaClient:
if response.status_code == 403:
self.audit.log_access_denied(
tool_name="gitea_api",
reason="Bot user lacks permission",
reason="bot user lacks permission",
correlation_id=correlation_id,
)
raise GiteaAuthorizationError("Bot user lacks permission for this operation")
@@ -109,7 +96,9 @@ class GiteaClient:
error_msg = f"Gitea API error: {response.status_code}"
try:
error_data = response.json()
error_msg = f"{error_msg} - {error_data.get('message', '')}"
message = error_data.get("message") if isinstance(error_data, dict) else None
if message:
error_msg = f"{error_msg} - {message}"
except Exception:
pass
raise GiteaError(error_msg)
@@ -119,35 +108,34 @@ class GiteaClient:
except Exception:
return {}
async def get_current_user(self) -> Dict[str, Any]:
"""Get information about the current bot user.
Returns:
User information dict
Raises:
GiteaError: On API errors
"""
if not self.client:
raise RuntimeError("Client not initialized - use async context manager")
async def _request(
self,
method: str,
endpoint: str,
*,
correlation_id: str,
params: dict[str, Any] | None = None,
json_body: dict[str, Any] | None = None,
) -> Any:
"""Execute a request to Gitea API with shared error handling."""
client = self._ensure_client()
response = await client.request(method=method, url=endpoint, params=params, json=json_body)
return self._handle_response(response, correlation_id)
async def get_current_user(self) -> dict[str, Any]:
"""Get current bot user profile."""
correlation_id = self.audit.log_tool_invocation(
tool_name="get_current_user",
result_status="pending",
)
try:
response = await self.client.get("/api/v1/user")
user_data = self._handle_response(response, correlation_id)
result = await self._request("GET", "/api/v1/user", correlation_id=correlation_id)
self.audit.log_tool_invocation(
tool_name="get_current_user",
correlation_id=correlation_id,
result_status="success",
)
return user_data
return result if isinstance(result, dict) else {}
except Exception as exc:
self.audit.log_tool_invocation(
tool_name="get_current_user",
@@ -157,39 +145,22 @@ class GiteaClient:
)
raise
async def list_repositories(self) -> List[Dict[str, Any]]:
"""List all repositories visible to the bot user.
Returns:
List of repository information dicts
Raises:
GiteaError: On API errors
"""
if not self.client:
raise RuntimeError("Client not initialized - use async context manager")
async def list_repositories(self) -> list[dict[str, Any]]:
"""List all repositories visible to the bot user."""
correlation_id = self.audit.log_tool_invocation(
tool_name="list_repositories",
result_status="pending",
)
try:
response = await self.client.get("/api/v1/user/repos")
repos_data = self._handle_response(response, correlation_id)
# Ensure we have a list
repos = repos_data if isinstance(repos_data, list) else []
result = await self._request("GET", "/api/v1/user/repos", correlation_id=correlation_id)
repositories = result if isinstance(result, list) else []
self.audit.log_tool_invocation(
tool_name="list_repositories",
correlation_id=correlation_id,
result_status="success",
params={"count": len(repos)},
params={"count": len(repositories)},
)
return repos
return repositories
except Exception as exc:
self.audit.log_tool_invocation(
tool_name="list_repositories",
@@ -199,43 +170,27 @@ class GiteaClient:
)
raise
async def get_repository(self, owner: str, repo: str) -> Dict[str, Any]:
"""Get information about a specific repository.
Args:
owner: Repository owner username
repo: Repository name
Returns:
Repository information dict
Raises:
GiteaNotFoundError: If repository doesn't exist or bot lacks access
GiteaError: On other API errors
"""
if not self.client:
raise RuntimeError("Client not initialized - use async context manager")
async def get_repository(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository metadata."""
repo_id = f"{owner}/{repo}"
correlation_id = self.audit.log_tool_invocation(
tool_name="get_repository",
repository=repo_id,
result_status="pending",
)
try:
response = await self.client.get(f"/api/v1/repos/{owner}/{repo}")
repo_data = self._handle_response(response, correlation_id)
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}",
correlation_id=correlation_id,
)
self.audit.log_tool_invocation(
tool_name="get_repository",
repository=repo_id,
correlation_id=correlation_id,
result_status="success",
)
return repo_data
return result if isinstance(result, dict) else {}
except Exception as exc:
self.audit.log_tool_invocation(
tool_name="get_repository",
@@ -247,26 +202,13 @@ class GiteaClient:
raise
async def get_file_contents(
self, owner: str, repo: str, filepath: str, ref: str = "main"
) -> Dict[str, Any]:
"""Get contents of a file in a repository.
Args:
owner: Repository owner username
repo: Repository name
filepath: Path to file within repository
ref: Branch, tag, or commit ref (defaults to 'main')
Returns:
File contents dict with 'content', 'encoding', 'size', etc.
Raises:
GiteaNotFoundError: If file doesn't exist
GiteaError: On other API errors
"""
if not self.client:
raise RuntimeError("Client not initialized - use async context manager")
self,
owner: str,
repo: str,
filepath: str,
ref: str = "main",
) -> dict[str, Any]:
"""Get file contents from a repository."""
repo_id = f"{owner}/{repo}"
correlation_id = self.audit.log_tool_invocation(
tool_name="get_file_contents",
@@ -275,20 +217,22 @@ class GiteaClient:
params={"ref": ref},
result_status="pending",
)
try:
response = await self.client.get(
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/contents/{filepath}",
params={"ref": ref},
correlation_id=correlation_id,
)
file_data = self._handle_response(response, correlation_id)
# Check file size against limit
file_size = file_data.get("size", 0)
if not isinstance(result, dict):
raise GiteaError("Unexpected response type for file contents")
file_size = int(result.get("size", 0))
if file_size > self.settings.max_file_size_bytes:
error_msg = (
f"File size ({file_size} bytes) exceeds "
f"limit ({self.settings.max_file_size_bytes} bytes)"
f"File size ({file_size} bytes) exceeds limit "
f"({self.settings.max_file_size_bytes} bytes)"
)
self.audit.log_security_event(
event_type="file_size_limit_exceeded",
@@ -311,9 +255,7 @@ class GiteaClient:
result_status="success",
params={"ref": ref, "size": file_size},
)
return file_data
return result
except Exception as exc:
self.audit.log_tool_invocation(
tool_name="get_file_contents",
@@ -326,25 +268,13 @@ class GiteaClient:
raise
async def get_tree(
self, owner: str, repo: str, ref: str = "main", recursive: bool = False
) -> Dict[str, Any]:
"""Get file tree for a repository.
Args:
owner: Repository owner username
repo: Repository name
ref: Branch, tag, or commit ref (defaults to 'main')
recursive: Whether to recursively fetch tree (default: False for safety)
Returns:
Tree information dict
Raises:
GiteaError: On API errors
"""
if not self.client:
raise RuntimeError("Client not initialized - use async context manager")
self,
owner: str,
repo: str,
ref: str = "main",
recursive: bool = False,
) -> dict[str, Any]:
"""Get repository tree at given ref."""
repo_id = f"{owner}/{repo}"
correlation_id = self.audit.log_tool_invocation(
tool_name="get_tree",
@@ -352,24 +282,26 @@ class GiteaClient:
params={"ref": ref, "recursive": recursive},
result_status="pending",
)
try:
response = await self.client.get(
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/git/trees/{ref}",
params={"recursive": str(recursive).lower()},
correlation_id=correlation_id,
)
tree_data = self._handle_response(response, correlation_id)
tree_data = result if isinstance(result, dict) else {}
self.audit.log_tool_invocation(
tool_name="get_tree",
repository=repo_id,
correlation_id=correlation_id,
result_status="success",
params={"ref": ref, "recursive": recursive, "count": len(tree_data.get("tree", []))},
params={
"ref": ref,
"recursive": recursive,
"count": len(tree_data.get("tree", [])),
},
)
return tree_data
except Exception as exc:
self.audit.log_tool_invocation(
tool_name="get_tree",
@@ -379,3 +311,326 @@ class GiteaClient:
error=str(exc),
)
raise
async def search_code(
self,
owner: str,
repo: str,
query: str,
*,
ref: str,
page: int,
limit: int,
) -> dict[str, Any]:
"""Search repository code by query."""
correlation_id = self.audit.log_tool_invocation(
tool_name="search_code",
repository=f"{owner}/{repo}",
params={"query": query, "ref": ref, "page": page, "limit": limit},
result_status="pending",
)
try:
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/search",
params={"q": query, "page": page, "limit": limit, "ref": ref},
correlation_id=correlation_id,
)
self.audit.log_tool_invocation(
tool_name="search_code",
repository=f"{owner}/{repo}",
correlation_id=correlation_id,
result_status="success",
)
return result if isinstance(result, dict) else {}
except Exception as exc:
self.audit.log_tool_invocation(
tool_name="search_code",
repository=f"{owner}/{repo}",
correlation_id=correlation_id,
result_status="error",
error=str(exc),
)
raise
async def list_commits(
self,
owner: str,
repo: str,
*,
ref: str,
page: int,
limit: int,
) -> list[dict[str, Any]]:
"""List commits for a repository ref."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/commits",
params={"sha": ref, "page": page, "limit": limit},
correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending")
),
)
return result if isinstance(result, list) else []
async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]:
"""Get detailed commit including changed files and patch metadata."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/git/commits/{sha}",
correlation_id=str(
self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending")
),
)
return result if isinstance(result, dict) else {}
async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]:
"""Compare two refs and return commit/file deltas."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/compare/{base}...{head}",
correlation_id=str(
self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending")
),
)
return result if isinstance(result, dict) else {}
async def list_issues(
self,
owner: str,
repo: str,
*,
state: str,
page: int,
limit: int,
labels: list[str] | None = None,
) -> list[dict[str, Any]]:
"""List repository issues."""
params: dict[str, Any] = {"state": state, "page": page, "limit": limit}
if labels:
params["labels"] = ",".join(labels)
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/issues",
params=params,
correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending")
),
)
return result if isinstance(result, list) else []
async def get_issue(self, owner: str, repo: str, index: int) -> dict[str, Any]:
"""Get issue details."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/issues/{index}",
correlation_id=str(
self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending")
),
)
return result if isinstance(result, dict) else {}
async def list_pull_requests(
self,
owner: str,
repo: str,
*,
state: str,
page: int,
limit: int,
) -> list[dict[str, Any]]:
"""List pull requests for repository."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/pulls",
params={"state": state, "page": page, "limit": limit},
correlation_id=str(
self.audit.log_tool_invocation(
tool_name="list_pull_requests", result_status="pending"
)
),
)
return result if isinstance(result, list) else []
async def get_pull_request(self, owner: str, repo: str, index: int) -> dict[str, Any]:
"""Get a single pull request."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/pulls/{index}",
correlation_id=str(
self.audit.log_tool_invocation(
tool_name="get_pull_request", result_status="pending"
)
),
)
return result if isinstance(result, dict) else {}
async def list_labels(
self, owner: str, repo: str, *, page: int, limit: int
) -> list[dict[str, Any]]:
"""List repository labels."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/labels",
params={"page": page, "limit": limit},
correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending")
),
)
return result if isinstance(result, list) else []
async def list_tags(
self, owner: str, repo: str, *, page: int, limit: int
) -> list[dict[str, Any]]:
"""List repository tags."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/tags",
params={"page": page, "limit": limit},
correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending")
),
)
return result if isinstance(result, list) else []
async def list_releases(
self,
owner: str,
repo: str,
*,
page: int,
limit: int,
) -> list[dict[str, Any]]:
"""List repository releases."""
result = await self._request(
"GET",
f"/api/v1/repos/{owner}/{repo}/releases",
params={"page": page, "limit": limit},
correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending")
),
)
return result if isinstance(result, list) else []
async def create_issue(
self,
owner: str,
repo: str,
*,
title: str,
body: str,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> dict[str, Any]:
"""Create repository issue."""
payload: dict[str, Any] = {"title": title, "body": body}
if labels:
payload["labels"] = labels
if assignees:
payload["assignees"] = assignees
result = await self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/issues",
json_body=payload,
correlation_id=str(
self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending")
),
)
return result if isinstance(result, dict) else {}
async def update_issue(
self,
owner: str,
repo: str,
index: int,
*,
title: str | None = None,
body: str | None = None,
state: str | None = None,
) -> dict[str, Any]:
"""Update issue fields."""
payload: dict[str, Any] = {}
if title is not None:
payload["title"] = title
if body is not None:
payload["body"] = body
if state is not None:
payload["state"] = state
result = await self._request(
"PATCH",
f"/api/v1/repos/{owner}/{repo}/issues/{index}",
json_body=payload,
correlation_id=str(
self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending")
),
)
return result if isinstance(result, dict) else {}
async def create_issue_comment(
self, owner: str, repo: str, index: int, body: str
) -> dict[str, Any]:
"""Create a comment on issue (and PR discussion if issue index refers to PR)."""
result = await self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments",
json_body={"body": body},
correlation_id=str(
self.audit.log_tool_invocation(
tool_name="create_issue_comment", result_status="pending"
)
),
)
return result if isinstance(result, dict) else {}
async def create_pr_comment(
self, owner: str, repo: str, index: int, body: str
) -> dict[str, Any]:
"""Create PR discussion comment."""
result = await self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments",
json_body={"body": body},
correlation_id=str(
self.audit.log_tool_invocation(
tool_name="create_pr_comment", result_status="pending"
)
),
)
return result if isinstance(result, dict) else {}
async def add_labels(
self,
owner: str,
repo: str,
index: int,
labels: list[str],
) -> dict[str, Any]:
"""Add labels to issue/PR."""
result = await self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/labels",
json_body={"labels": labels},
correlation_id=str(
self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending")
),
)
return result if isinstance(result, dict) else {}
async def assign_issue(
self,
owner: str,
repo: str,
index: int,
assignees: list[str],
) -> dict[str, Any]:
"""Assign users to issue/PR."""
result = await self._request(
"POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/assignees",
json_body={"assignees": assignees},
correlation_id=str(
self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending")
),
)
return result if isinstance(result, dict) else {}

View File

@@ -0,0 +1,48 @@
"""Structured logging configuration utilities."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from aegis_gitea_mcp.request_context import get_request_id
class JsonLogFormatter(logging.Formatter):
"""Format log records as JSON documents."""
def format(self, record: logging.LogRecord) -> str:
"""Serialize a log record to JSON."""
payload = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"request_id": get_request_id(),
}
if record.exc_info:
# Security decision: include only exception type to avoid stack leakage.
exception_type = record.exc_info[0]
if exception_type is not None:
payload["exception_type"] = str(exception_type.__name__)
return json.dumps(payload, separators=(",", ":"), ensure_ascii=True)
def configure_logging(level: str) -> None:
"""Configure application-wide structured JSON logging.
Args:
level: Logging level string.
"""
logger = logging.getLogger()
logger.setLevel(level.upper())
for handler in list(logger.handlers):
logger.removeHandler(handler)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(JsonLogFormatter())
logger.addHandler(stream_handler)

View File

@@ -1,6 +1,8 @@
"""MCP protocol implementation for AegisGitea."""
"""MCP protocol models and tool registry."""
from typing import Any, Dict, List, Optional
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@@ -10,153 +12,366 @@ class MCPTool(BaseModel):
name: str = Field(..., description="Unique tool identifier")
description: str = Field(..., description="Human-readable tool description")
input_schema: Dict[str, Any] = Field(
..., alias="inputSchema", description="JSON Schema for tool input"
)
model_config = ConfigDict(
populate_by_name=True,
serialize_by_alias=True,
)
input_schema: dict[str, Any] = Field(..., description="JSON schema describing input arguments")
write_operation: bool = Field(default=False, description="Whether tool mutates data")
class MCPToolCallRequest(BaseModel):
"""Request to invoke an MCP tool."""
tool: str = Field(..., description="Name of the tool to invoke")
arguments: Dict[str, Any] = Field(default_factory=dict, description="Tool arguments")
correlation_id: Optional[str] = Field(None, description="Request correlation ID")
arguments: dict[str, Any] = Field(default_factory=dict, description="Tool argument payload")
correlation_id: str | None = Field(default=None, description="Request correlation ID")
model_config = ConfigDict(extra="forbid")
class MCPToolCallResponse(BaseModel):
"""Response from an MCP tool invocation."""
"""Response returned from MCP tool invocation."""
success: bool = Field(..., description="Whether the tool call succeeded")
result: Optional[Any] = Field(None, description="Tool result data")
error: Optional[str] = Field(None, description="Error message if failed")
correlation_id: str = Field(..., description="Request correlation ID")
success: bool = Field(..., description="Whether invocation succeeded")
result: Any | None = Field(default=None, description="Tool result payload")
error: str | None = Field(default=None, description="Error message for failed request")
correlation_id: str = Field(..., description="Correlation ID for request tracing")
class MCPListToolsResponse(BaseModel):
"""Response listing available MCP tools."""
"""Response listing available tools."""
tools: List[MCPTool] = Field(..., description="List of available tools")
tools: list[MCPTool] = Field(..., description="Available tool definitions")
# Tool definitions for AegisGitea MCP
def _tool(
name: str, description: str, schema: dict[str, Any], write_operation: bool = False
) -> MCPTool:
"""Construct tool metadata entry."""
return MCPTool(
name=name,
description=description,
input_schema=schema,
write_operation=write_operation,
)
TOOL_LIST_REPOSITORIES = MCPTool(
name="list_repositories",
description="List all repositories visible to the AI bot user. "
"Only repositories where the bot has explicit read access will be returned. "
"This respects Gitea's dynamic authorization model.",
input_schema={
"type": "object",
"properties": {},
"required": [],
},
)
TOOL_GET_REPOSITORY_INFO = MCPTool(
name="get_repository_info",
description="Get detailed information about a specific repository, "
"including description, default branch, language, and metadata. "
"Requires the bot user to have read access.",
input_schema={
"type": "object",
"properties": {
"owner": {
"type": "string",
"description": "Repository owner username or organization",
},
"repo": {
"type": "string",
"description": "Repository name",
},
AVAILABLE_TOOLS: list[MCPTool] = [
_tool(
"list_repositories",
"List repositories visible to the configured bot account.",
{"type": "object", "properties": {}, "required": []},
),
_tool(
"get_repository_info",
"Get metadata for a repository.",
{
"type": "object",
"properties": {"owner": {"type": "string"}, "repo": {"type": "string"}},
"required": ["owner", "repo"],
"additionalProperties": False,
},
"required": ["owner", "repo"],
},
)
TOOL_GET_FILE_TREE = MCPTool(
name="get_file_tree",
description="Get the file tree structure for a repository at a specific ref. "
"Returns a list of files and directories. "
"Non-recursive by default for safety (max depth: 1 level).",
input_schema={
"type": "object",
"properties": {
"owner": {
"type": "string",
"description": "Repository owner username or organization",
},
"repo": {
"type": "string",
"description": "Repository name",
},
"ref": {
"type": "string",
"description": "Branch, tag, or commit SHA (defaults to 'main')",
"default": "main",
},
"recursive": {
"type": "boolean",
"description": "Whether to recursively fetch entire tree (use with caution)",
"default": False,
),
_tool(
"get_file_tree",
"Get repository tree at a selected ref.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"ref": {"type": "string", "default": "main"},
"recursive": {"type": "boolean", "default": False},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
"required": ["owner", "repo"],
},
)
TOOL_GET_FILE_CONTENTS = MCPTool(
name="get_file_contents",
description="Read the contents of a specific file in a repository. "
"File size is limited to 1MB by default for safety. "
"Returns base64-encoded content for binary files.",
input_schema={
"type": "object",
"properties": {
"owner": {
"type": "string",
"description": "Repository owner username or organization",
},
"repo": {
"type": "string",
"description": "Repository name",
},
"filepath": {
"type": "string",
"description": "Path to file within repository (e.g., 'src/main.py')",
},
"ref": {
"type": "string",
"description": "Branch, tag, or commit SHA (defaults to 'main')",
"default": "main",
),
_tool(
"get_file_contents",
"Read a repository file with size-limited content.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"filepath": {"type": "string"},
"ref": {"type": "string", "default": "main"},
},
"required": ["owner", "repo", "filepath"],
"additionalProperties": False,
},
"required": ["owner", "repo", "filepath"],
},
)
# Registry of all available tools
AVAILABLE_TOOLS: List[MCPTool] = [
TOOL_LIST_REPOSITORIES,
TOOL_GET_REPOSITORY_INFO,
TOOL_GET_FILE_TREE,
TOOL_GET_FILE_CONTENTS,
),
_tool(
"search_code",
"Search code in a repository.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"query": {"type": "string"},
"ref": {"type": "string", "default": "main"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
},
"required": ["owner", "repo", "query"],
"additionalProperties": False,
},
),
_tool(
"list_commits",
"List commits for a repository ref.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"ref": {"type": "string", "default": "main"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
),
_tool(
"get_commit_diff",
"Get commit metadata and file diffs.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"sha": {"type": "string"},
},
"required": ["owner", "repo", "sha"],
"additionalProperties": False,
},
),
_tool(
"compare_refs",
"Compare two repository refs.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"base": {"type": "string"},
"head": {"type": "string"},
},
"required": ["owner", "repo", "base", "head"],
"additionalProperties": False,
},
),
_tool(
"list_issues",
"List repository issues.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"state": {"type": "string", "enum": ["open", "closed", "all"], "default": "open"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
"labels": {"type": "array", "items": {"type": "string"}, "default": []},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
),
_tool(
"get_issue",
"Get repository issue details.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"issue_number": {"type": "integer", "minimum": 1},
},
"required": ["owner", "repo", "issue_number"],
"additionalProperties": False,
},
),
_tool(
"list_pull_requests",
"List repository pull requests.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"state": {"type": "string", "enum": ["open", "closed", "all"], "default": "open"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
),
_tool(
"get_pull_request",
"Get pull request details.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"pull_number": {"type": "integer", "minimum": 1},
},
"required": ["owner", "repo", "pull_number"],
"additionalProperties": False,
},
),
_tool(
"list_labels",
"List labels defined on a repository.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 50},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
),
_tool(
"list_tags",
"List repository tags.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 50},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
),
_tool(
"list_releases",
"List repository releases.",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"page": {"type": "integer", "minimum": 1, "default": 1},
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
},
"required": ["owner", "repo"],
"additionalProperties": False,
},
),
_tool(
"create_issue",
"Create a repository issue (write-mode only).",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"title": {"type": "string"},
"body": {"type": "string", "default": ""},
"labels": {"type": "array", "items": {"type": "string"}, "default": []},
"assignees": {"type": "array", "items": {"type": "string"}, "default": []},
},
"required": ["owner", "repo", "title"],
"additionalProperties": False,
},
write_operation=True,
),
_tool(
"update_issue",
"Update issue title/body/state (write-mode only).",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"issue_number": {"type": "integer", "minimum": 1},
"title": {"type": "string"},
"body": {"type": "string"},
"state": {"type": "string", "enum": ["open", "closed"]},
},
"required": ["owner", "repo", "issue_number"],
"additionalProperties": False,
},
write_operation=True,
),
_tool(
"create_issue_comment",
"Create issue comment (write-mode only).",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"issue_number": {"type": "integer", "minimum": 1},
"body": {"type": "string"},
},
"required": ["owner", "repo", "issue_number", "body"],
"additionalProperties": False,
},
write_operation=True,
),
_tool(
"create_pr_comment",
"Create pull request comment (write-mode only).",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"pull_number": {"type": "integer", "minimum": 1},
"body": {"type": "string"},
},
"required": ["owner", "repo", "pull_number", "body"],
"additionalProperties": False,
},
write_operation=True,
),
_tool(
"add_labels",
"Add labels to an issue or PR (write-mode only).",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"issue_number": {"type": "integer", "minimum": 1},
"labels": {"type": "array", "items": {"type": "string"}, "minItems": 1},
},
"required": ["owner", "repo", "issue_number", "labels"],
"additionalProperties": False,
},
write_operation=True,
),
_tool(
"assign_issue",
"Assign users to issue or PR (write-mode only).",
{
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"issue_number": {"type": "integer", "minimum": 1},
"assignees": {"type": "array", "items": {"type": "string"}, "minItems": 1},
},
"required": ["owner", "repo", "issue_number", "assignees"],
"additionalProperties": False,
},
write_operation=True,
),
]
def get_tool_by_name(tool_name: str) -> Optional[MCPTool]:
"""Get tool definition by name.
Args:
tool_name: Name of the tool to retrieve
Returns:
Tool definition or None if not found
"""
def get_tool_by_name(tool_name: str) -> MCPTool | None:
"""Get tool definition by name."""
for tool in AVAILABLE_TOOLS:
if tool.name == tool_name:
return tool

View File

@@ -0,0 +1,98 @@
"""Observability primitives: metrics and lightweight instrumentation."""
from __future__ import annotations
import time
from collections import defaultdict
from dataclasses import dataclass
from threading import Lock
@dataclass(frozen=True)
class ToolTiming:
"""Aggregated tool timing stats."""
count: int
total_seconds: float
class MetricsRegistry:
"""In-process Prometheus-compatible metrics storage."""
def __init__(self) -> None:
"""Initialize empty metrics state."""
self._lock = Lock()
self._http_requests_total: defaultdict[tuple[str, str, str], int] = defaultdict(int)
self._tool_calls_total: defaultdict[tuple[str, str], int] = defaultdict(int)
self._tool_duration_seconds: defaultdict[str, float] = defaultdict(float)
self._tool_duration_count: defaultdict[str, int] = defaultdict(int)
def record_http_request(self, method: str, path: str, status_code: int) -> None:
"""Record completed HTTP request metric."""
with self._lock:
self._http_requests_total[(method, path, str(status_code))] += 1
def record_tool_call(self, tool_name: str, status: str, duration_seconds: float) -> None:
"""Record tool invocation counters and duration aggregates."""
with self._lock:
self._tool_calls_total[(tool_name, status)] += 1
self._tool_duration_seconds[tool_name] += max(duration_seconds, 0.0)
self._tool_duration_count[tool_name] += 1
def render_prometheus(self) -> str:
"""Render metrics in Prometheus exposition format."""
lines: list[str] = []
lines.append("# HELP aegis_http_requests_total Total HTTP requests")
lines.append("# TYPE aegis_http_requests_total counter")
with self._lock:
for (method, path, status), count in sorted(self._http_requests_total.items()):
lines.append(
"aegis_http_requests_total"
f'{{method="{method}",path="{path}",status="{status}"}} {count}'
)
lines.append("# HELP aegis_tool_calls_total Total MCP tool calls")
lines.append("# TYPE aegis_tool_calls_total counter")
for (tool_name, status), count in sorted(self._tool_calls_total.items()):
lines.append(
"aegis_tool_calls_total" f'{{tool="{tool_name}",status="{status}"}} {count}'
)
lines.append(
"# HELP aegis_tool_duration_seconds_sum Sum of MCP tool call duration seconds"
)
lines.append("# TYPE aegis_tool_duration_seconds_sum counter")
for tool_name, total in sorted(self._tool_duration_seconds.items()):
lines.append(f'aegis_tool_duration_seconds_sum{{tool="{tool_name}"}} {total:.6f}')
lines.append(
"# HELP aegis_tool_duration_seconds_count MCP tool call duration sample count"
)
lines.append("# TYPE aegis_tool_duration_seconds_count counter")
for tool_name, count in sorted(self._tool_duration_count.items()):
lines.append(f'aegis_tool_duration_seconds_count{{tool="{tool_name}"}} {count}')
return "\n".join(lines) + "\n"
_metrics_registry: MetricsRegistry | None = None
def get_metrics_registry() -> MetricsRegistry:
"""Get global metrics registry."""
global _metrics_registry
if _metrics_registry is None:
_metrics_registry = MetricsRegistry()
return _metrics_registry
def reset_metrics_registry() -> None:
"""Reset global metrics registry for tests."""
global _metrics_registry
_metrics_registry = None
def monotonic_seconds() -> float:
"""Expose monotonic timer for deterministic instrumentation."""
return time.monotonic()

View File

@@ -0,0 +1,262 @@
"""Policy engine for tool authorization decisions."""
from __future__ import annotations
from dataclasses import dataclass, field
from fnmatch import fnmatch
from pathlib import Path
from typing import Any
import yaml # type: ignore[import-untyped]
from aegis_gitea_mcp.config import get_settings
class PolicyError(Exception):
"""Raised when policy loading or validation fails."""
@dataclass(frozen=True)
class PolicyDecision:
"""Authorization result for a policy check."""
allowed: bool
reason: str
@dataclass(frozen=True)
class RuleSet:
"""Allow/Deny rules for tools."""
allow: set[str] = field(default_factory=set)
deny: set[str] = field(default_factory=set)
@dataclass(frozen=True)
class PathRules:
"""Allow/Deny rules for target file paths."""
allow: tuple[str, ...] = ()
deny: tuple[str, ...] = ()
@dataclass(frozen=True)
class RepositoryPolicy:
"""Repository-scoped policy rules."""
tools: RuleSet = field(default_factory=RuleSet)
paths: PathRules = field(default_factory=PathRules)
@dataclass(frozen=True)
class PolicyConfig:
"""Parsed policy configuration."""
default_read: str = "allow"
default_write: str = "deny"
tools: RuleSet = field(default_factory=RuleSet)
repositories: dict[str, RepositoryPolicy] = field(default_factory=dict)
class PolicyEngine:
"""Evaluates authorization decisions for MCP tools."""
def __init__(self, config: PolicyConfig) -> None:
"""Initialize policy engine with prevalidated config."""
self.config = config
self.settings = get_settings()
@classmethod
def from_yaml_file(cls, policy_path: Path) -> PolicyEngine:
"""Build a policy engine from YAML policy file.
Args:
policy_path: Path to policy YAML file.
Returns:
Initialized policy engine.
Raises:
PolicyError: If file is malformed or violates policy schema.
"""
if not policy_path.exists():
# Secure default for writes, backwards-compatible allow for reads.
return cls(PolicyConfig())
try:
raw = yaml.safe_load(policy_path.read_text(encoding="utf-8"))
except Exception as exc:
raise PolicyError(f"Failed to parse policy YAML: {exc}") from exc
if raw is None:
return cls(PolicyConfig())
if not isinstance(raw, dict):
raise PolicyError("Policy root must be a mapping")
defaults = raw.get("defaults", {})
if defaults and not isinstance(defaults, dict):
raise PolicyError("defaults must be a mapping")
default_read = str(defaults.get("read", "allow")).lower()
default_write = str(defaults.get("write", "deny")).lower()
if default_read not in {"allow", "deny"}:
raise PolicyError("defaults.read must be 'allow' or 'deny'")
if default_write not in {"allow", "deny"}:
raise PolicyError("defaults.write must be 'allow' or 'deny'")
global_tools = cls._parse_tool_rules(raw.get("tools", {}), "tools")
repositories_raw = raw.get("repositories", {})
if repositories_raw is None:
repositories_raw = {}
if not isinstance(repositories_raw, dict):
raise PolicyError("repositories must be a mapping")
repositories: dict[str, RepositoryPolicy] = {}
for repo_name, repo_payload in repositories_raw.items():
if not isinstance(repo_name, str) or "/" not in repo_name:
raise PolicyError("Repository keys must be in 'owner/repo' format")
if not isinstance(repo_payload, dict):
raise PolicyError(f"Repository policy for {repo_name} must be a mapping")
tool_rules = cls._parse_tool_rules(
repo_payload.get("tools", {}),
f"repositories.{repo_name}.tools",
)
path_payload = repo_payload.get("paths", {})
if path_payload and not isinstance(path_payload, dict):
raise PolicyError(f"repositories.{repo_name}.paths must be a mapping")
allow_paths = cls._parse_path_list(path_payload.get("allow", []), "allow")
deny_paths = cls._parse_path_list(path_payload.get("deny", []), "deny")
repositories[repo_name] = RepositoryPolicy(
tools=tool_rules,
paths=PathRules(allow=allow_paths, deny=deny_paths),
)
return cls(
PolicyConfig(
default_read=default_read,
default_write=default_write,
tools=global_tools,
repositories=repositories,
)
)
@staticmethod
def _parse_tool_rules(raw_rules: Any, location: str) -> RuleSet:
"""Parse tool allow/deny mapping from raw payload."""
if not raw_rules:
return RuleSet()
if not isinstance(raw_rules, dict):
raise PolicyError(f"{location} must be a mapping")
allow = raw_rules.get("allow", [])
deny = raw_rules.get("deny", [])
if not isinstance(allow, list) or not all(isinstance(item, str) for item in allow):
raise PolicyError(f"{location}.allow must be a list of strings")
if not isinstance(deny, list) or not all(isinstance(item, str) for item in deny):
raise PolicyError(f"{location}.deny must be a list of strings")
return RuleSet(allow=set(allow), deny=set(deny))
@staticmethod
def _parse_path_list(raw_paths: Any, label: str) -> tuple[str, ...]:
"""Parse path allow/deny list."""
if raw_paths is None:
return ()
if not isinstance(raw_paths, list) or not all(isinstance(item, str) for item in raw_paths):
raise PolicyError(f"paths.{label} must be a list of strings")
return tuple(raw_paths)
@staticmethod
def _normalize_target_path(path: str) -> str:
"""Normalize path before policy matching.
Security note:
Path normalization blocks traversal attempts before fnmatch
comparisons are executed.
"""
normalized = path.replace("\\", "/").lstrip("/")
parts = [part for part in normalized.split("/") if part and part != "."]
if any(part == ".." for part in parts):
raise PolicyError("Target path contains traversal sequence '..'")
return "/".join(parts)
def authorize(
self,
tool_name: str,
is_write: bool,
repository: str | None = None,
target_path: str | None = None,
) -> PolicyDecision:
"""Evaluate whether a tool call is authorized by policy.
Args:
tool_name: Invoked MCP tool name.
is_write: Whether the tool mutates data.
repository: Optional `owner/repo` target repository.
target_path: Optional file path target.
Returns:
Policy decision indicating allow/deny and reason.
"""
if tool_name in self.config.tools.deny:
return PolicyDecision(False, "tool denied by global policy")
if self.config.tools.allow and tool_name not in self.config.tools.allow:
return PolicyDecision(False, "tool not allowed by global policy")
if is_write:
if not self.settings.write_mode:
return PolicyDecision(False, "write mode is disabled")
if not repository:
return PolicyDecision(False, "write operation requires a repository target")
if repository not in self.settings.write_repository_whitelist:
return PolicyDecision(False, "repository is not in write-mode whitelist")
repo_policy = self.config.repositories.get(repository) if repository else None
if repo_policy:
if tool_name in repo_policy.tools.deny:
return PolicyDecision(False, "tool denied for repository")
if repo_policy.tools.allow and tool_name not in repo_policy.tools.allow:
return PolicyDecision(False, "tool not allowed for repository")
if target_path:
normalized_path = self._normalize_target_path(target_path)
if repo_policy.paths.deny and any(
fnmatch(normalized_path, pattern) for pattern in repo_policy.paths.deny
):
return PolicyDecision(False, "path denied by repository policy")
if repo_policy.paths.allow and not any(
fnmatch(normalized_path, pattern) for pattern in repo_policy.paths.allow
):
return PolicyDecision(False, "path not allowed by repository policy")
default_behavior = self.config.default_write if is_write else self.config.default_read
return PolicyDecision(default_behavior == "allow", "default policy decision")
_policy_engine: PolicyEngine | None = None
def get_policy_engine() -> PolicyEngine:
"""Get or create global policy engine instance."""
global _policy_engine
if _policy_engine is None:
settings = get_settings()
_policy_engine = PolicyEngine.from_yaml_file(settings.policy_file_path)
return _policy_engine
def reset_policy_engine() -> None:
"""Reset global policy engine (mainly for tests)."""
global _policy_engine
_policy_engine = None

View File

@@ -0,0 +1,110 @@
"""In-memory request rate limiting for MCP endpoints."""
from __future__ import annotations
import hashlib
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.config import get_settings
@dataclass(frozen=True)
class RateLimitDecision:
"""Result of request rate-limit checks."""
allowed: bool
reason: str
class SlidingWindowLimiter:
"""Sliding-window limiter keyed by arbitrary identifiers."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
"""Initialize a fixed-window limiter.
Args:
max_requests: Maximum allowed requests within window.
window_seconds: Rolling time window length.
"""
self.max_requests = max_requests
self.window_seconds = window_seconds
self._events: dict[str, deque[float]] = defaultdict(deque)
def allow(self, key: str) -> bool:
"""Check and record request for the provided key."""
now = time.time()
boundary = now - self.window_seconds
events = self._events[key]
while events and events[0] < boundary:
events.popleft()
if len(events) >= self.max_requests:
return False
events.append(now)
return True
class RequestRateLimiter:
"""Combined per-IP and per-token request limiter."""
def __init__(self) -> None:
"""Initialize with current settings."""
settings = get_settings()
self._audit = get_audit_logger()
self._ip_limiter = SlidingWindowLimiter(settings.rate_limit_per_minute, 60)
self._token_limiter = SlidingWindowLimiter(settings.token_rate_limit_per_minute, 60)
def check(self, client_ip: str, token: str | None) -> RateLimitDecision:
"""Evaluate request against IP and token limits.
Args:
client_ip: Request source IP.
token: Optional authenticated API token.
Returns:
Rate limit decision.
"""
if not self._ip_limiter.allow(client_ip):
self._audit.log_security_event(
event_type="rate_limit_ip_exceeded",
description="Per-IP request rate limit exceeded",
severity="medium",
metadata={"client_ip": client_ip},
)
return RateLimitDecision(False, "Per-IP rate limit exceeded")
if token:
# Hash token before using it as a key to avoid storing secrets in memory maps.
token_key = hashlib.sha256(token.encode("utf-8")).hexdigest()
if not self._token_limiter.allow(token_key):
self._audit.log_security_event(
event_type="rate_limit_token_exceeded",
description="Per-token request rate limit exceeded",
severity="high",
metadata={"client_ip": client_ip},
)
return RateLimitDecision(False, "Per-token rate limit exceeded")
return RateLimitDecision(True, "within limits")
_rate_limiter: RequestRateLimiter | None = None
def get_rate_limiter() -> RequestRateLimiter:
"""Get global request limiter."""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RequestRateLimiter()
return _rate_limiter
def reset_rate_limiter() -> None:
"""Reset global limiter (primarily for tests)."""
global _rate_limiter
_rate_limiter = None

View File

@@ -0,0 +1,17 @@
"""Request context utilities for correlation and logging."""
from __future__ import annotations
from contextvars import ContextVar
_REQUEST_ID: ContextVar[str] = ContextVar("request_id", default="-")
def set_request_id(request_id: str) -> None:
"""Store request id in context-local state."""
_REQUEST_ID.set(request_id)
def get_request_id() -> str:
"""Get current request id from context-local state."""
return _REQUEST_ID.get()

View File

@@ -0,0 +1,56 @@
"""Helpers for bounded tool responses."""
from __future__ import annotations
from typing import Any
from aegis_gitea_mcp.config import get_settings
class ResponseLimitError(RuntimeError):
"""Raised when response processing exceeds configured safety limits."""
def limit_items(
items: list[dict[str, Any]], configured_limit: int | None = None
) -> tuple[list[dict[str, Any]], int]:
"""Trim a list of result items to configured maximum length.
Args:
items: List of result dictionaries.
configured_limit: Optional explicit item limit.
Returns:
Tuple of trimmed list and omitted count.
"""
settings = get_settings()
max_items = configured_limit or settings.max_tool_response_items
if max_items <= 0:
raise ResponseLimitError("max_tool_response_items must be greater than zero")
if len(items) <= max_items:
return items, 0
trimmed = items[:max_items]
omitted = len(items) - max_items
return trimmed, omitted
def limit_text(text: str, configured_limit: int | None = None) -> str:
"""Trim text output to configured maximum characters.
Args:
text: Untrusted text output.
configured_limit: Optional explicit character limit.
Returns:
Trimmed text.
"""
settings = get_settings()
max_chars = configured_limit or settings.max_tool_response_chars
if max_chars <= 0:
raise ResponseLimitError("max_tool_response_chars must be greater than zero")
if len(text) <= max_chars:
return text
return text[:max_chars]

View File

@@ -0,0 +1,134 @@
"""Security helpers for secret detection and untrusted content handling."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class SecretMatch:
"""Represents a detected secret-like token."""
secret_type: str
value: str
_SECRET_PATTERNS: tuple[tuple[str, re.Pattern[str]], ...] = (
(
"openai_key",
re.compile(r"\bsk-[A-Za-z0-9_-]{20,}\b"),
),
(
"aws_access_key",
re.compile(r"\bAKIA[0-9A-Z]{16}\b"),
),
(
"github_token",
re.compile(r"\bgh[pousr]_[A-Za-z0-9]{20,}\b"),
),
(
"jwt",
re.compile(r"\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{4,}\.[A-Za-z0-9_-]{4,}\b"),
),
(
"private_key",
re.compile(r"-----BEGIN (?:RSA |EC |OPENSSH |)PRIVATE KEY-----"),
),
(
"generic_api_key",
re.compile(r"\b(?:api[_-]?key|token)[\"'=: ]+[A-Za-z0-9_-]{16,}\b", re.IGNORECASE),
),
)
def detect_secrets(text: str) -> list[SecretMatch]:
"""Detect common secret patterns in text.
Args:
text: Untrusted text to scan.
Returns:
List of detected secret-like values.
"""
matches: list[SecretMatch] = []
for secret_type, pattern in _SECRET_PATTERNS:
for found in pattern.findall(text):
if isinstance(found, tuple):
candidate = "".join(found)
else:
candidate = found
matches.append(SecretMatch(secret_type=secret_type, value=candidate))
return matches
def mask_secret(value: str) -> str:
"""Mask a secret value while preserving minimal context.
Args:
value: Raw secret text.
Returns:
Masked string that does not reveal the secret.
"""
if len(value) <= 8:
return "[REDACTED]"
return f"{value[:4]}...{value[-4:]}"
def sanitize_data(value: Any, mode: str = "mask") -> Any:
"""Recursively sanitize secret-like material from arbitrary data.
Args:
value: Arbitrary response payload.
mode: `mask` to keep redacted content, `block` to fully replace fields.
Returns:
Sanitized payload value.
"""
if isinstance(value, dict):
return {str(key): sanitize_data(item, mode=mode) for key, item in value.items()}
if isinstance(value, list):
return [sanitize_data(item, mode=mode) for item in value]
if isinstance(value, tuple):
return tuple(sanitize_data(item, mode=mode) for item in value)
if isinstance(value, str):
findings = detect_secrets(value)
if not findings:
return value
if mode == "block":
return "[REDACTED_SECRET]"
masked = value
for finding in findings:
masked = masked.replace(finding.value, mask_secret(finding.value))
return masked
return value
def sanitize_untrusted_text(text: str, max_chars: int) -> str:
"""Normalize untrusted repository content for display-only usage.
Security note:
Repository content is always treated as data and never interpreted as
executable instructions. This helper enforces a strict length limit to
prevent prompt-stuffing through oversized payloads.
Args:
text: Repository text content.
max_chars: Maximum allowed characters in returned text.
Returns:
Truncated text safe for downstream display.
"""
if max_chars <= 0:
return ""
if len(text) <= max_chars:
return text
return text[:max_chars]

View File

@@ -1,16 +1,24 @@
"""Main MCP server implementation with FastAPI and SSE support."""
"""Main MCP server implementation with hardened security controls."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Dict
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import ValidationError
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
from pydantic import BaseModel, Field, ValidationError
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.auth import get_validator
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.gitea_client import GiteaClient
from aegis_gitea_mcp.logging_utils import configure_logging
from aegis_gitea_mcp.mcp_protocol import (
AVAILABLE_TOOLS,
MCPListToolsResponse,
@@ -18,276 +26,443 @@ from aegis_gitea_mcp.mcp_protocol import (
MCPToolCallResponse,
get_tool_by_name,
)
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
from aegis_gitea_mcp.rate_limit import get_rate_limiter
from aegis_gitea_mcp.request_context import set_request_id
from aegis_gitea_mcp.security import sanitize_data
from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path
from aegis_gitea_mcp.tools.read_tools import (
compare_refs_tool,
get_commit_diff_tool,
get_issue_tool,
get_pull_request_tool,
list_commits_tool,
list_issues_tool,
list_labels_tool,
list_pull_requests_tool,
list_releases_tool,
list_tags_tool,
search_code_tool,
)
from aegis_gitea_mcp.tools.repository import (
get_file_contents_tool,
get_file_tree_tool,
get_repository_info_tool,
list_repositories_tool,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
from aegis_gitea_mcp.tools.write_tools import (
add_labels_tool,
assign_issue_tool,
create_issue_comment_tool,
create_issue_tool,
create_pr_comment_tool,
update_issue_tool,
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="AegisGitea MCP Server",
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
version="0.1.0",
version="0.2.0",
)
# Global settings and audit logger
# Note: access settings/audit logger dynamically to support test resets.
class AutomationWebhookRequest(BaseModel):
"""Request body for automation webhook ingestion."""
event_type: str = Field(..., min_length=1, max_length=128)
payload: dict[str, Any] = Field(default_factory=dict)
repository: str | None = Field(default=None)
# Tool dispatcher mapping
TOOL_HANDLERS = {
class AutomationJobRequest(BaseModel):
"""Request body for automation job execution."""
job_name: str = Field(..., min_length=1, max_length=128)
owner: str = Field(..., min_length=1, max_length=100)
repo: str = Field(..., min_length=1, max_length=100)
finding_title: str | None = Field(default=None, max_length=256)
finding_body: str | None = Field(default=None, max_length=10_000)
ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]]
TOOL_HANDLERS: dict[str, ToolHandler] = {
# Baseline read tools
"list_repositories": list_repositories_tool,
"get_repository_info": get_repository_info_tool,
"get_file_tree": get_file_tree_tool,
"get_file_contents": get_file_contents_tool,
# Expanded read tools
"search_code": search_code_tool,
"list_commits": list_commits_tool,
"get_commit_diff": get_commit_diff_tool,
"compare_refs": compare_refs_tool,
"list_issues": list_issues_tool,
"get_issue": get_issue_tool,
"list_pull_requests": list_pull_requests_tool,
"get_pull_request": get_pull_request_tool,
"list_labels": list_labels_tool,
"list_tags": list_tags_tool,
"list_releases": list_releases_tool,
# Write-mode tools
"create_issue": create_issue_tool,
"update_issue": update_issue_tool,
"create_issue_comment": create_issue_comment_tool,
"create_pr_comment": create_pr_comment_tool,
"add_labels": add_labels_tool,
"assign_issue": assign_issue_tool,
}
# Authentication middleware
@app.middleware("http")
async def authenticate_request(request: Request, call_next):
"""Authenticate all requests except health checks and root.
async def request_context_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Attach request correlation context and collect request metrics."""
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
set_request_id(request_id)
request.state.request_id = request_id
Supports Mixed authentication mode where:
- /mcp/tools (list tools) is publicly accessible (No Auth)
- /mcp/tool/call (execute tools) requires authentication
- /mcp/sse requires authentication
"""
# Skip authentication for health check and root endpoints
if request.url.path in ["/", "/health"]:
started_at = monotonic_seconds()
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
response.headers["X-Request-ID"] = request_id
return response
finally:
duration = max(monotonic_seconds() - started_at, 0.0)
logger.debug(
"request_completed",
extra={
"method": request.method,
"path": request.url.path,
"duration_seconds": duration,
"status_code": status_code,
},
)
metrics = get_metrics_registry()
metrics.record_http_request(request.method, request.url.path, status_code)
@app.middleware("http")
async def authenticate_and_rate_limit(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Apply rate-limiting and authentication for MCP endpoints."""
settings = get_settings()
if request.url.path in {"/", "/health"}:
return await call_next(request)
# Only authenticate MCP endpoints
if not request.url.path.startswith("/mcp/"):
if request.url.path == "/metrics" and settings.metrics_enabled:
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
return await call_next(request)
# Mixed mode: allow /mcp/tools without authentication (for ChatGPT discovery)
if request.url.path == "/mcp/tools":
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
return await call_next(request)
# Extract client information
validator = get_validator()
limiter = get_rate_limiter()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
# Get validator instance (supports test resets)
validator = get_validator()
# Extract Authorization header
auth_header = request.headers.get("authorization")
api_key = validator.extract_bearer_token(auth_header)
# Fallback: allow API key via query parameter only for MCP endpoints
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
api_key = request.query_params.get("api_key")
# Validate API key
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
rate_limit = limiter.check(client_ip=client_ip, token=api_key)
if not rate_limit.allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"message": rate_limit.reason,
"request_id": getattr(request.state, "request_id", "-"),
},
)
# Mixed mode: tool discovery remains public to preserve MCP client compatibility.
if request.url.path == "/mcp/tools":
return await call_next(request)
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
if not is_valid:
return JSONResponse(
status_code=401,
content={
"error": "Authentication failed",
"message": error_message,
"detail": (
"Provide a valid API key via Authorization header (Bearer <api-key>) "
"or ?api_key=<api-key> query parameter"
),
"detail": "Provide Authorization: Bearer <api-key> or ?api_key=<api-key>",
"request_id": getattr(request.state, "request_id", "-"),
},
)
# Authentication successful - continue to endpoint
response = await call_next(request)
return response
return await call_next(request)
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize server on startup."""
"""Initialize server state on startup."""
settings = get_settings()
logger.info(f"Starting AegisGitea MCP Server on {settings.mcp_host}:{settings.mcp_port}")
logger.info(f"Connected to Gitea instance: {settings.gitea_base_url}")
logger.info(f"Audit logging enabled: {settings.audit_log_path}")
configure_logging(settings.log_level)
# Log authentication status
if settings.auth_enabled:
key_count = len(settings.mcp_api_keys)
logger.info(f"API key authentication ENABLED ({key_count} key(s) configured)")
else:
logger.warning("API key authentication DISABLED - server is open to all requests!")
logger.info("server_starting")
logger.info(
"server_configuration",
extra={
"host": settings.mcp_host,
"port": settings.mcp_port,
"gitea_url": settings.gitea_base_url,
"auth_enabled": settings.auth_enabled,
"write_mode": settings.write_mode,
"metrics_enabled": settings.metrics_enabled,
},
)
# Test Gitea connection
# Fail-fast policy parse errors at startup.
try:
async with GiteaClient() as gitea:
user = await gitea.get_current_user()
logger.info(f"Authenticated as bot user: {user.get('login', 'unknown')}")
except Exception as e:
logger.error(f"Failed to connect to Gitea: {e}")
_ = get_policy_engine()
except PolicyError:
logger.error("policy_load_failed")
raise
if settings.startup_validate_gitea and settings.environment != "test":
try:
async with GiteaClient() as gitea:
user = await gitea.get_current_user()
logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")})
except Exception:
logger.error("gitea_connection_failed")
raise
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Cleanup on server shutdown."""
logger.info("Shutting down AegisGitea MCP Server")
"""Log server shutdown event."""
logger.info("server_stopping")
@app.get("/")
async def root() -> Dict[str, Any]:
"""Root endpoint with server information."""
async def root() -> dict[str, Any]:
"""Root endpoint with server metadata."""
return {
"name": "AegisGitea MCP Server",
"version": "0.1.0",
"version": "0.2.0",
"status": "running",
"mcp_version": "1.0",
}
@app.get("/health")
async def health() -> Dict[str, str]:
async def health() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/metrics")
async def metrics() -> PlainTextResponse:
"""Prometheus-compatible metrics endpoint."""
settings = get_settings()
if not settings.metrics_enabled:
raise HTTPException(status_code=404, detail="Metrics endpoint disabled")
data = get_metrics_registry().render_prometheus()
return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4")
@app.post("/automation/webhook")
async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse:
"""Ingest policy-controlled automation webhooks."""
manager = AutomationManager()
try:
result = await manager.handle_webhook(
event_type=request.event_type,
payload=request.payload,
repository=request.repository,
)
return JSONResponse(content={"success": True, "result": result})
except AutomationError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
@app.post("/automation/jobs/run")
async def automation_run_job(request: AutomationJobRequest) -> JSONResponse:
"""Execute a policy-controlled automation job for a repository."""
manager = AutomationManager()
try:
result = await manager.run_job(
job_name=request.job_name,
owner=request.owner,
repo=request.repo,
finding_title=request.finding_title,
finding_body=request.finding_body,
)
return JSONResponse(content={"success": True, "result": result})
except AutomationError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
@app.get("/mcp/tools")
async def list_tools() -> JSONResponse:
"""List all available MCP tools.
Returns:
JSON response with list of tool definitions
"""
"""List all available MCP tools."""
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
return JSONResponse(content=response.model_dump(by_alias=True))
return JSONResponse(content=response.model_dump())
async def _execute_tool_call(
tool_name: str, arguments: dict[str, Any], correlation_id: str
) -> dict[str, Any]:
"""Execute tool call with policy checks and standardized response sanitization."""
settings = get_settings()
audit = get_audit_logger()
metrics = get_metrics_registry()
tool_def = get_tool_by_name(tool_name)
if not tool_def:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
handler = TOOL_HANDLERS.get(tool_name)
if not handler:
raise HTTPException(
status_code=500, detail=f"Tool '{tool_name}' has no handler implementation"
)
repository = extract_repository(arguments)
target_path = extract_target_path(arguments)
decision = get_policy_engine().authorize(
tool_name=tool_name,
is_write=tool_def.write_operation,
repository=repository,
target_path=target_path,
)
if not decision.allowed:
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason=decision.reason,
correlation_id=correlation_id,
)
raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}")
started_at = monotonic_seconds()
status = "error"
try:
async with GiteaClient() as gitea:
result = await handler(gitea, arguments)
if settings.secret_detection_mode != "off":
# Security decision: sanitize outbound payloads to prevent accidental secret exfiltration.
result = sanitize_data(result, mode=settings.secret_detection_mode)
status = "success"
return result
finally:
duration = max(monotonic_seconds() - started_at, 0.0)
metrics.record_tool_call(tool_name, status, duration)
@app.post("/mcp/tool/call")
async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
"""Execute an MCP tool call.
Args:
request: Tool call request with tool name and arguments
Returns:
JSON response with tool execution result
"""
"""Execute an MCP tool call."""
settings = get_settings()
audit = get_audit_logger()
correlation_id = request.correlation_id or audit.log_tool_invocation(
tool_name=request.tool,
params=request.arguments,
)
try:
# Validate tool exists
tool_def = get_tool_by_name(request.tool)
if not tool_def:
error_msg = f"Tool '{request.tool}' not found"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
)
raise HTTPException(status_code=404, detail=error_msg)
# Get tool handler
handler = TOOL_HANDLERS.get(request.tool)
if not handler:
error_msg = f"Tool '{request.tool}' has no handler implementation"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# Execute tool with Gitea client
async with GiteaClient() as gitea:
result = await handler(gitea, request.arguments)
result = await _execute_tool_call(request.tool, request.arguments, correlation_id)
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="success",
)
response = MCPToolCallResponse(
success=True,
result=result,
correlation_id=correlation_id,
return JSONResponse(
content=MCPToolCallResponse(
success=True,
result=result,
correlation_id=correlation_id,
).model_dump()
)
return JSONResponse(content=response.model_dump())
except HTTPException:
# Re-raise HTTP exceptions (like 404) without catching them
except HTTPException as exc:
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=str(exc.detail),
)
raise
except ValidationError as e:
error_msg = f"Invalid arguments: {str(e)}"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg)
except ValidationError as exc:
error_message = "Invalid tool arguments"
if settings.expose_error_details:
error_message = f"{error_message}: {exc}"
except Exception as e:
error_msg = str(e)
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
error="validation_error",
)
response = MCPToolCallResponse(
success=False,
error=error_msg,
raise HTTPException(status_code=400, detail=error_message) from exc
except Exception:
# Security decision: do not leak stack traces or raw exception messages.
error_message = "Internal server error"
if settings.expose_error_details:
error_message = "Internal server error (details hidden unless explicitly enabled)"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error="internal_error",
)
logger.exception("tool_execution_failed")
return JSONResponse(
status_code=500,
content=MCPToolCallResponse(
success=False,
error=error_message,
correlation_id=correlation_id,
).model_dump(),
)
return JSONResponse(content=response.model_dump(), status_code=500)
@app.get("/mcp/sse")
async def sse_endpoint(request: Request) -> StreamingResponse:
"""Server-Sent Events endpoint for MCP protocol.
"""Server-Sent Events endpoint for MCP transport."""
This enables real-time communication with ChatGPT using SSE.
async def event_stream() -> AsyncGenerator[str, None]:
yield (
"data: "
+ json.dumps(
{"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"},
separators=(",", ":"),
)
+ "\n\n"
)
Returns:
Streaming SSE response
"""
async def event_stream():
"""Generate SSE events."""
# Send initial connection event
yield f"data: {{'event': 'connected', 'server': 'AegisGitea MCP', 'version': '0.1.0'}}\n\n"
# Keep connection alive
try:
while True:
if await request.is_disconnected():
break
# Heartbeat every 30 seconds
yield f"data: {{'event': 'heartbeat'}}\n\n"
# Wait for next heartbeat (in production, this would handle actual events)
import asyncio
yield 'data: {"event":"heartbeat"}\n\n'
await asyncio.sleep(30)
except Exception as e:
logger.error(f"SSE stream error: {e}")
except Exception:
logger.exception("sse_stream_error")
return StreamingResponse(
event_stream(),
@@ -302,21 +477,12 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
@app.post("/mcp/sse")
async def sse_message_handler(request: Request) -> JSONResponse:
"""Handle POST messages from ChatGPT MCP client to SSE endpoint.
"""Handle POST messages for MCP SSE transport."""
settings = get_settings()
audit = get_audit_logger()
The MCP SSE transport uses:
- GET /mcp/sse for server-to-client streaming
- POST /mcp/sse for client-to-server messages
Returns:
JSON response acknowledging the message
"""
try:
audit = get_audit_logger()
body = await request.json()
logger.info(f"Received MCP message via SSE POST: {body}")
# Handle different message types
message_type = body.get("type") or body.get("method")
message_id = body.get("id")
@@ -328,87 +494,71 @@ async def sse_message_handler(request: Request) -> JSONResponse:
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "AegisGitea MCP", "version": "0.1.0"},
"serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"},
},
}
)
elif message_type == "tools/list":
# Return the list of available tools
if message_type == "tools/list":
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": response.model_dump(by_alias=True),
"result": response.model_dump(),
}
)
elif message_type == "tools/call":
# Handle tool execution
if message_type == "tools/call":
tool_name = body.get("params", {}).get("name")
tool_args = body.get("params", {}).get("arguments", {})
correlation_id = audit.log_tool_invocation(
tool_name=tool_name,
params=tool_args,
)
correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args)
try:
# Get tool handler
handler = TOOL_HANDLERS.get(tool_name)
if not handler:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
# Execute tool with Gitea client
async with GiteaClient() as gitea:
result = await handler(gitea, tool_args)
result = await _execute_tool_call(str(tool_name), tool_args, correlation_id)
audit.log_tool_invocation(
tool_name=tool_name,
tool_name=str(tool_name),
correlation_id=correlation_id,
result_status="success",
)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": {"content": [{"type": "text", "text": str(result)}]},
"result": {"content": [{"type": "text", "text": json.dumps(result)}]},
}
)
except Exception as e:
error_msg = str(e)
except Exception as exc:
audit.log_tool_invocation(
tool_name=tool_name,
tool_name=str(tool_name),
correlation_id=correlation_id,
result_status="error",
error=error_msg,
error=str(exc),
)
message = "Internal server error"
if settings.expose_error_details:
message = str(exc)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"error": {"code": -32603, "message": error_msg},
"error": {"code": -32603, "message": message},
}
)
# Handle notifications (no response needed)
elif message_type and message_type.startswith("notifications/"):
logger.info(f"Received notification: {message_type}")
if isinstance(message_type, str) and message_type.startswith("notifications/"):
return JSONResponse(content={})
# Acknowledge other message types
return JSONResponse(
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
)
except Exception as e:
logger.error(f"Error handling SSE POST message: {e}")
return JSONResponse(
status_code=400, content={"error": "Invalid message format", "detail": str(e)}
)
except Exception:
logger.exception("sse_message_handler_error")
message = "Invalid message format"
if settings.expose_error_details:
message = "Invalid message format (details hidden unless explicitly enabled)"
return JSONResponse(status_code=400, content={"error": message})
def main() -> None:

View File

@@ -1,15 +1,53 @@
"""MCP tool implementations for AegisGitea."""
"""MCP tool implementation exports."""
from aegis_gitea_mcp.tools.read_tools import (
compare_refs_tool,
get_commit_diff_tool,
get_issue_tool,
get_pull_request_tool,
list_commits_tool,
list_issues_tool,
list_labels_tool,
list_pull_requests_tool,
list_releases_tool,
list_tags_tool,
search_code_tool,
)
from aegis_gitea_mcp.tools.repository import (
get_file_contents_tool,
get_file_tree_tool,
get_repository_info_tool,
list_repositories_tool,
)
from aegis_gitea_mcp.tools.write_tools import (
add_labels_tool,
assign_issue_tool,
create_issue_comment_tool,
create_issue_tool,
create_pr_comment_tool,
update_issue_tool,
)
__all__ = [
"list_repositories_tool",
"get_repository_info_tool",
"get_file_tree_tool",
"get_file_contents_tool",
"search_code_tool",
"list_commits_tool",
"get_commit_diff_tool",
"compare_refs_tool",
"list_issues_tool",
"get_issue_tool",
"list_pull_requests_tool",
"get_pull_request_tool",
"list_labels_tool",
"list_tags_tool",
"list_releases_tool",
"create_issue_tool",
"update_issue_tool",
"create_issue_comment_tool",
"create_pr_comment_tool",
"add_labels_tool",
"assign_issue_tool",
]

View File

@@ -0,0 +1,208 @@
"""Pydantic argument models for MCP tools."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
class StrictBaseModel(BaseModel):
"""Strict model base that rejects unexpected fields."""
model_config = ConfigDict(extra="forbid")
class ListRepositoriesArgs(StrictBaseModel):
"""Arguments for list_repositories tool."""
class RepositoryArgs(StrictBaseModel):
"""Common repository locator arguments."""
owner: str = Field(..., pattern=_REPO_PART_PATTERN)
repo: str = Field(..., pattern=_REPO_PART_PATTERN)
class FileTreeArgs(RepositoryArgs):
"""Arguments for get_file_tree."""
ref: str = Field(default="main", min_length=1, max_length=200)
recursive: bool = Field(default=False)
class FileContentsArgs(RepositoryArgs):
"""Arguments for get_file_contents."""
filepath: str = Field(..., min_length=1, max_length=1024)
ref: str = Field(default="main", min_length=1, max_length=200)
@model_validator(mode="after")
def validate_filepath(self) -> FileContentsArgs:
"""Validate path safety constraints."""
normalized = self.filepath.replace("\\", "/")
# Security decision: block traversal and absolute paths.
if normalized.startswith("/") or ".." in normalized.split("/"):
raise ValueError("filepath must be a relative path without traversal")
if "\x00" in normalized:
raise ValueError("filepath cannot contain null bytes")
return self
class SearchCodeArgs(RepositoryArgs):
"""Arguments for search_code."""
query: str = Field(..., min_length=1, max_length=256)
ref: str = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class ListCommitsArgs(RepositoryArgs):
"""Arguments for list_commits."""
ref: str = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class CommitDiffArgs(RepositoryArgs):
"""Arguments for get_commit_diff."""
sha: str = Field(..., min_length=7, max_length=64)
class CompareRefsArgs(RepositoryArgs):
"""Arguments for compare_refs."""
base: str = Field(..., min_length=1, max_length=200)
head: str = Field(..., min_length=1, max_length=200)
class ListIssuesArgs(RepositoryArgs):
"""Arguments for list_issues."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
labels: list[str] = Field(default_factory=list, max_length=20)
class IssueArgs(RepositoryArgs):
"""Arguments for get_issue."""
issue_number: int = Field(..., ge=1)
class ListPullRequestsArgs(RepositoryArgs):
"""Arguments for list_pull_requests."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class PullRequestArgs(RepositoryArgs):
"""Arguments for get_pull_request."""
pull_number: int = Field(..., ge=1)
class ListLabelsArgs(RepositoryArgs):
"""Arguments for list_labels."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListTagsArgs(RepositoryArgs):
"""Arguments for list_tags."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListReleasesArgs(RepositoryArgs):
"""Arguments for list_releases."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class CreateIssueArgs(RepositoryArgs):
"""Arguments for create_issue."""
title: str = Field(..., min_length=1, max_length=256)
body: str = Field(default="", max_length=20_000)
labels: list[str] = Field(default_factory=list, max_length=20)
assignees: list[str] = Field(default_factory=list, max_length=20)
class UpdateIssueArgs(RepositoryArgs):
"""Arguments for update_issue."""
issue_number: int = Field(..., ge=1)
title: str | None = Field(default=None, min_length=1, max_length=256)
body: str | None = Field(default=None, max_length=20_000)
state: Literal["open", "closed"] | None = Field(default=None)
@model_validator(mode="after")
def require_change(self) -> UpdateIssueArgs:
"""Require at least one mutable field in update payload."""
if self.title is None and self.body is None and self.state is None:
raise ValueError("At least one of title, body, or state must be provided")
return self
class CreateIssueCommentArgs(RepositoryArgs):
"""Arguments for create_issue_comment."""
issue_number: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class CreatePrCommentArgs(RepositoryArgs):
"""Arguments for create_pr_comment."""
pull_number: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class AddLabelsArgs(RepositoryArgs):
"""Arguments for add_labels."""
issue_number: int = Field(..., ge=1)
labels: list[str] = Field(..., min_length=1, max_length=20)
class AssignIssueArgs(RepositoryArgs):
"""Arguments for assign_issue."""
issue_number: int = Field(..., ge=1)
assignees: list[str] = Field(..., min_length=1, max_length=20)
def extract_repository(arguments: dict[str, object]) -> str | None:
"""Extract `owner/repo` from raw argument mapping.
Args:
arguments: Raw tool arguments.
Returns:
`owner/repo` or None when arguments are incomplete.
"""
owner = arguments.get("owner")
repo = arguments.get("repo")
if isinstance(owner, str) and isinstance(repo, str) and owner and repo:
return f"{owner}/{repo}"
return None
def extract_target_path(arguments: dict[str, object]) -> str | None:
"""Extract optional target path argument for policy path checks."""
filepath = arguments.get("filepath")
if isinstance(filepath, str) and filepath:
return filepath
return None

View File

@@ -0,0 +1,402 @@
"""Extended read-only MCP tools."""
from __future__ import annotations
from typing import Any
from aegis_gitea_mcp.gitea_client import GiteaClient, GiteaError
from aegis_gitea_mcp.response_limits import limit_items, limit_text
from aegis_gitea_mcp.tools.arguments import (
CommitDiffArgs,
CompareRefsArgs,
IssueArgs,
ListCommitsArgs,
ListIssuesArgs,
ListLabelsArgs,
ListPullRequestsArgs,
ListReleasesArgs,
ListTagsArgs,
PullRequestArgs,
SearchCodeArgs,
)
async def search_code_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Search repository code and return bounded result snippets."""
parsed = SearchCodeArgs.model_validate(arguments)
try:
raw = await gitea.search_code(
parsed.owner,
parsed.repo,
parsed.query,
ref=parsed.ref,
page=parsed.page,
limit=parsed.limit,
)
hits_raw = raw.get("data", raw.get("hits", [])) if isinstance(raw, dict) else []
if not isinstance(hits_raw, list):
hits_raw = []
normalized_hits = []
for item in hits_raw:
if not isinstance(item, dict):
continue
snippet = str(item.get("content", item.get("snippet", "")))
normalized_hits.append(
{
"path": item.get("filename", item.get("path", "")),
"sha": item.get("sha", ""),
"ref": parsed.ref,
"snippet": limit_text(snippet),
"score": item.get("score", 0),
}
)
bounded, omitted = limit_items(normalized_hits, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"query": parsed.query,
"ref": parsed.ref,
"results": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to search code: {exc}") from exc
async def list_commits_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List commits for a repository reference."""
parsed = ListCommitsArgs.model_validate(arguments)
try:
commits = await gitea.list_commits(
parsed.owner,
parsed.repo,
ref=parsed.ref,
page=parsed.page,
limit=parsed.limit,
)
normalized = [
{
"sha": commit.get("sha", ""),
"message": limit_text(str(commit.get("commit", {}).get("message", ""))),
"author": commit.get("author", {}).get("login", ""),
"created": commit.get("commit", {}).get("author", {}).get("date", ""),
"url": commit.get("html_url", ""),
}
for commit in commits
if isinstance(commit, dict)
]
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"ref": parsed.ref,
"commits": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list commits: {exc}") from exc
async def get_commit_diff_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Return commit-level file diff metadata."""
parsed = CommitDiffArgs.model_validate(arguments)
try:
commit = await gitea.get_commit_diff(parsed.owner, parsed.repo, parsed.sha)
files = commit.get("files", []) if isinstance(commit, dict) else []
normalized_files = []
if isinstance(files, list):
for item in files:
if not isinstance(item, dict):
continue
normalized_files.append(
{
"filename": item.get("filename", ""),
"status": item.get("status", ""),
"additions": item.get("additions", 0),
"deletions": item.get("deletions", 0),
"changes": item.get("changes", 0),
"patch": limit_text(str(item.get("patch", ""))),
}
)
bounded, omitted = limit_items(normalized_files)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"sha": parsed.sha,
"message": limit_text(
str(commit.get("message", commit.get("commit", {}).get("message", "")))
),
"files": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to get commit diff: {exc}") from exc
async def compare_refs_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Compare two refs and return bounded commit/file changes."""
parsed = CompareRefsArgs.model_validate(arguments)
try:
comparison = await gitea.compare_refs(parsed.owner, parsed.repo, parsed.base, parsed.head)
commits_raw = comparison.get("commits", []) if isinstance(comparison, dict) else []
files_raw = comparison.get("files", []) if isinstance(comparison, dict) else []
commits = [
{
"sha": commit.get("sha", ""),
"message": limit_text(str(commit.get("commit", {}).get("message", ""))),
}
for commit in commits_raw
if isinstance(commit, dict)
]
commit_items, commit_omitted = limit_items(commits)
files = [
{
"filename": item.get("filename", ""),
"status": item.get("status", ""),
"additions": item.get("additions", 0),
"deletions": item.get("deletions", 0),
}
for item in files_raw
if isinstance(item, dict)
]
file_items, file_omitted = limit_items(files)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"base": parsed.base,
"head": parsed.head,
"commits": commit_items,
"files": file_items,
"commit_count": len(commit_items),
"file_count": len(file_items),
"omitted_commits": commit_omitted,
"omitted_files": file_omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to compare refs: {exc}") from exc
async def list_issues_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List issues for repository."""
parsed = ListIssuesArgs.model_validate(arguments)
try:
issues = await gitea.list_issues(
parsed.owner,
parsed.repo,
state=parsed.state,
page=parsed.page,
limit=parsed.limit,
labels=parsed.labels,
)
normalized = [
{
"number": issue.get("number", 0),
"title": limit_text(str(issue.get("title", ""))),
"state": issue.get("state", ""),
"author": issue.get("user", {}).get("login", ""),
"labels": [label.get("name", "") for label in issue.get("labels", [])],
"created_at": issue.get("created_at", ""),
"updated_at": issue.get("updated_at", ""),
"url": issue.get("html_url", ""),
}
for issue in issues
if isinstance(issue, dict)
]
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"state": parsed.state,
"issues": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list issues: {exc}") from exc
async def get_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Get issue details."""
parsed = IssueArgs.model_validate(arguments)
try:
issue = await gitea.get_issue(parsed.owner, parsed.repo, parsed.issue_number)
return {
"number": issue.get("number", 0),
"title": limit_text(str(issue.get("title", ""))),
"body": limit_text(str(issue.get("body", ""))),
"state": issue.get("state", ""),
"author": issue.get("user", {}).get("login", ""),
"labels": [label.get("name", "") for label in issue.get("labels", [])],
"assignees": [assignee.get("login", "") for assignee in issue.get("assignees", [])],
"created_at": issue.get("created_at", ""),
"updated_at": issue.get("updated_at", ""),
"url": issue.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to get issue: {exc}") from exc
async def list_pull_requests_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List pull requests."""
parsed = ListPullRequestsArgs.model_validate(arguments)
try:
pull_requests = await gitea.list_pull_requests(
parsed.owner,
parsed.repo,
state=parsed.state,
page=parsed.page,
limit=parsed.limit,
)
normalized = [
{
"number": pull.get("number", 0),
"title": limit_text(str(pull.get("title", ""))),
"state": pull.get("state", ""),
"author": pull.get("user", {}).get("login", ""),
"draft": pull.get("draft", False),
"mergeable": pull.get("mergeable", False),
"created_at": pull.get("created_at", ""),
"updated_at": pull.get("updated_at", ""),
"url": pull.get("html_url", ""),
}
for pull in pull_requests
if isinstance(pull, dict)
]
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"state": parsed.state,
"pull_requests": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list pull requests: {exc}") from exc
async def get_pull_request_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Get pull request details."""
parsed = PullRequestArgs.model_validate(arguments)
try:
pull = await gitea.get_pull_request(parsed.owner, parsed.repo, parsed.pull_number)
return {
"number": pull.get("number", 0),
"title": limit_text(str(pull.get("title", ""))),
"body": limit_text(str(pull.get("body", ""))),
"state": pull.get("state", ""),
"draft": pull.get("draft", False),
"mergeable": pull.get("mergeable", False),
"author": pull.get("user", {}).get("login", ""),
"base": pull.get("base", {}).get("ref", ""),
"head": pull.get("head", {}).get("ref", ""),
"created_at": pull.get("created_at", ""),
"updated_at": pull.get("updated_at", ""),
"url": pull.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to get pull request: {exc}") from exc
async def list_labels_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List labels configured on repository."""
parsed = ListLabelsArgs.model_validate(arguments)
try:
labels = await gitea.list_labels(
parsed.owner, parsed.repo, page=parsed.page, limit=parsed.limit
)
normalized = [
{
"id": label.get("id", 0),
"name": label.get("name", ""),
"color": label.get("color", ""),
"description": limit_text(str(label.get("description", ""))),
}
for label in labels
if isinstance(label, dict)
]
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"labels": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list labels: {exc}") from exc
async def list_tags_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List repository tags."""
parsed = ListTagsArgs.model_validate(arguments)
try:
tags = await gitea.list_tags(
parsed.owner, parsed.repo, page=parsed.page, limit=parsed.limit
)
normalized = [
{
"name": tag.get("name", ""),
"commit": tag.get("commit", {}).get("sha", ""),
"zipball_url": tag.get("zipball_url", ""),
"tarball_url": tag.get("tarball_url", ""),
}
for tag in tags
if isinstance(tag, dict)
]
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"tags": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list tags: {exc}") from exc
async def list_releases_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List repository releases."""
parsed = ListReleasesArgs.model_validate(arguments)
try:
releases = await gitea.list_releases(
parsed.owner,
parsed.repo,
page=parsed.page,
limit=parsed.limit,
)
normalized = [
{
"id": release.get("id", 0),
"tag_name": release.get("tag_name", ""),
"name": limit_text(str(release.get("name", ""))),
"draft": release.get("draft", False),
"prerelease": release.get("prerelease", False),
"body": limit_text(str(release.get("body", ""))),
"created_at": release.get("created_at", ""),
"published_at": release.get("published_at", ""),
"url": release.get("html_url", ""),
}
for release in releases
if isinstance(release, dict)
]
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
return {
"owner": parsed.owner,
"repo": parsed.repo,
"releases": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list releases: {exc}") from exc

View File

@@ -1,26 +1,36 @@
"""Repository-related MCP tool implementations."""
from __future__ import annotations
import base64
from typing import Any, Dict
import binascii
from typing import Any
from aegis_gitea_mcp.gitea_client import GiteaClient, GiteaError
from aegis_gitea_mcp.response_limits import limit_items, limit_text
from aegis_gitea_mcp.security import sanitize_untrusted_text
from aegis_gitea_mcp.tools.arguments import (
FileContentsArgs,
FileTreeArgs,
ListRepositoriesArgs,
RepositoryArgs,
)
async def list_repositories_tool(gitea: GiteaClient, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""List all repositories visible to the bot user.
async def list_repositories_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List repositories visible to the bot user.
Args:
gitea: Initialized Gitea client
arguments: Tool arguments (empty for this tool)
gitea: Initialized Gitea client.
arguments: Tool arguments.
Returns:
Dict containing list of repositories with metadata
Response payload with bounded repository list.
"""
ListRepositoriesArgs.model_validate(arguments)
try:
repos = await gitea.list_repositories()
# Transform to simplified format
simplified_repos = [
repositories = await gitea.list_repositories()
simplified = [
{
"owner": repo.get("owner", {}).get("login", ""),
"name": repo.get("name", ""),
@@ -32,39 +42,24 @@ async def list_repositories_tool(gitea: GiteaClient, arguments: Dict[str, Any])
"stars": repo.get("stars_count", 0),
"url": repo.get("html_url", ""),
}
for repo in repos
for repo in repositories
]
bounded, omitted = limit_items(simplified)
return {
"repositories": simplified_repos,
"count": len(simplified_repos),
"repositories": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to list repositories: {exc}") from exc
async def get_repository_info_tool(
gitea: GiteaClient, arguments: Dict[str, Any]
) -> Dict[str, Any]:
"""Get detailed information about a specific repository.
Args:
gitea: Initialized Gitea client
arguments: Tool arguments with 'owner' and 'repo'
Returns:
Dict containing repository information
"""
owner = arguments.get("owner")
repo = arguments.get("repo")
if not owner or not repo:
raise ValueError("Both 'owner' and 'repo' arguments are required")
async def get_repository_info_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Get detailed metadata for a repository."""
parsed = RepositoryArgs.model_validate(arguments)
try:
repo_data = await gitea.get_repository(owner, repo)
repo_data = await gitea.get_repository(parsed.owner, parsed.repo)
return {
"owner": repo_data.get("owner", {}).get("login", ""),
"name": repo_data.get("name", ""),
@@ -83,107 +78,82 @@ async def get_repository_info_tool(
"url": repo_data.get("html_url", ""),
"clone_url": repo_data.get("clone_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to get repository info: {exc}") from exc
async def get_file_tree_tool(gitea: GiteaClient, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get file tree for a repository.
Args:
gitea: Initialized Gitea client
arguments: Tool arguments with 'owner', 'repo', optional 'ref' and 'recursive'
Returns:
Dict containing file tree structure
"""
owner = arguments.get("owner")
repo = arguments.get("repo")
ref = arguments.get("ref", "main")
recursive = arguments.get("recursive", False)
if not owner or not repo:
raise ValueError("Both 'owner' and 'repo' arguments are required")
async def get_file_tree_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Get repository file tree at selected ref."""
parsed = FileTreeArgs.model_validate(arguments)
try:
tree_data = await gitea.get_tree(owner, repo, ref, recursive)
# Transform tree entries to simplified format
tree_data = await gitea.get_tree(parsed.owner, parsed.repo, parsed.ref, parsed.recursive)
tree_entries = tree_data.get("tree", [])
simplified_tree = [
simplified = [
{
"path": entry.get("path", ""),
"type": entry.get("type", ""), # 'blob' (file) or 'tree' (directory)
"type": entry.get("type", ""),
"size": entry.get("size", 0),
"sha": entry.get("sha", ""),
}
for entry in tree_entries
]
bounded, omitted = limit_items(simplified)
return {
"owner": owner,
"repo": repo,
"ref": ref,
"tree": simplified_tree,
"count": len(simplified_tree),
"owner": parsed.owner,
"repo": parsed.repo,
"ref": parsed.ref,
"recursive": parsed.recursive,
"tree": bounded,
"count": len(bounded),
"omitted": omitted,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to get file tree: {exc}") from exc
async def get_file_contents_tool(gitea: GiteaClient, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Get contents of a file in a repository.
async def get_file_contents_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Read file contents from a repository ref.
Args:
gitea: Initialized Gitea client
arguments: Tool arguments with 'owner', 'repo', 'filepath', optional 'ref'
Returns:
Dict containing file contents and metadata
Security notes:
- Repository content is treated as untrusted data and never executed.
- Text output is truncated to configured limits to reduce prompt-stuffing risk.
"""
owner = arguments.get("owner")
repo = arguments.get("repo")
filepath = arguments.get("filepath")
ref = arguments.get("ref", "main")
if not owner or not repo or not filepath:
raise ValueError("'owner', 'repo', and 'filepath' arguments are required")
parsed = FileContentsArgs.model_validate(arguments)
try:
file_data = await gitea.get_file_contents(owner, repo, filepath, ref)
file_data = await gitea.get_file_contents(
parsed.owner, parsed.repo, parsed.filepath, parsed.ref
)
# Content is base64-encoded by Gitea
content_b64 = file_data.get("content", "")
encoding = file_data.get("encoding", "base64")
content = str(content_b64)
# Decode if base64
content = content_b64
if encoding == "base64":
try:
content_bytes = base64.b64decode(content_b64)
# Try to decode as UTF-8 text
decoded_bytes = base64.b64decode(content_b64)
try:
content = content_bytes.decode("utf-8")
content = decoded_bytes.decode("utf-8")
except UnicodeDecodeError:
# If not text, keep as base64
content = content_b64
except Exception:
# If decode fails, keep as-is
pass
# Edge case: binary files should remain encoded instead of forcing invalid text.
content = str(content_b64)
except (binascii.Error, ValueError):
content = str(content_b64)
# Validation logic: keep untrusted content bounded before returning it to LLM clients.
content = sanitize_untrusted_text(content, max_chars=200_000)
content = limit_text(content)
return {
"owner": owner,
"repo": repo,
"filepath": filepath,
"ref": ref,
"owner": parsed.owner,
"repo": parsed.repo,
"filepath": parsed.filepath,
"ref": parsed.ref,
"content": content,
"encoding": encoding,
"size": file_data.get("size", 0),
"sha": file_data.get("sha", ""),
"url": file_data.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to get file contents: {exc}") from exc

View File

@@ -0,0 +1,141 @@
"""Write-mode MCP tool implementations (disabled by default)."""
from __future__ import annotations
from typing import Any
from aegis_gitea_mcp.gitea_client import GiteaClient, GiteaError
from aegis_gitea_mcp.response_limits import limit_text
from aegis_gitea_mcp.tools.arguments import (
AddLabelsArgs,
AssignIssueArgs,
CreateIssueArgs,
CreateIssueCommentArgs,
CreatePrCommentArgs,
UpdateIssueArgs,
)
async def create_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Create a new issue in write mode."""
parsed = CreateIssueArgs.model_validate(arguments)
try:
issue = await gitea.create_issue(
parsed.owner,
parsed.repo,
title=parsed.title,
body=parsed.body,
labels=parsed.labels,
assignees=parsed.assignees,
)
return {
"number": issue.get("number", 0),
"title": limit_text(str(issue.get("title", ""))),
"state": issue.get("state", ""),
"url": issue.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to create issue: {exc}") from exc
async def update_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Update issue fields in write mode."""
parsed = UpdateIssueArgs.model_validate(arguments)
try:
issue = await gitea.update_issue(
parsed.owner,
parsed.repo,
parsed.issue_number,
title=parsed.title,
body=parsed.body,
state=parsed.state,
)
return {
"number": issue.get("number", parsed.issue_number),
"title": limit_text(str(issue.get("title", ""))),
"state": issue.get("state", ""),
"url": issue.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to update issue: {exc}") from exc
async def create_issue_comment_tool(
gitea: GiteaClient, arguments: dict[str, Any]
) -> dict[str, Any]:
"""Create issue comment in write mode."""
parsed = CreateIssueCommentArgs.model_validate(arguments)
try:
comment = await gitea.create_issue_comment(
parsed.owner,
parsed.repo,
parsed.issue_number,
parsed.body,
)
return {
"id": comment.get("id", 0),
"issue_number": parsed.issue_number,
"body": limit_text(str(comment.get("body", ""))),
"url": comment.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to create issue comment: {exc}") from exc
async def create_pr_comment_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Create PR discussion comment in write mode."""
parsed = CreatePrCommentArgs.model_validate(arguments)
try:
comment = await gitea.create_pr_comment(
parsed.owner,
parsed.repo,
parsed.pull_number,
parsed.body,
)
return {
"id": comment.get("id", 0),
"pull_number": parsed.pull_number,
"body": limit_text(str(comment.get("body", ""))),
"url": comment.get("html_url", ""),
}
except GiteaError as exc:
raise RuntimeError(f"Failed to create PR comment: {exc}") from exc
async def add_labels_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Add labels to an issue or pull request."""
parsed = AddLabelsArgs.model_validate(arguments)
try:
result = await gitea.add_labels(
parsed.owner, parsed.repo, parsed.issue_number, parsed.labels
)
label_names = []
if isinstance(result, dict):
label_names = [label.get("name", "") for label in result.get("labels", [])]
return {
"issue_number": parsed.issue_number,
"labels": label_names or parsed.labels,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to add labels: {exc}") from exc
async def assign_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""Assign users to an issue or pull request."""
parsed = AssignIssueArgs.model_validate(arguments)
try:
result = await gitea.assign_issue(
parsed.owner,
parsed.repo,
parsed.issue_number,
parsed.assignees,
)
assignees = []
if isinstance(result, dict):
assignees = [assignee.get("login", "") for assignee in result.get("assignees", [])]
return {
"issue_number": parsed.issue_number,
"assignees": assignees or parsed.assignees,
}
except GiteaError as exc:
raise RuntimeError(f"Failed to assign issue: {exc}") from exc