feat: harden gateway with policy engine, secure tools, and governance docs
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
"""AegisGitea MCP - Security-first MCP server for self-hosted Gitea."""
|
||||
"""AegisGitea MCP - Security-first MCP gateway for self-hosted Gitea."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0"
|
||||
|
||||
@@ -1,50 +1,110 @@
|
||||
"""Audit logging system for MCP tool invocations."""
|
||||
"""Tamper-evident audit logging for MCP tool invocations and security events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import structlog
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
from aegis_gitea_mcp.request_context import get_request_id
|
||||
from aegis_gitea_mcp.security import sanitize_data
|
||||
|
||||
_GENESIS_HASH = "GENESIS"
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Audit logger for tracking all MCP tool invocations."""
|
||||
"""Append-only tamper-evident audit logger.
|
||||
|
||||
def __init__(self, log_path: Optional[Path] = None) -> None:
|
||||
Every line in the audit file is hash-chained to the previous line. This makes
|
||||
post-hoc modifications detectable by integrity validation.
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: Path | None = None) -> None:
|
||||
"""Initialize audit logger.
|
||||
|
||||
Args:
|
||||
log_path: Path to audit log file (defaults to config value)
|
||||
log_path: Path to audit log file (defaults to config value).
|
||||
"""
|
||||
self.settings = get_settings()
|
||||
self.log_path = log_path or self.settings.audit_log_path
|
||||
|
||||
# Ensure log directory exists
|
||||
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._log_file = self._get_log_file()
|
||||
self._lock = threading.Lock()
|
||||
self._log_file = open(self.log_path, "a+", encoding="utf-8")
|
||||
self._last_hash = self._read_last_hash()
|
||||
|
||||
# Configure structlog for audit logging
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.dict_tracebacks,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(file=self._log_file),
|
||||
cache_logger_on_first_use=True,
|
||||
def _read_last_hash(self) -> str:
|
||||
"""Read the previous hash from the last log entry."""
|
||||
try:
|
||||
entries = self.log_path.read_text(encoding="utf-8").splitlines()
|
||||
except FileNotFoundError:
|
||||
return _GENESIS_HASH
|
||||
|
||||
if not entries:
|
||||
return _GENESIS_HASH
|
||||
|
||||
last_line = entries[-1]
|
||||
try:
|
||||
payload = json.loads(last_line)
|
||||
entry_hash = payload.get("entry_hash")
|
||||
if isinstance(entry_hash, str) and entry_hash:
|
||||
return entry_hash
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Corrupt trailing line forces a new chain segment.
|
||||
return _GENESIS_HASH
|
||||
|
||||
@staticmethod
|
||||
def _compute_entry_hash(prev_hash: str, payload: dict[str, Any]) -> str:
|
||||
"""Compute deterministic hash for an audit entry payload."""
|
||||
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
|
||||
digest = hashlib.sha256(f"{prev_hash}:{canonical}".encode()).hexdigest()
|
||||
return digest
|
||||
|
||||
def _append_entry(self, event_type: str, payload: dict[str, Any]) -> str:
|
||||
"""Append a hash-chained entry to audit log.
|
||||
|
||||
Args:
|
||||
event_type: Event category.
|
||||
payload: Event payload data.
|
||||
|
||||
Returns:
|
||||
Correlation ID for the appended entry.
|
||||
"""
|
||||
correlation_id = payload.get("correlation_id")
|
||||
if not isinstance(correlation_id, str) or not correlation_id:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
payload["correlation_id"] = correlation_id
|
||||
|
||||
# Security decision: sanitize all audit payloads before persistence.
|
||||
mode = "mask" if self.settings.secret_detection_mode != "off" else "off"
|
||||
safe_payload = payload if mode == "off" else sanitize_data(payload, mode=mode)
|
||||
|
||||
base_entry: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"event_type": event_type,
|
||||
"request_id": get_request_id(),
|
||||
"payload": safe_payload,
|
||||
"prev_hash": self._last_hash,
|
||||
}
|
||||
entry_hash = self._compute_entry_hash(self._last_hash, base_entry)
|
||||
base_entry["entry_hash"] = entry_hash
|
||||
|
||||
serialized = json.dumps(
|
||||
base_entry, sort_keys=True, separators=(",", ":"), ensure_ascii=True
|
||||
)
|
||||
|
||||
self.logger = structlog.get_logger("audit")
|
||||
with self._lock:
|
||||
self._log_file.write(serialized + "\n")
|
||||
self._log_file.flush()
|
||||
self._last_hash = entry_hash
|
||||
|
||||
def _get_log_file(self) -> Any:
|
||||
"""Get file handle for audit log."""
|
||||
return open(self.log_path, "a", encoding="utf-8")
|
||||
return correlation_id
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close open audit log resources."""
|
||||
@@ -56,111 +116,108 @@ class AuditLogger:
|
||||
def log_tool_invocation(
|
||||
self,
|
||||
tool_name: str,
|
||||
repository: Optional[str] = None,
|
||||
target: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
repository: str | None = None,
|
||||
target: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
result_status: str = "pending",
|
||||
error: Optional[str] = None,
|
||||
error: str | None = None,
|
||||
) -> str:
|
||||
"""Log an MCP tool invocation.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the MCP tool being invoked
|
||||
repository: Repository identifier (owner/repo)
|
||||
target: Target path, commit hash, issue number, etc.
|
||||
params: Additional parameters passed to the tool
|
||||
correlation_id: Request correlation ID (auto-generated if not provided)
|
||||
result_status: Status of the invocation (pending, success, error)
|
||||
error: Error message if invocation failed
|
||||
|
||||
Returns:
|
||||
Correlation ID for this invocation
|
||||
"""
|
||||
if correlation_id is None:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
audit_entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"correlation_id": correlation_id,
|
||||
"""Log an MCP tool invocation."""
|
||||
payload: dict[str, Any] = {
|
||||
"correlation_id": correlation_id or str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"repository": repository,
|
||||
"target": target,
|
||||
"params": params or {},
|
||||
"result_status": result_status,
|
||||
}
|
||||
|
||||
if error:
|
||||
audit_entry["error"] = error
|
||||
|
||||
self.logger.info("tool_invocation", **audit_entry)
|
||||
return correlation_id
|
||||
payload["error"] = error
|
||||
return self._append_entry("tool_invocation", payload)
|
||||
|
||||
def log_access_denied(
|
||||
self,
|
||||
tool_name: str,
|
||||
repository: Optional[str] = None,
|
||||
repository: str | None = None,
|
||||
reason: str = "unauthorized",
|
||||
correlation_id: Optional[str] = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> str:
|
||||
"""Log an access denial event.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was denied access
|
||||
repository: Repository identifier that access was denied to
|
||||
reason: Reason for denial
|
||||
correlation_id: Request correlation ID
|
||||
|
||||
Returns:
|
||||
Correlation ID for this event
|
||||
"""
|
||||
if correlation_id is None:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
self.logger.warning(
|
||||
"""Log an access denial event."""
|
||||
return self._append_entry(
|
||||
"access_denied",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
correlation_id=correlation_id,
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason=reason,
|
||||
{
|
||||
"correlation_id": correlation_id or str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"repository": repository,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
def log_security_event(
|
||||
self,
|
||||
event_type: str,
|
||||
description: str,
|
||||
severity: str = "medium",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Log a security-related event.
|
||||
|
||||
Args:
|
||||
event_type: Type of security event (e.g., rate_limit, invalid_input)
|
||||
description: Human-readable description of the event
|
||||
severity: Severity level (low, medium, high, critical)
|
||||
metadata: Additional metadata about the event
|
||||
|
||||
Returns:
|
||||
Correlation ID for this event
|
||||
"""
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
self.logger.warning(
|
||||
"""Log a security event."""
|
||||
return self._append_entry(
|
||||
"security_event",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
correlation_id=correlation_id,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
severity=severity,
|
||||
metadata=metadata or {},
|
||||
{
|
||||
"event_type": event_type,
|
||||
"description": description,
|
||||
"severity": severity,
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
# Global audit logger instance
|
||||
_audit_logger: Optional[AuditLogger] = None
|
||||
def validate_audit_log_integrity(log_path: Path) -> tuple[bool, list[str]]:
|
||||
"""Validate audit log hash chain integrity.
|
||||
|
||||
Args:
|
||||
log_path: Path to an audit log file.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, errors).
|
||||
"""
|
||||
if not log_path.exists():
|
||||
return True, []
|
||||
|
||||
errors: list[str] = []
|
||||
prev_hash = _GENESIS_HASH
|
||||
|
||||
for line_number, raw_line in enumerate(
|
||||
log_path.read_text(encoding="utf-8").splitlines(), start=1
|
||||
):
|
||||
if not raw_line.strip():
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(raw_line)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(f"line {line_number}: invalid JSON")
|
||||
continue
|
||||
|
||||
line_prev_hash = entry.get("prev_hash")
|
||||
entry_hash = entry.get("entry_hash")
|
||||
if line_prev_hash != prev_hash:
|
||||
errors.append(f"line {line_number}: prev_hash mismatch")
|
||||
|
||||
# Recompute hash after removing the stored entry hash.
|
||||
cloned = dict(entry)
|
||||
cloned.pop("entry_hash", None)
|
||||
expected_hash = AuditLogger._compute_entry_hash(prev_hash, cloned)
|
||||
if entry_hash != expected_hash:
|
||||
errors.append(f"line {line_number}: entry_hash mismatch")
|
||||
prev_hash = expected_hash
|
||||
else:
|
||||
prev_hash = str(entry_hash)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
_audit_logger: AuditLogger | None = None
|
||||
|
||||
|
||||
def get_audit_logger() -> AuditLogger:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Authentication module for MCP server API key validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
@@ -13,70 +14,43 @@ from aegis_gitea_mcp.config import get_settings
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class APIKeyValidator:
|
||||
"""Validates API keys for MCP server access."""
|
||||
"""Validate API keys for MCP server access."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize API key validator."""
|
||||
"""Initialize API key validator state."""
|
||||
self.settings = get_settings()
|
||||
self.audit = get_audit_logger()
|
||||
self._failed_attempts: dict[str, list[datetime]] = {}
|
||||
|
||||
def _constant_time_compare(self, a: str, b: str) -> bool:
|
||||
"""Compare two strings in constant time to prevent timing attacks.
|
||||
|
||||
Args:
|
||||
a: First string
|
||||
b: Second string
|
||||
|
||||
Returns:
|
||||
True if strings are equal, False otherwise
|
||||
"""
|
||||
return hmac.compare_digest(a, b)
|
||||
def _constant_time_compare(self, candidate: str, expected: str) -> bool:
|
||||
"""Compare API keys in constant time to mitigate timing attacks."""
|
||||
return hmac.compare_digest(candidate, expected)
|
||||
|
||||
def _check_rate_limit(self, identifier: str) -> bool:
|
||||
"""Check if identifier has exceeded failed authentication rate limit.
|
||||
|
||||
Args:
|
||||
identifier: IP address or other identifier
|
||||
|
||||
Returns:
|
||||
True if within rate limit, False if exceeded
|
||||
"""
|
||||
"""Check whether authentication failures exceed configured threshold."""
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now.timestamp() - self.settings.auth_failure_window
|
||||
boundary = now.timestamp() - self.settings.auth_failure_window
|
||||
|
||||
# Clean up old attempts
|
||||
if identifier in self._failed_attempts:
|
||||
self._failed_attempts[identifier] = [
|
||||
attempt
|
||||
for attempt in self._failed_attempts[identifier]
|
||||
if attempt.timestamp() > window_start
|
||||
if attempt.timestamp() > boundary
|
||||
]
|
||||
|
||||
# Check count
|
||||
attempt_count = len(self._failed_attempts.get(identifier, []))
|
||||
return attempt_count < self.settings.max_auth_failures
|
||||
return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures
|
||||
|
||||
def _record_failed_attempt(self, identifier: str) -> None:
|
||||
"""Record a failed authentication attempt.
|
||||
"""Record a failed authentication attempt for rate limiting."""
|
||||
attempt_time = datetime.now(timezone.utc)
|
||||
self._failed_attempts.setdefault(identifier, []).append(attempt_time)
|
||||
|
||||
Args:
|
||||
identifier: IP address or other identifier
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
if identifier not in self._failed_attempts:
|
||||
self._failed_attempts[identifier] = []
|
||||
self._failed_attempts[identifier].append(now)
|
||||
|
||||
# Check if threshold exceeded
|
||||
if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures:
|
||||
self.audit.log_security_event(
|
||||
event_type="auth_rate_limit_exceeded",
|
||||
description=f"IP {identifier} exceeded auth failure threshold",
|
||||
description="Authentication failure threshold exceeded",
|
||||
severity="high",
|
||||
metadata={
|
||||
"identifier": identifier,
|
||||
@@ -86,29 +60,31 @@ class APIKeyValidator:
|
||||
)
|
||||
|
||||
def validate_api_key(
|
||||
self, provided_key: Optional[str], client_ip: str, user_agent: str
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
self,
|
||||
provided_key: str | None,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Validate an API key.
|
||||
|
||||
Args:
|
||||
provided_key: API key provided by client
|
||||
client_ip: Client IP address
|
||||
user_agent: Client user agent string
|
||||
provided_key: API key provided by client.
|
||||
client_ip: Request source IP address.
|
||||
user_agent: Request user agent.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
Tuple of `(is_valid, error_message)`.
|
||||
"""
|
||||
# Check if authentication is enabled
|
||||
if not self.settings.auth_enabled:
|
||||
# Security note: auth-disabled mode is explicit and should be monitored.
|
||||
self.audit.log_security_event(
|
||||
event_type="auth_disabled",
|
||||
description="Authentication is disabled - allowing all requests",
|
||||
description="Authentication disabled; request was allowed",
|
||||
severity="critical",
|
||||
metadata={"client_ip": client_ip},
|
||||
)
|
||||
return True, None
|
||||
|
||||
# Check rate limit
|
||||
if not self._check_rate_limit(client_ip):
|
||||
self.audit.log_access_denied(
|
||||
tool_name="api_authentication",
|
||||
@@ -116,7 +92,6 @@ class APIKeyValidator:
|
||||
)
|
||||
return False, "Too many failed authentication attempts. Please try again later."
|
||||
|
||||
# Check if key was provided
|
||||
if not provided_key:
|
||||
self._record_failed_attempt(client_ip)
|
||||
self.audit.log_access_denied(
|
||||
@@ -125,8 +100,8 @@ class APIKeyValidator:
|
||||
)
|
||||
return False, "Authorization header missing. Required: Authorization: Bearer <api-key>"
|
||||
|
||||
# Validate key format (should be at least 32 characters)
|
||||
if len(provided_key) < 32:
|
||||
# Validation logic: reject short keys early to reduce brute force surface.
|
||||
self._record_failed_attempt(client_ip)
|
||||
self.audit.log_access_denied(
|
||||
tool_name="api_authentication",
|
||||
@@ -134,99 +109,87 @@ class APIKeyValidator:
|
||||
)
|
||||
return False, "Invalid API key format"
|
||||
|
||||
# Get valid API keys from config
|
||||
valid_keys = self.settings.mcp_api_keys
|
||||
|
||||
if not valid_keys:
|
||||
self.audit.log_security_event(
|
||||
event_type="no_api_keys_configured",
|
||||
description="No API keys configured in environment",
|
||||
description="No API keys configured while auth is enabled",
|
||||
severity="critical",
|
||||
metadata={"client_ip": client_ip},
|
||||
)
|
||||
return False, "Server configuration error: No API keys configured"
|
||||
|
||||
# Check against all valid keys (constant time comparison)
|
||||
is_valid = any(self._constant_time_compare(provided_key, valid_key) for valid_key in valid_keys)
|
||||
is_valid = any(
|
||||
self._constant_time_compare(provided_key, valid_key) for valid_key in valid_keys
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
# Success - log and return
|
||||
key_hint = f"{provided_key[:8]}...{provided_key[-4:]}"
|
||||
key_fingerprint = hashlib.sha256(provided_key.encode("utf-8")).hexdigest()[:12]
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="api_authentication",
|
||||
result_status="success",
|
||||
params={"client_ip": client_ip, "user_agent": user_agent, "key_hint": key_hint},
|
||||
)
|
||||
return True, None
|
||||
else:
|
||||
# Failure - record attempt and log
|
||||
self._record_failed_attempt(client_ip)
|
||||
key_hint = f"{provided_key[:8]}..." if len(provided_key) >= 8 else "too_short"
|
||||
self.audit.log_access_denied(
|
||||
tool_name="api_authentication",
|
||||
reason="invalid_api_key",
|
||||
)
|
||||
self.audit.log_security_event(
|
||||
event_type="invalid_api_key_attempt",
|
||||
description=f"Invalid API key attempted from {client_ip}",
|
||||
severity="medium",
|
||||
metadata={
|
||||
params={
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"key_hint": key_hint,
|
||||
"key_fingerprint": key_fingerprint,
|
||||
},
|
||||
)
|
||||
return False, "Invalid API key"
|
||||
return True, None
|
||||
|
||||
def extract_bearer_token(self, authorization_header: Optional[str]) -> Optional[str]:
|
||||
"""Extract bearer token from Authorization header.
|
||||
self._record_failed_attempt(client_ip)
|
||||
self.audit.log_access_denied(
|
||||
tool_name="api_authentication",
|
||||
reason="invalid_api_key",
|
||||
)
|
||||
self.audit.log_security_event(
|
||||
event_type="invalid_api_key_attempt",
|
||||
description="Invalid API key was presented",
|
||||
severity="medium",
|
||||
metadata={"client_ip": client_ip, "user_agent": user_agent},
|
||||
)
|
||||
return False, "Invalid API key"
|
||||
|
||||
Args:
|
||||
authorization_header: Authorization header value
|
||||
def extract_bearer_token(self, authorization_header: str | None) -> str | None:
|
||||
"""Extract API token from `Authorization: Bearer <token>` header.
|
||||
|
||||
Returns:
|
||||
Extracted token or None if invalid format
|
||||
Security note:
|
||||
The scheme is case-sensitive by policy (`Bearer`) to prevent accepting
|
||||
ambiguous client implementations and to align strict API contracts.
|
||||
"""
|
||||
if not authorization_header:
|
||||
return None
|
||||
|
||||
parts = authorization_header.split()
|
||||
parts = authorization_header.split(" ")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
scheme, token = parts
|
||||
if scheme.lower() != "bearer":
|
||||
if scheme != "Bearer":
|
||||
return None
|
||||
if not token.strip():
|
||||
return None
|
||||
|
||||
return token
|
||||
return token.strip()
|
||||
|
||||
|
||||
def generate_api_key(length: int = 64) -> str:
|
||||
"""Generate a cryptographically secure API key.
|
||||
|
||||
Args:
|
||||
length: Length of the key in characters (default: 64)
|
||||
length: Length of key in characters.
|
||||
|
||||
Returns:
|
||||
Generated API key as hex string
|
||||
Generated API key string.
|
||||
"""
|
||||
return secrets.token_hex(length // 2)
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""Hash an API key for secure storage (future use).
|
||||
|
||||
Args:
|
||||
api_key: Plain text API key
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the key
|
||||
"""
|
||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||
"""Hash an API key for secure storage and comparison."""
|
||||
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# Global validator instance
|
||||
_validator: Optional[APIKeyValidator] = None
|
||||
_validator: APIKeyValidator | None = None
|
||||
|
||||
|
||||
def get_validator() -> APIKeyValidator:
|
||||
@@ -238,6 +201,6 @@ def get_validator() -> APIKeyValidator:
|
||||
|
||||
|
||||
def reset_validator() -> None:
|
||||
"""Reset global validator instance (primarily for testing)."""
|
||||
"""Reset global API key validator instance (primarily for testing)."""
|
||||
global _validator
|
||||
_validator = None
|
||||
|
||||
220
src/aegis_gitea_mcp/automation.py
Normal file
220
src/aegis_gitea_mcp/automation.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Automation workflows for webhooks and scheduled jobs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
from aegis_gitea_mcp.gitea_client import GiteaClient
|
||||
from aegis_gitea_mcp.policy import get_policy_engine
|
||||
|
||||
|
||||
class AutomationError(RuntimeError):
|
||||
"""Raised when an automation action is denied or invalid."""
|
||||
|
||||
|
||||
def _parse_timestamp(value: str) -> datetime | None:
|
||||
"""Parse ISO8601 timestamp with best-effort normalization."""
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
try:
|
||||
return datetime.fromisoformat(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class AutomationManager:
|
||||
"""Policy-controlled automation manager."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize automation manager with runtime services."""
|
||||
self.settings = get_settings()
|
||||
self.audit = get_audit_logger()
|
||||
|
||||
async def handle_webhook(
|
||||
self,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
repository: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Handle inbound webhook event.
|
||||
|
||||
Args:
|
||||
event_type: Event type identifier.
|
||||
payload: Event payload body.
|
||||
repository: Optional target repository (`owner/repo`).
|
||||
|
||||
Returns:
|
||||
Result summary for webhook processing.
|
||||
"""
|
||||
if not self.settings.automation_enabled:
|
||||
raise AutomationError("automation is disabled")
|
||||
|
||||
decision = get_policy_engine().authorize(
|
||||
tool_name="automation_webhook_ingest",
|
||||
is_write=False,
|
||||
repository=repository,
|
||||
)
|
||||
if not decision.allowed:
|
||||
raise AutomationError(f"policy denied webhook: {decision.reason}")
|
||||
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="automation_webhook_ingest",
|
||||
repository=repository,
|
||||
params={"event_type": event_type},
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
# Safe default: treat webhook payload as data only.
|
||||
return {
|
||||
"status": "accepted",
|
||||
"event_type": event_type,
|
||||
"repository": repository,
|
||||
"keys": sorted(payload.keys()),
|
||||
}
|
||||
|
||||
async def run_job(
|
||||
self,
|
||||
job_name: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
finding_title: str | None = None,
|
||||
finding_body: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Run a named automation job for a repository.
|
||||
|
||||
Args:
|
||||
job_name: Job identifier.
|
||||
owner: Repository owner.
|
||||
repo: Repository name.
|
||||
|
||||
Returns:
|
||||
Job execution summary.
|
||||
"""
|
||||
if not self.settings.automation_enabled:
|
||||
raise AutomationError("automation is disabled")
|
||||
|
||||
repository = f"{owner}/{repo}"
|
||||
is_write = job_name == "auto_issue_creation"
|
||||
|
||||
decision = get_policy_engine().authorize(
|
||||
tool_name=f"automation_{job_name}",
|
||||
is_write=is_write,
|
||||
repository=repository,
|
||||
)
|
||||
if not decision.allowed:
|
||||
raise AutomationError(f"policy denied automation job: {decision.reason}")
|
||||
|
||||
if job_name == "dependency_hygiene_scan":
|
||||
return await self._dependency_hygiene_scan(owner, repo)
|
||||
if job_name == "stale_issue_detection":
|
||||
return await self._stale_issue_detection(owner, repo)
|
||||
if job_name == "auto_issue_creation":
|
||||
return await self._auto_issue_creation(
|
||||
owner,
|
||||
repo,
|
||||
finding_title=finding_title,
|
||||
finding_body=finding_body,
|
||||
)
|
||||
|
||||
raise AutomationError(f"unsupported automation job: {job_name}")
|
||||
|
||||
async def _dependency_hygiene_scan(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Run dependency hygiene scan placeholder workflow.
|
||||
|
||||
Security note:
|
||||
This job intentionally performs read-only checks and does not mutate
|
||||
repository state directly.
|
||||
"""
|
||||
repository = f"{owner}/{repo}"
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="automation_dependency_hygiene_scan",
|
||||
repository=repository,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
# Placeholder output for policy-controlled automation scaffold.
|
||||
return {
|
||||
"job": "dependency_hygiene_scan",
|
||||
"repository": repository,
|
||||
"status": "completed",
|
||||
"findings": [],
|
||||
}
|
||||
|
||||
async def _stale_issue_detection(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Detect stale issues using repository issue metadata."""
|
||||
repository = f"{owner}/{repo}"
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=self.settings.automation_stale_days)
|
||||
|
||||
stale_issue_numbers: list[int] = []
|
||||
async with GiteaClient() as gitea:
|
||||
issues = await gitea.list_issues(
|
||||
owner,
|
||||
repo,
|
||||
state="open",
|
||||
page=1,
|
||||
limit=100,
|
||||
labels=None,
|
||||
)
|
||||
|
||||
for issue in issues:
|
||||
updated_at = issue.get("updated_at")
|
||||
if not isinstance(updated_at, str):
|
||||
continue
|
||||
parsed = _parse_timestamp(updated_at)
|
||||
if parsed and parsed < cutoff:
|
||||
number = issue.get("number")
|
||||
if isinstance(number, int):
|
||||
stale_issue_numbers.append(number)
|
||||
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="automation_stale_issue_detection",
|
||||
repository=repository,
|
||||
params={"stale_count": len(stale_issue_numbers)},
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
return {
|
||||
"job": "stale_issue_detection",
|
||||
"repository": repository,
|
||||
"status": "completed",
|
||||
"stale_issue_numbers": stale_issue_numbers,
|
||||
"stale_count": len(stale_issue_numbers),
|
||||
}
|
||||
|
||||
async def _auto_issue_creation(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
finding_title: str | None,
|
||||
finding_body: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create issue from automation finding payload."""
|
||||
repository = f"{owner}/{repo}"
|
||||
title = finding_title or "Automated security finding"
|
||||
body = finding_body or "Automated finding created by Aegis automation workflow."
|
||||
|
||||
async with GiteaClient() as gitea:
|
||||
issue = await gitea.create_issue(
|
||||
owner,
|
||||
repo,
|
||||
title=title,
|
||||
body=body,
|
||||
labels=["security", "automation"],
|
||||
assignees=None,
|
||||
)
|
||||
|
||||
issue_number = issue.get("number", 0) if isinstance(issue, dict) else 0
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="automation_auto_issue_creation",
|
||||
repository=repository,
|
||||
params={"issue_number": issue_number},
|
||||
result_status="success",
|
||||
)
|
||||
return {
|
||||
"job": "auto_issue_creation",
|
||||
"repository": repository,
|
||||
"status": "completed",
|
||||
"issue_number": issue_number,
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
"""Configuration management for AegisGitea MCP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, HttpUrl, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
_ALLOWED_LOG_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
_ALLOWED_SECRET_MODES = {"off", "mask", "block"}
|
||||
_ALLOWED_ENVIRONMENTS = {"development", "staging", "production", "test"}
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
@@ -15,64 +20,86 @@ class Settings(BaseSettings):
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
# Don't try to parse env vars as JSON for complex types
|
||||
env_parse_none_str="null",
|
||||
)
|
||||
|
||||
# Runtime environment
|
||||
environment: str = Field(
|
||||
default="production",
|
||||
description="Runtime environment name",
|
||||
)
|
||||
|
||||
# Gitea configuration
|
||||
gitea_url: HttpUrl = Field(
|
||||
...,
|
||||
description="Base URL of the Gitea instance",
|
||||
)
|
||||
gitea_token: str = Field(
|
||||
...,
|
||||
description="Bot user access token for Gitea API",
|
||||
min_length=1,
|
||||
)
|
||||
gitea_url: HttpUrl = Field(..., description="Base URL of the Gitea instance")
|
||||
gitea_token: str = Field(..., description="Bot user access token for Gitea API", min_length=1)
|
||||
|
||||
# MCP server configuration
|
||||
mcp_host: str = Field(
|
||||
default="0.0.0.0",
|
||||
description="Host to bind MCP server to",
|
||||
default="127.0.0.1",
|
||||
description="Host interface to bind MCP server to",
|
||||
)
|
||||
mcp_port: int = Field(
|
||||
default=8080,
|
||||
description="Port to bind MCP server to",
|
||||
ge=1,
|
||||
le=65535,
|
||||
mcp_port: int = Field(default=8080, description="Port to bind MCP server to", ge=1, le=65535)
|
||||
allow_insecure_bind: bool = Field(
|
||||
default=False,
|
||||
description="Allow binding to 0.0.0.0 (disabled by default for local hardening)",
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
log_level: str = Field(
|
||||
default="INFO",
|
||||
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
|
||||
)
|
||||
# Logging and observability
|
||||
log_level: str = Field(default="INFO", description="Application logging level")
|
||||
audit_log_path: Path = Field(
|
||||
default=Path("/var/log/aegis-mcp/audit.log"),
|
||||
description="Path to audit log file",
|
||||
description="Path to tamper-evident audit log file",
|
||||
)
|
||||
metrics_enabled: bool = Field(default=True, description="Enable Prometheus metrics endpoint")
|
||||
expose_error_details: bool = Field(
|
||||
default=False,
|
||||
description="Return internal error details in API responses (disabled by default)",
|
||||
)
|
||||
startup_validate_gitea: bool = Field(
|
||||
default=True,
|
||||
description="Validate Gitea connectivity during startup",
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
# Security limits
|
||||
max_file_size_bytes: int = Field(
|
||||
default=1_048_576, # 1MB
|
||||
description="Maximum file size that can be read (in bytes)",
|
||||
default=1_048_576,
|
||||
description="Maximum file size that can be read (bytes)",
|
||||
ge=1,
|
||||
)
|
||||
request_timeout_seconds: int = Field(
|
||||
default=30,
|
||||
description="Timeout for Gitea API requests (in seconds)",
|
||||
description="Timeout for Gitea API requests (seconds)",
|
||||
ge=1,
|
||||
)
|
||||
rate_limit_per_minute: int = Field(
|
||||
default=60,
|
||||
description="Maximum number of requests per minute",
|
||||
description="Maximum requests per minute for a single IP",
|
||||
ge=1,
|
||||
)
|
||||
token_rate_limit_per_minute: int = Field(
|
||||
default=120,
|
||||
description="Maximum requests per minute per authenticated token",
|
||||
ge=1,
|
||||
)
|
||||
max_tool_response_items: int = Field(
|
||||
default=200,
|
||||
description="Maximum list items returned by a tool response",
|
||||
ge=1,
|
||||
)
|
||||
max_tool_response_chars: int = Field(
|
||||
default=20_000,
|
||||
description="Maximum characters returned in text fields",
|
||||
ge=1,
|
||||
)
|
||||
secret_detection_mode: str = Field(
|
||||
default="mask",
|
||||
description="Secret detection mode: off, mask, or block",
|
||||
)
|
||||
|
||||
# Authentication configuration
|
||||
auth_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable API key authentication (disable only for testing)",
|
||||
description="Enable API key authentication (disable only in controlled testing)",
|
||||
)
|
||||
mcp_api_keys_raw: str = Field(
|
||||
default="",
|
||||
@@ -81,81 +108,149 @@ class Settings(BaseSettings):
|
||||
)
|
||||
max_auth_failures: int = Field(
|
||||
default=5,
|
||||
description="Maximum authentication failures before rate limiting",
|
||||
description="Maximum authentication failures before auth rate limiting",
|
||||
ge=1,
|
||||
)
|
||||
auth_failure_window: int = Field(
|
||||
default=300, # 5 minutes
|
||||
description="Time window for counting auth failures (in seconds)",
|
||||
default=300,
|
||||
description="Time window for counting auth failures (seconds)",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
# Policy and write-mode configuration
|
||||
policy_file_path: Path = Field(
|
||||
default=Path("policy.yaml"),
|
||||
description="Path to YAML authorization policy file",
|
||||
)
|
||||
write_mode: bool = Field(default=False, description="Enable write-capable tools")
|
||||
write_repository_whitelist_raw: str = Field(
|
||||
default="",
|
||||
description="Comma-separated repository whitelist for write mode (owner/repo)",
|
||||
alias="WRITE_REPOSITORY_WHITELIST",
|
||||
)
|
||||
automation_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable automation endpoints and workflows",
|
||||
)
|
||||
automation_scheduler_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable built-in scheduled job loop",
|
||||
)
|
||||
automation_stale_days: int = Field(
|
||||
default=30,
|
||||
description="Number of days before an issue is considered stale",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
@field_validator("environment")
|
||||
@classmethod
|
||||
def validate_environment(cls, value: str) -> str:
|
||||
"""Validate deployment environment name."""
|
||||
normalized = value.strip().lower()
|
||||
if normalized not in _ALLOWED_ENVIRONMENTS:
|
||||
raise ValueError(f"environment must be one of {_ALLOWED_ENVIRONMENTS}")
|
||||
return normalized
|
||||
|
||||
@field_validator("log_level")
|
||||
@classmethod
|
||||
def validate_log_level(cls, v: str) -> str:
|
||||
def validate_log_level(cls, value: str) -> str:
|
||||
"""Validate log level is one of the allowed values."""
|
||||
allowed_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
v_upper = v.upper()
|
||||
if v_upper not in allowed_levels:
|
||||
raise ValueError(f"log_level must be one of {allowed_levels}")
|
||||
return v_upper
|
||||
normalized = value.upper()
|
||||
if normalized not in _ALLOWED_LOG_LEVELS:
|
||||
raise ValueError(f"log_level must be one of {_ALLOWED_LOG_LEVELS}")
|
||||
return normalized
|
||||
|
||||
@field_validator("gitea_token")
|
||||
@classmethod
|
||||
def validate_token_not_empty(cls, v: str) -> str:
|
||||
"""Validate Gitea token is not empty or whitespace."""
|
||||
if not v.strip():
|
||||
def validate_token_not_empty(cls, value: str) -> str:
|
||||
"""Validate Gitea token is non-empty and trimmed."""
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
raise ValueError("gitea_token cannot be empty or whitespace")
|
||||
return v.strip()
|
||||
return cleaned
|
||||
|
||||
@field_validator("secret_detection_mode")
|
||||
@classmethod
|
||||
def validate_secret_detection_mode(cls, value: str) -> str:
|
||||
"""Validate secret detection behavior setting."""
|
||||
normalized = value.lower().strip()
|
||||
if normalized not in _ALLOWED_SECRET_MODES:
|
||||
raise ValueError(f"secret_detection_mode must be one of {_ALLOWED_SECRET_MODES}")
|
||||
return normalized
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_parse_api_keys(self) -> "Settings":
|
||||
"""Parse and validate API keys if authentication is enabled."""
|
||||
# Parse comma-separated keys into list
|
||||
keys: list[str] = []
|
||||
if self.mcp_api_keys_raw and self.mcp_api_keys_raw.strip():
|
||||
keys = [key.strip() for key in self.mcp_api_keys_raw.split(",") if key.strip()]
|
||||
def validate_security_constraints(self) -> Settings:
|
||||
"""Validate cross-field security constraints."""
|
||||
parsed_keys: list[str] = []
|
||||
if self.mcp_api_keys_raw.strip():
|
||||
parsed_keys = [
|
||||
value.strip() for value in self.mcp_api_keys_raw.split(",") if value.strip()
|
||||
]
|
||||
|
||||
# Store in a property we'll access
|
||||
object.__setattr__(self, "_mcp_api_keys", keys)
|
||||
object.__setattr__(self, "_mcp_api_keys", parsed_keys)
|
||||
|
||||
# Validate if auth is enabled
|
||||
if self.auth_enabled and not keys:
|
||||
write_repositories: list[str] = []
|
||||
if self.write_repository_whitelist_raw.strip():
|
||||
write_repositories = [
|
||||
value.strip()
|
||||
for value in self.write_repository_whitelist_raw.split(",")
|
||||
if value.strip()
|
||||
]
|
||||
|
||||
for repository in write_repositories:
|
||||
if "/" not in repository:
|
||||
raise ValueError("WRITE_REPOSITORY_WHITELIST entries must be in owner/repo format")
|
||||
|
||||
object.__setattr__(self, "_write_repository_whitelist", write_repositories)
|
||||
|
||||
# Security decision: binding all interfaces requires explicit opt-in.
|
||||
if self.mcp_host == "0.0.0.0" and not self.allow_insecure_bind:
|
||||
raise ValueError(
|
||||
"At least one API key must be configured when auth_enabled=True. "
|
||||
"Set MCP_API_KEYS environment variable or disable auth with AUTH_ENABLED=false"
|
||||
"Binding to 0.0.0.0 is blocked by default. "
|
||||
"Set ALLOW_INSECURE_BIND=true to explicitly permit this."
|
||||
)
|
||||
|
||||
# Validate key format (at least 32 characters for security)
|
||||
for key in keys:
|
||||
if self.auth_enabled and not parsed_keys:
|
||||
raise ValueError(
|
||||
"At least one API key must be configured when auth_enabled=True. "
|
||||
"Set MCP_API_KEYS or disable auth explicitly for controlled testing."
|
||||
)
|
||||
|
||||
# Enforce minimum key length to reduce brute-force success probability.
|
||||
for key in parsed_keys:
|
||||
if len(key) < 32:
|
||||
raise ValueError(
|
||||
f"API keys must be at least 32 characters long. "
|
||||
f"Use scripts/generate_api_key.py to generate secure keys."
|
||||
)
|
||||
raise ValueError("API keys must be at least 32 characters long")
|
||||
|
||||
if self.write_mode and not write_repositories:
|
||||
raise ValueError("WRITE_MODE=true requires WRITE_REPOSITORY_WHITELIST to be configured")
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def mcp_api_keys(self) -> list[str]:
|
||||
"""Get parsed list of API keys."""
|
||||
return getattr(self, "_mcp_api_keys", [])
|
||||
return list(getattr(self, "_mcp_api_keys", []))
|
||||
|
||||
@property
|
||||
def write_repository_whitelist(self) -> list[str]:
|
||||
"""Get parsed list of repositories allowed for write-mode operations."""
|
||||
return list(getattr(self, "_write_repository_whitelist", []))
|
||||
|
||||
@property
|
||||
def gitea_base_url(self) -> str:
|
||||
"""Get Gitea base URL as string."""
|
||||
"""Get Gitea base URL as normalized string."""
|
||||
return str(self.gitea_url).rstrip("/")
|
||||
|
||||
|
||||
# Global settings instance
|
||||
_settings: Optional[Settings] = None
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get or create global settings instance."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
# Mypy limitation: BaseSettings loads from environment dynamically.
|
||||
_settings = Settings() # type: ignore[call-arg]
|
||||
return _settings
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Gitea API client with bot user authentication."""
|
||||
"""Gitea API client with hardened request handling."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from httpx import AsyncClient, Response
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
@@ -12,47 +13,37 @@ from aegis_gitea_mcp.config import get_settings
|
||||
class GiteaError(Exception):
|
||||
"""Base exception for Gitea API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GiteaAuthenticationError(GiteaError):
|
||||
"""Raised when authentication with Gitea fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GiteaAuthorizationError(GiteaError):
|
||||
"""Raised when bot user lacks permission for an operation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GiteaNotFoundError(GiteaError):
|
||||
"""Raised when a requested resource is not found."""
|
||||
|
||||
pass
|
||||
"""Raised when requested resource is not found."""
|
||||
|
||||
|
||||
class GiteaClient:
|
||||
"""Client for interacting with Gitea API as a bot user."""
|
||||
|
||||
def __init__(self, base_url: Optional[str] = None, token: Optional[str] = None) -> None:
|
||||
def __init__(self, base_url: str | None = None, token: str | None = None) -> None:
|
||||
"""Initialize Gitea client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of Gitea instance (defaults to config value)
|
||||
token: Bot user access token (defaults to config value)
|
||||
base_url: Optional base URL override.
|
||||
token: Optional token override.
|
||||
"""
|
||||
self.settings = get_settings()
|
||||
self.audit = get_audit_logger()
|
||||
|
||||
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
|
||||
self.token = token or self.settings.gitea_token
|
||||
self.client: AsyncClient | None = None
|
||||
|
||||
self.client: Optional[AsyncClient] = None
|
||||
|
||||
async def __aenter__(self) -> "GiteaClient":
|
||||
"""Async context manager entry."""
|
||||
async def __aenter__(self) -> GiteaClient:
|
||||
"""Create async HTTP client context."""
|
||||
self.client = AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers={
|
||||
@@ -65,26 +56,22 @@ class GiteaClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
"""Close async HTTP client context."""
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
def _handle_response(self, response: Response, correlation_id: str) -> Any:
|
||||
"""Handle Gitea API response and raise appropriate exceptions.
|
||||
|
||||
Args:
|
||||
response: HTTP response from Gitea
|
||||
correlation_id: Correlation ID for audit logging
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
def _ensure_client(self) -> AsyncClient:
|
||||
"""Return initialized HTTP client.
|
||||
|
||||
Raises:
|
||||
GiteaAuthenticationError: On 401 responses
|
||||
GiteaAuthorizationError: On 403 responses
|
||||
GiteaNotFoundError: On 404 responses
|
||||
GiteaError: On other error responses
|
||||
RuntimeError: If called outside async context manager.
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("Client not initialized - use async context manager")
|
||||
return self.client
|
||||
|
||||
def _handle_response(self, response: Response, correlation_id: str) -> Any:
|
||||
"""Handle HTTP response and map to domain exceptions."""
|
||||
if response.status_code == 401:
|
||||
self.audit.log_security_event(
|
||||
event_type="authentication_failure",
|
||||
@@ -97,7 +84,7 @@ class GiteaClient:
|
||||
if response.status_code == 403:
|
||||
self.audit.log_access_denied(
|
||||
tool_name="gitea_api",
|
||||
reason="Bot user lacks permission",
|
||||
reason="bot user lacks permission",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise GiteaAuthorizationError("Bot user lacks permission for this operation")
|
||||
@@ -109,7 +96,9 @@ class GiteaClient:
|
||||
error_msg = f"Gitea API error: {response.status_code}"
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = f"{error_msg} - {error_data.get('message', '')}"
|
||||
message = error_data.get("message") if isinstance(error_data, dict) else None
|
||||
if message:
|
||||
error_msg = f"{error_msg} - {message}"
|
||||
except Exception:
|
||||
pass
|
||||
raise GiteaError(error_msg)
|
||||
@@ -119,35 +108,34 @@ class GiteaClient:
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
async def get_current_user(self) -> Dict[str, Any]:
|
||||
"""Get information about the current bot user.
|
||||
|
||||
Returns:
|
||||
User information dict
|
||||
|
||||
Raises:
|
||||
GiteaError: On API errors
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("Client not initialized - use async context manager")
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
*,
|
||||
correlation_id: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
json_body: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Execute a request to Gitea API with shared error handling."""
|
||||
client = self._ensure_client()
|
||||
response = await client.request(method=method, url=endpoint, params=params, json=json_body)
|
||||
return self._handle_response(response, correlation_id)
|
||||
|
||||
async def get_current_user(self) -> dict[str, Any]:
|
||||
"""Get current bot user profile."""
|
||||
correlation_id = self.audit.log_tool_invocation(
|
||||
tool_name="get_current_user",
|
||||
result_status="pending",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.get("/api/v1/user")
|
||||
user_data = self._handle_response(response, correlation_id)
|
||||
|
||||
result = await self._request("GET", "/api/v1/user", correlation_id=correlation_id)
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_current_user",
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
return user_data
|
||||
|
||||
return result if isinstance(result, dict) else {}
|
||||
except Exception as exc:
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_current_user",
|
||||
@@ -157,39 +145,22 @@ class GiteaClient:
|
||||
)
|
||||
raise
|
||||
|
||||
async def list_repositories(self) -> List[Dict[str, Any]]:
|
||||
"""List all repositories visible to the bot user.
|
||||
|
||||
Returns:
|
||||
List of repository information dicts
|
||||
|
||||
Raises:
|
||||
GiteaError: On API errors
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("Client not initialized - use async context manager")
|
||||
|
||||
async def list_repositories(self) -> list[dict[str, Any]]:
|
||||
"""List all repositories visible to the bot user."""
|
||||
correlation_id = self.audit.log_tool_invocation(
|
||||
tool_name="list_repositories",
|
||||
result_status="pending",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.get("/api/v1/user/repos")
|
||||
repos_data = self._handle_response(response, correlation_id)
|
||||
|
||||
# Ensure we have a list
|
||||
repos = repos_data if isinstance(repos_data, list) else []
|
||||
|
||||
result = await self._request("GET", "/api/v1/user/repos", correlation_id=correlation_id)
|
||||
repositories = result if isinstance(result, list) else []
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="list_repositories",
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
params={"count": len(repos)},
|
||||
params={"count": len(repositories)},
|
||||
)
|
||||
|
||||
return repos
|
||||
|
||||
return repositories
|
||||
except Exception as exc:
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="list_repositories",
|
||||
@@ -199,43 +170,27 @@ class GiteaClient:
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_repository(self, owner: str, repo: str) -> Dict[str, Any]:
|
||||
"""Get information about a specific repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner username
|
||||
repo: Repository name
|
||||
|
||||
Returns:
|
||||
Repository information dict
|
||||
|
||||
Raises:
|
||||
GiteaNotFoundError: If repository doesn't exist or bot lacks access
|
||||
GiteaError: On other API errors
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("Client not initialized - use async context manager")
|
||||
|
||||
async def get_repository(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Get repository metadata."""
|
||||
repo_id = f"{owner}/{repo}"
|
||||
correlation_id = self.audit.log_tool_invocation(
|
||||
tool_name="get_repository",
|
||||
repository=repo_id,
|
||||
result_status="pending",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.get(f"/api/v1/repos/{owner}/{repo}")
|
||||
repo_data = self._handle_response(response, correlation_id)
|
||||
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_repository",
|
||||
repository=repo_id,
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
return repo_data
|
||||
|
||||
return result if isinstance(result, dict) else {}
|
||||
except Exception as exc:
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_repository",
|
||||
@@ -247,26 +202,13 @@ class GiteaClient:
|
||||
raise
|
||||
|
||||
async def get_file_contents(
|
||||
self, owner: str, repo: str, filepath: str, ref: str = "main"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get contents of a file in a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner username
|
||||
repo: Repository name
|
||||
filepath: Path to file within repository
|
||||
ref: Branch, tag, or commit ref (defaults to 'main')
|
||||
|
||||
Returns:
|
||||
File contents dict with 'content', 'encoding', 'size', etc.
|
||||
|
||||
Raises:
|
||||
GiteaNotFoundError: If file doesn't exist
|
||||
GiteaError: On other API errors
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("Client not initialized - use async context manager")
|
||||
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
filepath: str,
|
||||
ref: str = "main",
|
||||
) -> dict[str, Any]:
|
||||
"""Get file contents from a repository."""
|
||||
repo_id = f"{owner}/{repo}"
|
||||
correlation_id = self.audit.log_tool_invocation(
|
||||
tool_name="get_file_contents",
|
||||
@@ -275,20 +217,22 @@ class GiteaClient:
|
||||
params={"ref": ref},
|
||||
result_status="pending",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.get(
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/contents/{filepath}",
|
||||
params={"ref": ref},
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
file_data = self._handle_response(response, correlation_id)
|
||||
|
||||
# Check file size against limit
|
||||
file_size = file_data.get("size", 0)
|
||||
if not isinstance(result, dict):
|
||||
raise GiteaError("Unexpected response type for file contents")
|
||||
|
||||
file_size = int(result.get("size", 0))
|
||||
if file_size > self.settings.max_file_size_bytes:
|
||||
error_msg = (
|
||||
f"File size ({file_size} bytes) exceeds "
|
||||
f"limit ({self.settings.max_file_size_bytes} bytes)"
|
||||
f"File size ({file_size} bytes) exceeds limit "
|
||||
f"({self.settings.max_file_size_bytes} bytes)"
|
||||
)
|
||||
self.audit.log_security_event(
|
||||
event_type="file_size_limit_exceeded",
|
||||
@@ -311,9 +255,7 @@ class GiteaClient:
|
||||
result_status="success",
|
||||
params={"ref": ref, "size": file_size},
|
||||
)
|
||||
|
||||
return file_data
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_file_contents",
|
||||
@@ -326,25 +268,13 @@ class GiteaClient:
|
||||
raise
|
||||
|
||||
async def get_tree(
|
||||
self, owner: str, repo: str, ref: str = "main", recursive: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Get file tree for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner username
|
||||
repo: Repository name
|
||||
ref: Branch, tag, or commit ref (defaults to 'main')
|
||||
recursive: Whether to recursively fetch tree (default: False for safety)
|
||||
|
||||
Returns:
|
||||
Tree information dict
|
||||
|
||||
Raises:
|
||||
GiteaError: On API errors
|
||||
"""
|
||||
if not self.client:
|
||||
raise RuntimeError("Client not initialized - use async context manager")
|
||||
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
ref: str = "main",
|
||||
recursive: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Get repository tree at given ref."""
|
||||
repo_id = f"{owner}/{repo}"
|
||||
correlation_id = self.audit.log_tool_invocation(
|
||||
tool_name="get_tree",
|
||||
@@ -352,24 +282,26 @@ class GiteaClient:
|
||||
params={"ref": ref, "recursive": recursive},
|
||||
result_status="pending",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.get(
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/git/trees/{ref}",
|
||||
params={"recursive": str(recursive).lower()},
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
tree_data = self._handle_response(response, correlation_id)
|
||||
|
||||
tree_data = result if isinstance(result, dict) else {}
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_tree",
|
||||
repository=repo_id,
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
params={"ref": ref, "recursive": recursive, "count": len(tree_data.get("tree", []))},
|
||||
params={
|
||||
"ref": ref,
|
||||
"recursive": recursive,
|
||||
"count": len(tree_data.get("tree", [])),
|
||||
},
|
||||
)
|
||||
|
||||
return tree_data
|
||||
|
||||
except Exception as exc:
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_tree",
|
||||
@@ -379,3 +311,326 @@ class GiteaClient:
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
|
||||
async def search_code(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
query: str,
|
||||
*,
|
||||
ref: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Search repository code by query."""
|
||||
correlation_id = self.audit.log_tool_invocation(
|
||||
tool_name="search_code",
|
||||
repository=f"{owner}/{repo}",
|
||||
params={"query": query, "ref": ref, "page": page, "limit": limit},
|
||||
result_status="pending",
|
||||
)
|
||||
try:
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/search",
|
||||
params={"q": query, "page": page, "limit": limit, "ref": ref},
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="search_code",
|
||||
repository=f"{owner}/{repo}",
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
except Exception as exc:
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="search_code",
|
||||
repository=f"{owner}/{repo}",
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=str(exc),
|
||||
)
|
||||
raise
|
||||
|
||||
async def list_commits(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
*,
|
||||
ref: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List commits for a repository ref."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/commits",
|
||||
params={"sha": ref, "page": page, "limit": limit},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]:
|
||||
"""Get detailed commit including changed files and patch metadata."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/git/commits/{sha}",
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]:
|
||||
"""Compare two refs and return commit/file deltas."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/compare/{base}...{head}",
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def list_issues(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
*,
|
||||
state: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
labels: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List repository issues."""
|
||||
params: dict[str, Any] = {"state": state, "page": page, "limit": limit}
|
||||
if labels:
|
||||
params["labels"] = ",".join(labels)
|
||||
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues",
|
||||
params=params,
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
async def get_issue(self, owner: str, repo: str, index: int) -> dict[str, Any]:
|
||||
"""Get issue details."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}",
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def list_pull_requests(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
*,
|
||||
state: str,
|
||||
page: int,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List pull requests for repository."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/pulls",
|
||||
params={"state": state, "page": page, "limit": limit},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="list_pull_requests", result_status="pending"
|
||||
)
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
async def get_pull_request(self, owner: str, repo: str, index: int) -> dict[str, Any]:
|
||||
"""Get a single pull request."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/pulls/{index}",
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="get_pull_request", result_status="pending"
|
||||
)
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def list_labels(
|
||||
self, owner: str, repo: str, *, page: int, limit: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List repository labels."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/labels",
|
||||
params={"page": page, "limit": limit},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
async def list_tags(
|
||||
self, owner: str, repo: str, *, page: int, limit: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List repository tags."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/tags",
|
||||
params={"page": page, "limit": limit},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
async def list_releases(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List repository releases."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/repos/{owner}/{repo}/releases",
|
||||
params={"page": page, "limit": limit},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
|
||||
async def create_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
*,
|
||||
title: str,
|
||||
body: str,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create repository issue."""
|
||||
payload: dict[str, Any] = {"title": title, "body": body}
|
||||
if labels:
|
||||
payload["labels"] = labels
|
||||
if assignees:
|
||||
payload["assignees"] = assignees
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues",
|
||||
json_body=payload,
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def update_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
*,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update issue fields."""
|
||||
payload: dict[str, Any] = {}
|
||||
if title is not None:
|
||||
payload["title"] = title
|
||||
if body is not None:
|
||||
payload["body"] = body
|
||||
if state is not None:
|
||||
payload["state"] = state
|
||||
result = await self._request(
|
||||
"PATCH",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}",
|
||||
json_body=payload,
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def create_issue_comment(
|
||||
self, owner: str, repo: str, index: int, body: str
|
||||
) -> dict[str, Any]:
|
||||
"""Create a comment on issue (and PR discussion if issue index refers to PR)."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments",
|
||||
json_body={"body": body},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="create_issue_comment", result_status="pending"
|
||||
)
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def create_pr_comment(
|
||||
self, owner: str, repo: str, index: int, body: str
|
||||
) -> dict[str, Any]:
|
||||
"""Create PR discussion comment."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments",
|
||||
json_body={"body": body},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="create_pr_comment", result_status="pending"
|
||||
)
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def add_labels(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
labels: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Add labels to issue/PR."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/labels",
|
||||
json_body={"labels": labels},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def assign_issue(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
index: int,
|
||||
assignees: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Assign users to issue/PR."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/assignees",
|
||||
json_body={"assignees": assignees},
|
||||
correlation_id=str(
|
||||
self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending")
|
||||
),
|
||||
)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
48
src/aegis_gitea_mcp/logging_utils.py
Normal file
48
src/aegis_gitea_mcp/logging_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Structured logging configuration utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from aegis_gitea_mcp.request_context import get_request_id
|
||||
|
||||
|
||||
class JsonLogFormatter(logging.Formatter):
|
||||
"""Format log records as JSON documents."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Serialize a log record to JSON."""
|
||||
payload = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
"request_id": get_request_id(),
|
||||
}
|
||||
|
||||
if record.exc_info:
|
||||
# Security decision: include only exception type to avoid stack leakage.
|
||||
exception_type = record.exc_info[0]
|
||||
if exception_type is not None:
|
||||
payload["exception_type"] = str(exception_type.__name__)
|
||||
|
||||
return json.dumps(payload, separators=(",", ":"), ensure_ascii=True)
|
||||
|
||||
|
||||
def configure_logging(level: str) -> None:
|
||||
"""Configure application-wide structured JSON logging.
|
||||
|
||||
Args:
|
||||
level: Logging level string.
|
||||
"""
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(level.upper())
|
||||
|
||||
for handler in list(logger.handlers):
|
||||
logger.removeHandler(handler)
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(JsonLogFormatter())
|
||||
logger.addHandler(stream_handler)
|
||||
@@ -1,6 +1,8 @@
|
||||
"""MCP protocol implementation for AegisGitea."""
|
||||
"""MCP protocol models and tool registry."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -10,153 +12,366 @@ class MCPTool(BaseModel):
|
||||
|
||||
name: str = Field(..., description="Unique tool identifier")
|
||||
description: str = Field(..., description="Human-readable tool description")
|
||||
input_schema: Dict[str, Any] = Field(
|
||||
..., alias="inputSchema", description="JSON Schema for tool input"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
)
|
||||
input_schema: dict[str, Any] = Field(..., description="JSON schema describing input arguments")
|
||||
write_operation: bool = Field(default=False, description="Whether tool mutates data")
|
||||
|
||||
|
||||
class MCPToolCallRequest(BaseModel):
|
||||
"""Request to invoke an MCP tool."""
|
||||
|
||||
tool: str = Field(..., description="Name of the tool to invoke")
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict, description="Tool arguments")
|
||||
correlation_id: Optional[str] = Field(None, description="Request correlation ID")
|
||||
arguments: dict[str, Any] = Field(default_factory=dict, description="Tool argument payload")
|
||||
correlation_id: str | None = Field(default=None, description="Request correlation ID")
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class MCPToolCallResponse(BaseModel):
|
||||
"""Response from an MCP tool invocation."""
|
||||
"""Response returned from MCP tool invocation."""
|
||||
|
||||
success: bool = Field(..., description="Whether the tool call succeeded")
|
||||
result: Optional[Any] = Field(None, description="Tool result data")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
correlation_id: str = Field(..., description="Request correlation ID")
|
||||
success: bool = Field(..., description="Whether invocation succeeded")
|
||||
result: Any | None = Field(default=None, description="Tool result payload")
|
||||
error: str | None = Field(default=None, description="Error message for failed request")
|
||||
correlation_id: str = Field(..., description="Correlation ID for request tracing")
|
||||
|
||||
|
||||
class MCPListToolsResponse(BaseModel):
|
||||
"""Response listing available MCP tools."""
|
||||
"""Response listing available tools."""
|
||||
|
||||
tools: List[MCPTool] = Field(..., description="List of available tools")
|
||||
tools: list[MCPTool] = Field(..., description="Available tool definitions")
|
||||
|
||||
|
||||
# Tool definitions for AegisGitea MCP
|
||||
def _tool(
|
||||
name: str, description: str, schema: dict[str, Any], write_operation: bool = False
|
||||
) -> MCPTool:
|
||||
"""Construct tool metadata entry."""
|
||||
return MCPTool(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=schema,
|
||||
write_operation=write_operation,
|
||||
)
|
||||
|
||||
TOOL_LIST_REPOSITORIES = MCPTool(
|
||||
name="list_repositories",
|
||||
description="List all repositories visible to the AI bot user. "
|
||||
"Only repositories where the bot has explicit read access will be returned. "
|
||||
"This respects Gitea's dynamic authorization model.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
TOOL_GET_REPOSITORY_INFO = MCPTool(
|
||||
name="get_repository_info",
|
||||
description="Get detailed information about a specific repository, "
|
||||
"including description, default branch, language, and metadata. "
|
||||
"Requires the bot user to have read access.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner username or organization",
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name",
|
||||
},
|
||||
AVAILABLE_TOOLS: list[MCPTool] = [
|
||||
_tool(
|
||||
"list_repositories",
|
||||
"List repositories visible to the configured bot account.",
|
||||
{"type": "object", "properties": {}, "required": []},
|
||||
),
|
||||
_tool(
|
||||
"get_repository_info",
|
||||
"Get metadata for a repository.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"owner": {"type": "string"}, "repo": {"type": "string"}},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
},
|
||||
)
|
||||
|
||||
TOOL_GET_FILE_TREE = MCPTool(
|
||||
name="get_file_tree",
|
||||
description="Get the file tree structure for a repository at a specific ref. "
|
||||
"Returns a list of files and directories. "
|
||||
"Non-recursive by default for safety (max depth: 1 level).",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner username or organization",
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name",
|
||||
},
|
||||
"ref": {
|
||||
"type": "string",
|
||||
"description": "Branch, tag, or commit SHA (defaults to 'main')",
|
||||
"default": "main",
|
||||
},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to recursively fetch entire tree (use with caution)",
|
||||
"default": False,
|
||||
),
|
||||
_tool(
|
||||
"get_file_tree",
|
||||
"Get repository tree at a selected ref.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"ref": {"type": "string", "default": "main"},
|
||||
"recursive": {"type": "boolean", "default": False},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
},
|
||||
)
|
||||
|
||||
TOOL_GET_FILE_CONTENTS = MCPTool(
|
||||
name="get_file_contents",
|
||||
description="Read the contents of a specific file in a repository. "
|
||||
"File size is limited to 1MB by default for safety. "
|
||||
"Returns base64-encoded content for binary files.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner username or organization",
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name",
|
||||
},
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to file within repository (e.g., 'src/main.py')",
|
||||
},
|
||||
"ref": {
|
||||
"type": "string",
|
||||
"description": "Branch, tag, or commit SHA (defaults to 'main')",
|
||||
"default": "main",
|
||||
),
|
||||
_tool(
|
||||
"get_file_contents",
|
||||
"Read a repository file with size-limited content.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"filepath": {"type": "string"},
|
||||
"ref": {"type": "string", "default": "main"},
|
||||
},
|
||||
"required": ["owner", "repo", "filepath"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["owner", "repo", "filepath"],
|
||||
},
|
||||
)
|
||||
|
||||
# Registry of all available tools
|
||||
AVAILABLE_TOOLS: List[MCPTool] = [
|
||||
TOOL_LIST_REPOSITORIES,
|
||||
TOOL_GET_REPOSITORY_INFO,
|
||||
TOOL_GET_FILE_TREE,
|
||||
TOOL_GET_FILE_CONTENTS,
|
||||
),
|
||||
_tool(
|
||||
"search_code",
|
||||
"Search code in a repository.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"query": {"type": "string"},
|
||||
"ref": {"type": "string", "default": "main"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
|
||||
},
|
||||
"required": ["owner", "repo", "query"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"list_commits",
|
||||
"List commits for a repository ref.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"ref": {"type": "string", "default": "main"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"get_commit_diff",
|
||||
"Get commit metadata and file diffs.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"sha": {"type": "string"},
|
||||
},
|
||||
"required": ["owner", "repo", "sha"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"compare_refs",
|
||||
"Compare two repository refs.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"base": {"type": "string"},
|
||||
"head": {"type": "string"},
|
||||
},
|
||||
"required": ["owner", "repo", "base", "head"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"list_issues",
|
||||
"List repository issues.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"state": {"type": "string", "enum": ["open", "closed", "all"], "default": "open"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
|
||||
"labels": {"type": "array", "items": {"type": "string"}, "default": []},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"get_issue",
|
||||
"Get repository issue details.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"issue_number": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["owner", "repo", "issue_number"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"list_pull_requests",
|
||||
"List repository pull requests.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"state": {"type": "string", "enum": ["open", "closed", "all"], "default": "open"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"get_pull_request",
|
||||
"Get pull request details.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"pull_number": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["owner", "repo", "pull_number"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"list_labels",
|
||||
"List labels defined on a repository.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 50},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"list_tags",
|
||||
"List repository tags.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 50},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"list_releases",
|
||||
"List repository releases.",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"page": {"type": "integer", "minimum": 1, "default": 1},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100, "default": 25},
|
||||
},
|
||||
"required": ["owner", "repo"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
_tool(
|
||||
"create_issue",
|
||||
"Create a repository issue (write-mode only).",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"title": {"type": "string"},
|
||||
"body": {"type": "string", "default": ""},
|
||||
"labels": {"type": "array", "items": {"type": "string"}, "default": []},
|
||||
"assignees": {"type": "array", "items": {"type": "string"}, "default": []},
|
||||
},
|
||||
"required": ["owner", "repo", "title"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
write_operation=True,
|
||||
),
|
||||
_tool(
|
||||
"update_issue",
|
||||
"Update issue title/body/state (write-mode only).",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"issue_number": {"type": "integer", "minimum": 1},
|
||||
"title": {"type": "string"},
|
||||
"body": {"type": "string"},
|
||||
"state": {"type": "string", "enum": ["open", "closed"]},
|
||||
},
|
||||
"required": ["owner", "repo", "issue_number"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
write_operation=True,
|
||||
),
|
||||
_tool(
|
||||
"create_issue_comment",
|
||||
"Create issue comment (write-mode only).",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"issue_number": {"type": "integer", "minimum": 1},
|
||||
"body": {"type": "string"},
|
||||
},
|
||||
"required": ["owner", "repo", "issue_number", "body"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
write_operation=True,
|
||||
),
|
||||
_tool(
|
||||
"create_pr_comment",
|
||||
"Create pull request comment (write-mode only).",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"pull_number": {"type": "integer", "minimum": 1},
|
||||
"body": {"type": "string"},
|
||||
},
|
||||
"required": ["owner", "repo", "pull_number", "body"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
write_operation=True,
|
||||
),
|
||||
_tool(
|
||||
"add_labels",
|
||||
"Add labels to an issue or PR (write-mode only).",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"issue_number": {"type": "integer", "minimum": 1},
|
||||
"labels": {"type": "array", "items": {"type": "string"}, "minItems": 1},
|
||||
},
|
||||
"required": ["owner", "repo", "issue_number", "labels"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
write_operation=True,
|
||||
),
|
||||
_tool(
|
||||
"assign_issue",
|
||||
"Assign users to issue or PR (write-mode only).",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {"type": "string"},
|
||||
"repo": {"type": "string"},
|
||||
"issue_number": {"type": "integer", "minimum": 1},
|
||||
"assignees": {"type": "array", "items": {"type": "string"}, "minItems": 1},
|
||||
},
|
||||
"required": ["owner", "repo", "issue_number", "assignees"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
write_operation=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_tool_by_name(tool_name: str) -> Optional[MCPTool]:
|
||||
"""Get tool definition by name.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to retrieve
|
||||
|
||||
Returns:
|
||||
Tool definition or None if not found
|
||||
"""
|
||||
def get_tool_by_name(tool_name: str) -> MCPTool | None:
|
||||
"""Get tool definition by name."""
|
||||
for tool in AVAILABLE_TOOLS:
|
||||
if tool.name == tool_name:
|
||||
return tool
|
||||
|
||||
98
src/aegis_gitea_mcp/observability.py
Normal file
98
src/aegis_gitea_mcp/observability.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Observability primitives: metrics and lightweight instrumentation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolTiming:
|
||||
"""Aggregated tool timing stats."""
|
||||
|
||||
count: int
|
||||
total_seconds: float
|
||||
|
||||
|
||||
class MetricsRegistry:
|
||||
"""In-process Prometheus-compatible metrics storage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize empty metrics state."""
|
||||
self._lock = Lock()
|
||||
self._http_requests_total: defaultdict[tuple[str, str, str], int] = defaultdict(int)
|
||||
self._tool_calls_total: defaultdict[tuple[str, str], int] = defaultdict(int)
|
||||
self._tool_duration_seconds: defaultdict[str, float] = defaultdict(float)
|
||||
self._tool_duration_count: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
def record_http_request(self, method: str, path: str, status_code: int) -> None:
|
||||
"""Record completed HTTP request metric."""
|
||||
with self._lock:
|
||||
self._http_requests_total[(method, path, str(status_code))] += 1
|
||||
|
||||
def record_tool_call(self, tool_name: str, status: str, duration_seconds: float) -> None:
|
||||
"""Record tool invocation counters and duration aggregates."""
|
||||
with self._lock:
|
||||
self._tool_calls_total[(tool_name, status)] += 1
|
||||
self._tool_duration_seconds[tool_name] += max(duration_seconds, 0.0)
|
||||
self._tool_duration_count[tool_name] += 1
|
||||
|
||||
def render_prometheus(self) -> str:
|
||||
"""Render metrics in Prometheus exposition format."""
|
||||
lines: list[str] = []
|
||||
|
||||
lines.append("# HELP aegis_http_requests_total Total HTTP requests")
|
||||
lines.append("# TYPE aegis_http_requests_total counter")
|
||||
with self._lock:
|
||||
for (method, path, status), count in sorted(self._http_requests_total.items()):
|
||||
lines.append(
|
||||
"aegis_http_requests_total"
|
||||
f'{{method="{method}",path="{path}",status="{status}"}} {count}'
|
||||
)
|
||||
|
||||
lines.append("# HELP aegis_tool_calls_total Total MCP tool calls")
|
||||
lines.append("# TYPE aegis_tool_calls_total counter")
|
||||
for (tool_name, status), count in sorted(self._tool_calls_total.items()):
|
||||
lines.append(
|
||||
"aegis_tool_calls_total" f'{{tool="{tool_name}",status="{status}"}} {count}'
|
||||
)
|
||||
|
||||
lines.append(
|
||||
"# HELP aegis_tool_duration_seconds_sum Sum of MCP tool call duration seconds"
|
||||
)
|
||||
lines.append("# TYPE aegis_tool_duration_seconds_sum counter")
|
||||
for tool_name, total in sorted(self._tool_duration_seconds.items()):
|
||||
lines.append(f'aegis_tool_duration_seconds_sum{{tool="{tool_name}"}} {total:.6f}')
|
||||
|
||||
lines.append(
|
||||
"# HELP aegis_tool_duration_seconds_count MCP tool call duration sample count"
|
||||
)
|
||||
lines.append("# TYPE aegis_tool_duration_seconds_count counter")
|
||||
for tool_name, count in sorted(self._tool_duration_count.items()):
|
||||
lines.append(f'aegis_tool_duration_seconds_count{{tool="{tool_name}"}} {count}')
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
_metrics_registry: MetricsRegistry | None = None
|
||||
|
||||
|
||||
def get_metrics_registry() -> MetricsRegistry:
|
||||
"""Get global metrics registry."""
|
||||
global _metrics_registry
|
||||
if _metrics_registry is None:
|
||||
_metrics_registry = MetricsRegistry()
|
||||
return _metrics_registry
|
||||
|
||||
|
||||
def reset_metrics_registry() -> None:
|
||||
"""Reset global metrics registry for tests."""
|
||||
global _metrics_registry
|
||||
_metrics_registry = None
|
||||
|
||||
|
||||
def monotonic_seconds() -> float:
|
||||
"""Expose monotonic timer for deterministic instrumentation."""
|
||||
return time.monotonic()
|
||||
262
src/aegis_gitea_mcp/policy.py
Normal file
262
src/aegis_gitea_mcp/policy.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Policy engine for tool authorization decisions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
|
||||
|
||||
class PolicyError(Exception):
|
||||
"""Raised when policy loading or validation fails."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PolicyDecision:
|
||||
"""Authorization result for a policy check."""
|
||||
|
||||
allowed: bool
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuleSet:
|
||||
"""Allow/Deny rules for tools."""
|
||||
|
||||
allow: set[str] = field(default_factory=set)
|
||||
deny: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PathRules:
|
||||
"""Allow/Deny rules for target file paths."""
|
||||
|
||||
allow: tuple[str, ...] = ()
|
||||
deny: tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RepositoryPolicy:
|
||||
"""Repository-scoped policy rules."""
|
||||
|
||||
tools: RuleSet = field(default_factory=RuleSet)
|
||||
paths: PathRules = field(default_factory=PathRules)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PolicyConfig:
|
||||
"""Parsed policy configuration."""
|
||||
|
||||
default_read: str = "allow"
|
||||
default_write: str = "deny"
|
||||
tools: RuleSet = field(default_factory=RuleSet)
|
||||
repositories: dict[str, RepositoryPolicy] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
"""Evaluates authorization decisions for MCP tools."""
|
||||
|
||||
def __init__(self, config: PolicyConfig) -> None:
|
||||
"""Initialize policy engine with prevalidated config."""
|
||||
self.config = config
|
||||
self.settings = get_settings()
|
||||
|
||||
@classmethod
|
||||
def from_yaml_file(cls, policy_path: Path) -> PolicyEngine:
|
||||
"""Build a policy engine from YAML policy file.
|
||||
|
||||
Args:
|
||||
policy_path: Path to policy YAML file.
|
||||
|
||||
Returns:
|
||||
Initialized policy engine.
|
||||
|
||||
Raises:
|
||||
PolicyError: If file is malformed or violates policy schema.
|
||||
"""
|
||||
if not policy_path.exists():
|
||||
# Secure default for writes, backwards-compatible allow for reads.
|
||||
return cls(PolicyConfig())
|
||||
|
||||
try:
|
||||
raw = yaml.safe_load(policy_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise PolicyError(f"Failed to parse policy YAML: {exc}") from exc
|
||||
|
||||
if raw is None:
|
||||
return cls(PolicyConfig())
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise PolicyError("Policy root must be a mapping")
|
||||
|
||||
defaults = raw.get("defaults", {})
|
||||
if defaults and not isinstance(defaults, dict):
|
||||
raise PolicyError("defaults must be a mapping")
|
||||
|
||||
default_read = str(defaults.get("read", "allow")).lower()
|
||||
default_write = str(defaults.get("write", "deny")).lower()
|
||||
if default_read not in {"allow", "deny"}:
|
||||
raise PolicyError("defaults.read must be 'allow' or 'deny'")
|
||||
if default_write not in {"allow", "deny"}:
|
||||
raise PolicyError("defaults.write must be 'allow' or 'deny'")
|
||||
|
||||
global_tools = cls._parse_tool_rules(raw.get("tools", {}), "tools")
|
||||
|
||||
repositories_raw = raw.get("repositories", {})
|
||||
if repositories_raw is None:
|
||||
repositories_raw = {}
|
||||
if not isinstance(repositories_raw, dict):
|
||||
raise PolicyError("repositories must be a mapping")
|
||||
|
||||
repositories: dict[str, RepositoryPolicy] = {}
|
||||
for repo_name, repo_payload in repositories_raw.items():
|
||||
if not isinstance(repo_name, str) or "/" not in repo_name:
|
||||
raise PolicyError("Repository keys must be in 'owner/repo' format")
|
||||
if not isinstance(repo_payload, dict):
|
||||
raise PolicyError(f"Repository policy for {repo_name} must be a mapping")
|
||||
|
||||
tool_rules = cls._parse_tool_rules(
|
||||
repo_payload.get("tools", {}),
|
||||
f"repositories.{repo_name}.tools",
|
||||
)
|
||||
|
||||
path_payload = repo_payload.get("paths", {})
|
||||
if path_payload and not isinstance(path_payload, dict):
|
||||
raise PolicyError(f"repositories.{repo_name}.paths must be a mapping")
|
||||
|
||||
allow_paths = cls._parse_path_list(path_payload.get("allow", []), "allow")
|
||||
deny_paths = cls._parse_path_list(path_payload.get("deny", []), "deny")
|
||||
|
||||
repositories[repo_name] = RepositoryPolicy(
|
||||
tools=tool_rules,
|
||||
paths=PathRules(allow=allow_paths, deny=deny_paths),
|
||||
)
|
||||
|
||||
return cls(
|
||||
PolicyConfig(
|
||||
default_read=default_read,
|
||||
default_write=default_write,
|
||||
tools=global_tools,
|
||||
repositories=repositories,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_rules(raw_rules: Any, location: str) -> RuleSet:
|
||||
"""Parse tool allow/deny mapping from raw payload."""
|
||||
if not raw_rules:
|
||||
return RuleSet()
|
||||
if not isinstance(raw_rules, dict):
|
||||
raise PolicyError(f"{location} must be a mapping")
|
||||
|
||||
allow = raw_rules.get("allow", [])
|
||||
deny = raw_rules.get("deny", [])
|
||||
|
||||
if not isinstance(allow, list) or not all(isinstance(item, str) for item in allow):
|
||||
raise PolicyError(f"{location}.allow must be a list of strings")
|
||||
if not isinstance(deny, list) or not all(isinstance(item, str) for item in deny):
|
||||
raise PolicyError(f"{location}.deny must be a list of strings")
|
||||
|
||||
return RuleSet(allow=set(allow), deny=set(deny))
|
||||
|
||||
@staticmethod
|
||||
def _parse_path_list(raw_paths: Any, label: str) -> tuple[str, ...]:
|
||||
"""Parse path allow/deny list."""
|
||||
if raw_paths is None:
|
||||
return ()
|
||||
if not isinstance(raw_paths, list) or not all(isinstance(item, str) for item in raw_paths):
|
||||
raise PolicyError(f"paths.{label} must be a list of strings")
|
||||
return tuple(raw_paths)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_target_path(path: str) -> str:
|
||||
"""Normalize path before policy matching.
|
||||
|
||||
Security note:
|
||||
Path normalization blocks traversal attempts before fnmatch
|
||||
comparisons are executed.
|
||||
"""
|
||||
normalized = path.replace("\\", "/").lstrip("/")
|
||||
parts = [part for part in normalized.split("/") if part and part != "."]
|
||||
if any(part == ".." for part in parts):
|
||||
raise PolicyError("Target path contains traversal sequence '..'")
|
||||
return "/".join(parts)
|
||||
|
||||
def authorize(
|
||||
self,
|
||||
tool_name: str,
|
||||
is_write: bool,
|
||||
repository: str | None = None,
|
||||
target_path: str | None = None,
|
||||
) -> PolicyDecision:
|
||||
"""Evaluate whether a tool call is authorized by policy.
|
||||
|
||||
Args:
|
||||
tool_name: Invoked MCP tool name.
|
||||
is_write: Whether the tool mutates data.
|
||||
repository: Optional `owner/repo` target repository.
|
||||
target_path: Optional file path target.
|
||||
|
||||
Returns:
|
||||
Policy decision indicating allow/deny and reason.
|
||||
"""
|
||||
if tool_name in self.config.tools.deny:
|
||||
return PolicyDecision(False, "tool denied by global policy")
|
||||
|
||||
if self.config.tools.allow and tool_name not in self.config.tools.allow:
|
||||
return PolicyDecision(False, "tool not allowed by global policy")
|
||||
|
||||
if is_write:
|
||||
if not self.settings.write_mode:
|
||||
return PolicyDecision(False, "write mode is disabled")
|
||||
|
||||
if not repository:
|
||||
return PolicyDecision(False, "write operation requires a repository target")
|
||||
|
||||
if repository not in self.settings.write_repository_whitelist:
|
||||
return PolicyDecision(False, "repository is not in write-mode whitelist")
|
||||
|
||||
repo_policy = self.config.repositories.get(repository) if repository else None
|
||||
|
||||
if repo_policy:
|
||||
if tool_name in repo_policy.tools.deny:
|
||||
return PolicyDecision(False, "tool denied for repository")
|
||||
if repo_policy.tools.allow and tool_name not in repo_policy.tools.allow:
|
||||
return PolicyDecision(False, "tool not allowed for repository")
|
||||
|
||||
if target_path:
|
||||
normalized_path = self._normalize_target_path(target_path)
|
||||
if repo_policy.paths.deny and any(
|
||||
fnmatch(normalized_path, pattern) for pattern in repo_policy.paths.deny
|
||||
):
|
||||
return PolicyDecision(False, "path denied by repository policy")
|
||||
if repo_policy.paths.allow and not any(
|
||||
fnmatch(normalized_path, pattern) for pattern in repo_policy.paths.allow
|
||||
):
|
||||
return PolicyDecision(False, "path not allowed by repository policy")
|
||||
|
||||
default_behavior = self.config.default_write if is_write else self.config.default_read
|
||||
return PolicyDecision(default_behavior == "allow", "default policy decision")
|
||||
|
||||
|
||||
_policy_engine: PolicyEngine | None = None
|
||||
|
||||
|
||||
def get_policy_engine() -> PolicyEngine:
|
||||
"""Get or create global policy engine instance."""
|
||||
global _policy_engine
|
||||
if _policy_engine is None:
|
||||
settings = get_settings()
|
||||
_policy_engine = PolicyEngine.from_yaml_file(settings.policy_file_path)
|
||||
return _policy_engine
|
||||
|
||||
|
||||
def reset_policy_engine() -> None:
|
||||
"""Reset global policy engine (mainly for tests)."""
|
||||
global _policy_engine
|
||||
_policy_engine = None
|
||||
110
src/aegis_gitea_mcp/rate_limit.py
Normal file
110
src/aegis_gitea_mcp/rate_limit.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""In-memory request rate limiting for MCP endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RateLimitDecision:
|
||||
"""Result of request rate-limit checks."""
|
||||
|
||||
allowed: bool
|
||||
reason: str
|
||||
|
||||
|
||||
class SlidingWindowLimiter:
|
||||
"""Sliding-window limiter keyed by arbitrary identifiers."""
|
||||
|
||||
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
||||
"""Initialize a fixed-window limiter.
|
||||
|
||||
Args:
|
||||
max_requests: Maximum allowed requests within window.
|
||||
window_seconds: Rolling time window length.
|
||||
"""
|
||||
self.max_requests = max_requests
|
||||
self.window_seconds = window_seconds
|
||||
self._events: dict[str, deque[float]] = defaultdict(deque)
|
||||
|
||||
def allow(self, key: str) -> bool:
|
||||
"""Check and record request for the provided key."""
|
||||
now = time.time()
|
||||
boundary = now - self.window_seconds
|
||||
|
||||
events = self._events[key]
|
||||
while events and events[0] < boundary:
|
||||
events.popleft()
|
||||
|
||||
if len(events) >= self.max_requests:
|
||||
return False
|
||||
|
||||
events.append(now)
|
||||
return True
|
||||
|
||||
|
||||
class RequestRateLimiter:
|
||||
"""Combined per-IP and per-token request limiter."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize with current settings."""
|
||||
settings = get_settings()
|
||||
self._audit = get_audit_logger()
|
||||
self._ip_limiter = SlidingWindowLimiter(settings.rate_limit_per_minute, 60)
|
||||
self._token_limiter = SlidingWindowLimiter(settings.token_rate_limit_per_minute, 60)
|
||||
|
||||
def check(self, client_ip: str, token: str | None) -> RateLimitDecision:
|
||||
"""Evaluate request against IP and token limits.
|
||||
|
||||
Args:
|
||||
client_ip: Request source IP.
|
||||
token: Optional authenticated API token.
|
||||
|
||||
Returns:
|
||||
Rate limit decision.
|
||||
"""
|
||||
if not self._ip_limiter.allow(client_ip):
|
||||
self._audit.log_security_event(
|
||||
event_type="rate_limit_ip_exceeded",
|
||||
description="Per-IP request rate limit exceeded",
|
||||
severity="medium",
|
||||
metadata={"client_ip": client_ip},
|
||||
)
|
||||
return RateLimitDecision(False, "Per-IP rate limit exceeded")
|
||||
|
||||
if token:
|
||||
# Hash token before using it as a key to avoid storing secrets in memory maps.
|
||||
token_key = hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
if not self._token_limiter.allow(token_key):
|
||||
self._audit.log_security_event(
|
||||
event_type="rate_limit_token_exceeded",
|
||||
description="Per-token request rate limit exceeded",
|
||||
severity="high",
|
||||
metadata={"client_ip": client_ip},
|
||||
)
|
||||
return RateLimitDecision(False, "Per-token rate limit exceeded")
|
||||
|
||||
return RateLimitDecision(True, "within limits")
|
||||
|
||||
|
||||
_rate_limiter: RequestRateLimiter | None = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RequestRateLimiter:
|
||||
"""Get global request limiter."""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RequestRateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
def reset_rate_limiter() -> None:
|
||||
"""Reset global limiter (primarily for tests)."""
|
||||
global _rate_limiter
|
||||
_rate_limiter = None
|
||||
17
src/aegis_gitea_mcp/request_context.py
Normal file
17
src/aegis_gitea_mcp/request_context.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Request context utilities for correlation and logging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
_REQUEST_ID: ContextVar[str] = ContextVar("request_id", default="-")
|
||||
|
||||
|
||||
def set_request_id(request_id: str) -> None:
|
||||
"""Store request id in context-local state."""
|
||||
_REQUEST_ID.set(request_id)
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""Get current request id from context-local state."""
|
||||
return _REQUEST_ID.get()
|
||||
56
src/aegis_gitea_mcp/response_limits.py
Normal file
56
src/aegis_gitea_mcp/response_limits.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Helpers for bounded tool responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
|
||||
|
||||
class ResponseLimitError(RuntimeError):
|
||||
"""Raised when response processing exceeds configured safety limits."""
|
||||
|
||||
|
||||
def limit_items(
|
||||
items: list[dict[str, Any]], configured_limit: int | None = None
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Trim a list of result items to configured maximum length.
|
||||
|
||||
Args:
|
||||
items: List of result dictionaries.
|
||||
configured_limit: Optional explicit item limit.
|
||||
|
||||
Returns:
|
||||
Tuple of trimmed list and omitted count.
|
||||
"""
|
||||
settings = get_settings()
|
||||
max_items = configured_limit or settings.max_tool_response_items
|
||||
if max_items <= 0:
|
||||
raise ResponseLimitError("max_tool_response_items must be greater than zero")
|
||||
|
||||
if len(items) <= max_items:
|
||||
return items, 0
|
||||
|
||||
trimmed = items[:max_items]
|
||||
omitted = len(items) - max_items
|
||||
return trimmed, omitted
|
||||
|
||||
|
||||
def limit_text(text: str, configured_limit: int | None = None) -> str:
|
||||
"""Trim text output to configured maximum characters.
|
||||
|
||||
Args:
|
||||
text: Untrusted text output.
|
||||
configured_limit: Optional explicit character limit.
|
||||
|
||||
Returns:
|
||||
Trimmed text.
|
||||
"""
|
||||
settings = get_settings()
|
||||
max_chars = configured_limit or settings.max_tool_response_chars
|
||||
if max_chars <= 0:
|
||||
raise ResponseLimitError("max_tool_response_chars must be greater than zero")
|
||||
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars]
|
||||
134
src/aegis_gitea_mcp/security.py
Normal file
134
src/aegis_gitea_mcp/security.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Security helpers for secret detection and untrusted content handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SecretMatch:
|
||||
"""Represents a detected secret-like token."""
|
||||
|
||||
secret_type: str
|
||||
value: str
|
||||
|
||||
|
||||
_SECRET_PATTERNS: tuple[tuple[str, re.Pattern[str]], ...] = (
|
||||
(
|
||||
"openai_key",
|
||||
re.compile(r"\bsk-[A-Za-z0-9_-]{20,}\b"),
|
||||
),
|
||||
(
|
||||
"aws_access_key",
|
||||
re.compile(r"\bAKIA[0-9A-Z]{16}\b"),
|
||||
),
|
||||
(
|
||||
"github_token",
|
||||
re.compile(r"\bgh[pousr]_[A-Za-z0-9]{20,}\b"),
|
||||
),
|
||||
(
|
||||
"jwt",
|
||||
re.compile(r"\beyJ[A-Za-z0-9_-]{8,}\.[A-Za-z0-9_-]{4,}\.[A-Za-z0-9_-]{4,}\b"),
|
||||
),
|
||||
(
|
||||
"private_key",
|
||||
re.compile(r"-----BEGIN (?:RSA |EC |OPENSSH |)PRIVATE KEY-----"),
|
||||
),
|
||||
(
|
||||
"generic_api_key",
|
||||
re.compile(r"\b(?:api[_-]?key|token)[\"'=: ]+[A-Za-z0-9_-]{16,}\b", re.IGNORECASE),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def detect_secrets(text: str) -> list[SecretMatch]:
|
||||
"""Detect common secret patterns in text.
|
||||
|
||||
Args:
|
||||
text: Untrusted text to scan.
|
||||
|
||||
Returns:
|
||||
List of detected secret-like values.
|
||||
"""
|
||||
matches: list[SecretMatch] = []
|
||||
for secret_type, pattern in _SECRET_PATTERNS:
|
||||
for found in pattern.findall(text):
|
||||
if isinstance(found, tuple):
|
||||
candidate = "".join(found)
|
||||
else:
|
||||
candidate = found
|
||||
matches.append(SecretMatch(secret_type=secret_type, value=candidate))
|
||||
return matches
|
||||
|
||||
|
||||
def mask_secret(value: str) -> str:
|
||||
"""Mask a secret value while preserving minimal context.
|
||||
|
||||
Args:
|
||||
value: Raw secret text.
|
||||
|
||||
Returns:
|
||||
Masked string that does not reveal the secret.
|
||||
"""
|
||||
if len(value) <= 8:
|
||||
return "[REDACTED]"
|
||||
return f"{value[:4]}...{value[-4:]}"
|
||||
|
||||
|
||||
def sanitize_data(value: Any, mode: str = "mask") -> Any:
|
||||
"""Recursively sanitize secret-like material from arbitrary data.
|
||||
|
||||
Args:
|
||||
value: Arbitrary response payload.
|
||||
mode: `mask` to keep redacted content, `block` to fully replace fields.
|
||||
|
||||
Returns:
|
||||
Sanitized payload value.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {str(key): sanitize_data(item, mode=mode) for key, item in value.items()}
|
||||
|
||||
if isinstance(value, list):
|
||||
return [sanitize_data(item, mode=mode) for item in value]
|
||||
|
||||
if isinstance(value, tuple):
|
||||
return tuple(sanitize_data(item, mode=mode) for item in value)
|
||||
|
||||
if isinstance(value, str):
|
||||
findings = detect_secrets(value)
|
||||
if not findings:
|
||||
return value
|
||||
|
||||
if mode == "block":
|
||||
return "[REDACTED_SECRET]"
|
||||
|
||||
masked = value
|
||||
for finding in findings:
|
||||
masked = masked.replace(finding.value, mask_secret(finding.value))
|
||||
return masked
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def sanitize_untrusted_text(text: str, max_chars: int) -> str:
|
||||
"""Normalize untrusted repository content for display-only usage.
|
||||
|
||||
Security note:
|
||||
Repository content is always treated as data and never interpreted as
|
||||
executable instructions. This helper enforces a strict length limit to
|
||||
prevent prompt-stuffing through oversized payloads.
|
||||
|
||||
Args:
|
||||
text: Repository text content.
|
||||
max_chars: Maximum allowed characters in returned text.
|
||||
|
||||
Returns:
|
||||
Truncated text safe for downstream display.
|
||||
"""
|
||||
if max_chars <= 0:
|
||||
return ""
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars]
|
||||
@@ -1,16 +1,24 @@
|
||||
"""Main MCP server implementation with FastAPI and SSE support."""
|
||||
"""Main MCP server implementation with hardened security controls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.auth import get_validator
|
||||
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
from aegis_gitea_mcp.gitea_client import GiteaClient
|
||||
from aegis_gitea_mcp.logging_utils import configure_logging
|
||||
from aegis_gitea_mcp.mcp_protocol import (
|
||||
AVAILABLE_TOOLS,
|
||||
MCPListToolsResponse,
|
||||
@@ -18,276 +26,443 @@ from aegis_gitea_mcp.mcp_protocol import (
|
||||
MCPToolCallResponse,
|
||||
get_tool_by_name,
|
||||
)
|
||||
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
|
||||
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
|
||||
from aegis_gitea_mcp.rate_limit import get_rate_limiter
|
||||
from aegis_gitea_mcp.request_context import set_request_id
|
||||
from aegis_gitea_mcp.security import sanitize_data
|
||||
from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path
|
||||
from aegis_gitea_mcp.tools.read_tools import (
|
||||
compare_refs_tool,
|
||||
get_commit_diff_tool,
|
||||
get_issue_tool,
|
||||
get_pull_request_tool,
|
||||
list_commits_tool,
|
||||
list_issues_tool,
|
||||
list_labels_tool,
|
||||
list_pull_requests_tool,
|
||||
list_releases_tool,
|
||||
list_tags_tool,
|
||||
search_code_tool,
|
||||
)
|
||||
from aegis_gitea_mcp.tools.repository import (
|
||||
get_file_contents_tool,
|
||||
get_file_tree_tool,
|
||||
get_repository_info_tool,
|
||||
list_repositories_tool,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
from aegis_gitea_mcp.tools.write_tools import (
|
||||
add_labels_tool,
|
||||
assign_issue_tool,
|
||||
create_issue_comment_tool,
|
||||
create_issue_tool,
|
||||
create_pr_comment_tool,
|
||||
update_issue_tool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="AegisGitea MCP Server",
|
||||
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
|
||||
version="0.1.0",
|
||||
version="0.2.0",
|
||||
)
|
||||
|
||||
# Global settings and audit logger
|
||||
# Note: access settings/audit logger dynamically to support test resets.
|
||||
|
||||
class AutomationWebhookRequest(BaseModel):
|
||||
"""Request body for automation webhook ingestion."""
|
||||
|
||||
event_type: str = Field(..., min_length=1, max_length=128)
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
repository: str | None = Field(default=None)
|
||||
|
||||
|
||||
# Tool dispatcher mapping
|
||||
TOOL_HANDLERS = {
|
||||
class AutomationJobRequest(BaseModel):
|
||||
"""Request body for automation job execution."""
|
||||
|
||||
job_name: str = Field(..., min_length=1, max_length=128)
|
||||
owner: str = Field(..., min_length=1, max_length=100)
|
||||
repo: str = Field(..., min_length=1, max_length=100)
|
||||
finding_title: str | None = Field(default=None, max_length=256)
|
||||
finding_body: str | None = Field(default=None, max_length=10_000)
|
||||
|
||||
|
||||
ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
|
||||
TOOL_HANDLERS: dict[str, ToolHandler] = {
|
||||
# Baseline read tools
|
||||
"list_repositories": list_repositories_tool,
|
||||
"get_repository_info": get_repository_info_tool,
|
||||
"get_file_tree": get_file_tree_tool,
|
||||
"get_file_contents": get_file_contents_tool,
|
||||
# Expanded read tools
|
||||
"search_code": search_code_tool,
|
||||
"list_commits": list_commits_tool,
|
||||
"get_commit_diff": get_commit_diff_tool,
|
||||
"compare_refs": compare_refs_tool,
|
||||
"list_issues": list_issues_tool,
|
||||
"get_issue": get_issue_tool,
|
||||
"list_pull_requests": list_pull_requests_tool,
|
||||
"get_pull_request": get_pull_request_tool,
|
||||
"list_labels": list_labels_tool,
|
||||
"list_tags": list_tags_tool,
|
||||
"list_releases": list_releases_tool,
|
||||
# Write-mode tools
|
||||
"create_issue": create_issue_tool,
|
||||
"update_issue": update_issue_tool,
|
||||
"create_issue_comment": create_issue_comment_tool,
|
||||
"create_pr_comment": create_pr_comment_tool,
|
||||
"add_labels": add_labels_tool,
|
||||
"assign_issue": assign_issue_tool,
|
||||
}
|
||||
|
||||
|
||||
# Authentication middleware
|
||||
@app.middleware("http")
|
||||
async def authenticate_request(request: Request, call_next):
|
||||
"""Authenticate all requests except health checks and root.
|
||||
async def request_context_middleware(
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Attach request correlation context and collect request metrics."""
|
||||
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
|
||||
set_request_id(request_id)
|
||||
request.state.request_id = request_id
|
||||
|
||||
Supports Mixed authentication mode where:
|
||||
- /mcp/tools (list tools) is publicly accessible (No Auth)
|
||||
- /mcp/tool/call (execute tools) requires authentication
|
||||
- /mcp/sse requires authentication
|
||||
"""
|
||||
# Skip authentication for health check and root endpoints
|
||||
if request.url.path in ["/", "/health"]:
|
||||
started_at = monotonic_seconds()
|
||||
status_code = 500
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
finally:
|
||||
duration = max(monotonic_seconds() - started_at, 0.0)
|
||||
logger.debug(
|
||||
"request_completed",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"duration_seconds": duration,
|
||||
"status_code": status_code,
|
||||
},
|
||||
)
|
||||
metrics = get_metrics_registry()
|
||||
metrics.record_http_request(request.method, request.url.path, status_code)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def authenticate_and_rate_limit(
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Apply rate-limiting and authentication for MCP endpoints."""
|
||||
settings = get_settings()
|
||||
|
||||
if request.url.path in {"/", "/health"}:
|
||||
return await call_next(request)
|
||||
|
||||
# Only authenticate MCP endpoints
|
||||
if not request.url.path.startswith("/mcp/"):
|
||||
if request.url.path == "/metrics" and settings.metrics_enabled:
|
||||
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
|
||||
return await call_next(request)
|
||||
|
||||
# Mixed mode: allow /mcp/tools without authentication (for ChatGPT discovery)
|
||||
if request.url.path == "/mcp/tools":
|
||||
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract client information
|
||||
validator = get_validator()
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
# Get validator instance (supports test resets)
|
||||
validator = get_validator()
|
||||
|
||||
# Extract Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
api_key = validator.extract_bearer_token(auth_header)
|
||||
|
||||
# Fallback: allow API key via query parameter only for MCP endpoints
|
||||
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
|
||||
api_key = request.query_params.get("api_key")
|
||||
|
||||
# Validate API key
|
||||
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
|
||||
rate_limit = limiter.check(client_ip=client_ip, token=api_key)
|
||||
if not rate_limit.allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"message": rate_limit.reason,
|
||||
"request_id": getattr(request.state, "request_id", "-"),
|
||||
},
|
||||
)
|
||||
|
||||
# Mixed mode: tool discovery remains public to preserve MCP client compatibility.
|
||||
if request.url.path == "/mcp/tools":
|
||||
return await call_next(request)
|
||||
|
||||
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
|
||||
if not is_valid:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "Authentication failed",
|
||||
"message": error_message,
|
||||
"detail": (
|
||||
"Provide a valid API key via Authorization header (Bearer <api-key>) "
|
||||
"or ?api_key=<api-key> query parameter"
|
||||
),
|
||||
"detail": "Provide Authorization: Bearer <api-key> or ?api_key=<api-key>",
|
||||
"request_id": getattr(request.state, "request_id", "-"),
|
||||
},
|
||||
)
|
||||
|
||||
# Authentication successful - continue to endpoint
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
"""Initialize server on startup."""
|
||||
"""Initialize server state on startup."""
|
||||
settings = get_settings()
|
||||
logger.info(f"Starting AegisGitea MCP Server on {settings.mcp_host}:{settings.mcp_port}")
|
||||
logger.info(f"Connected to Gitea instance: {settings.gitea_base_url}")
|
||||
logger.info(f"Audit logging enabled: {settings.audit_log_path}")
|
||||
configure_logging(settings.log_level)
|
||||
|
||||
# Log authentication status
|
||||
if settings.auth_enabled:
|
||||
key_count = len(settings.mcp_api_keys)
|
||||
logger.info(f"API key authentication ENABLED ({key_count} key(s) configured)")
|
||||
else:
|
||||
logger.warning("API key authentication DISABLED - server is open to all requests!")
|
||||
logger.info("server_starting")
|
||||
logger.info(
|
||||
"server_configuration",
|
||||
extra={
|
||||
"host": settings.mcp_host,
|
||||
"port": settings.mcp_port,
|
||||
"gitea_url": settings.gitea_base_url,
|
||||
"auth_enabled": settings.auth_enabled,
|
||||
"write_mode": settings.write_mode,
|
||||
"metrics_enabled": settings.metrics_enabled,
|
||||
},
|
||||
)
|
||||
|
||||
# Test Gitea connection
|
||||
# Fail-fast policy parse errors at startup.
|
||||
try:
|
||||
async with GiteaClient() as gitea:
|
||||
user = await gitea.get_current_user()
|
||||
logger.info(f"Authenticated as bot user: {user.get('login', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Gitea: {e}")
|
||||
_ = get_policy_engine()
|
||||
except PolicyError:
|
||||
logger.error("policy_load_failed")
|
||||
raise
|
||||
|
||||
if settings.startup_validate_gitea and settings.environment != "test":
|
||||
try:
|
||||
async with GiteaClient() as gitea:
|
||||
user = await gitea.get_current_user()
|
||||
logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")})
|
||||
except Exception:
|
||||
logger.error("gitea_connection_failed")
|
||||
raise
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
"""Cleanup on server shutdown."""
|
||||
logger.info("Shutting down AegisGitea MCP Server")
|
||||
"""Log server shutdown event."""
|
||||
logger.info("server_stopping")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, Any]:
|
||||
"""Root endpoint with server information."""
|
||||
async def root() -> dict[str, Any]:
|
||||
"""Root endpoint with server metadata."""
|
||||
return {
|
||||
"name": "AegisGitea MCP Server",
|
||||
"version": "0.1.0",
|
||||
"version": "0.2.0",
|
||||
"status": "running",
|
||||
"mcp_version": "1.0",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, str]:
|
||||
async def health() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/metrics")
|
||||
async def metrics() -> PlainTextResponse:
|
||||
"""Prometheus-compatible metrics endpoint."""
|
||||
settings = get_settings()
|
||||
if not settings.metrics_enabled:
|
||||
raise HTTPException(status_code=404, detail="Metrics endpoint disabled")
|
||||
data = get_metrics_registry().render_prometheus()
|
||||
return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4")
|
||||
|
||||
|
||||
@app.post("/automation/webhook")
|
||||
async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse:
|
||||
"""Ingest policy-controlled automation webhooks."""
|
||||
manager = AutomationManager()
|
||||
try:
|
||||
result = await manager.handle_webhook(
|
||||
event_type=request.event_type,
|
||||
payload=request.payload,
|
||||
repository=request.repository,
|
||||
)
|
||||
return JSONResponse(content={"success": True, "result": result})
|
||||
except AutomationError as exc:
|
||||
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.post("/automation/jobs/run")
|
||||
async def automation_run_job(request: AutomationJobRequest) -> JSONResponse:
|
||||
"""Execute a policy-controlled automation job for a repository."""
|
||||
manager = AutomationManager()
|
||||
try:
|
||||
result = await manager.run_job(
|
||||
job_name=request.job_name,
|
||||
owner=request.owner,
|
||||
repo=request.repo,
|
||||
finding_title=request.finding_title,
|
||||
finding_body=request.finding_body,
|
||||
)
|
||||
return JSONResponse(content={"success": True, "result": result})
|
||||
except AutomationError as exc:
|
||||
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.get("/mcp/tools")
|
||||
async def list_tools() -> JSONResponse:
|
||||
"""List all available MCP tools.
|
||||
|
||||
Returns:
|
||||
JSON response with list of tool definitions
|
||||
"""
|
||||
"""List all available MCP tools."""
|
||||
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
||||
return JSONResponse(content=response.model_dump(by_alias=True))
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
|
||||
async def _execute_tool_call(
|
||||
tool_name: str, arguments: dict[str, Any], correlation_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Execute tool call with policy checks and standardized response sanitization."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
metrics = get_metrics_registry()
|
||||
|
||||
tool_def = get_tool_by_name(tool_name)
|
||||
if not tool_def:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
|
||||
handler = TOOL_HANDLERS.get(tool_name)
|
||||
if not handler:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Tool '{tool_name}' has no handler implementation"
|
||||
)
|
||||
|
||||
repository = extract_repository(arguments)
|
||||
target_path = extract_target_path(arguments)
|
||||
decision = get_policy_engine().authorize(
|
||||
tool_name=tool_name,
|
||||
is_write=tool_def.write_operation,
|
||||
repository=repository,
|
||||
target_path=target_path,
|
||||
)
|
||||
if not decision.allowed:
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason=decision.reason,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}")
|
||||
|
||||
started_at = monotonic_seconds()
|
||||
status = "error"
|
||||
|
||||
try:
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, arguments)
|
||||
|
||||
if settings.secret_detection_mode != "off":
|
||||
# Security decision: sanitize outbound payloads to prevent accidental secret exfiltration.
|
||||
result = sanitize_data(result, mode=settings.secret_detection_mode)
|
||||
|
||||
status = "success"
|
||||
return result
|
||||
finally:
|
||||
duration = max(monotonic_seconds() - started_at, 0.0)
|
||||
metrics.record_tool_call(tool_name, status, duration)
|
||||
|
||||
|
||||
@app.post("/mcp/tool/call")
|
||||
async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
|
||||
"""Execute an MCP tool call.
|
||||
|
||||
Args:
|
||||
request: Tool call request with tool name and arguments
|
||||
|
||||
Returns:
|
||||
JSON response with tool execution result
|
||||
"""
|
||||
"""Execute an MCP tool call."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
|
||||
correlation_id = request.correlation_id or audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
params=request.arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate tool exists
|
||||
tool_def = get_tool_by_name(request.tool)
|
||||
if not tool_def:
|
||||
error_msg = f"Tool '{request.tool}' not found"
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
# Get tool handler
|
||||
handler = TOOL_HANDLERS.get(request.tool)
|
||||
if not handler:
|
||||
error_msg = f"Tool '{request.tool}' has no handler implementation"
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# Execute tool with Gitea client
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, request.arguments)
|
||||
|
||||
result = await _execute_tool_call(request.tool, request.arguments, correlation_id)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
response = MCPToolCallResponse(
|
||||
success=True,
|
||||
result=result,
|
||||
correlation_id=correlation_id,
|
||||
return JSONResponse(
|
||||
content=MCPToolCallResponse(
|
||||
success=True,
|
||||
result=result,
|
||||
correlation_id=correlation_id,
|
||||
).model_dump()
|
||||
)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like 404) without catching them
|
||||
except HTTPException as exc:
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=str(exc.detail),
|
||||
)
|
||||
raise
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = f"Invalid arguments: {str(e)}"
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
except ValidationError as exc:
|
||||
error_message = "Invalid tool arguments"
|
||||
if settings.expose_error_details:
|
||||
error_message = f"{error_message}: {exc}"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
error="validation_error",
|
||||
)
|
||||
response = MCPToolCallResponse(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
raise HTTPException(status_code=400, detail=error_message) from exc
|
||||
|
||||
except Exception:
|
||||
# Security decision: do not leak stack traces or raw exception messages.
|
||||
error_message = "Internal server error"
|
||||
if settings.expose_error_details:
|
||||
error_message = "Internal server error (details hidden unless explicitly enabled)"
|
||||
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error="internal_error",
|
||||
)
|
||||
logger.exception("tool_execution_failed")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=MCPToolCallResponse(
|
||||
success=False,
|
||||
error=error_message,
|
||||
correlation_id=correlation_id,
|
||||
).model_dump(),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(), status_code=500)
|
||||
|
||||
|
||||
@app.get("/mcp/sse")
|
||||
async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
"""Server-Sent Events endpoint for MCP protocol.
|
||||
"""Server-Sent Events endpoint for MCP transport."""
|
||||
|
||||
This enables real-time communication with ChatGPT using SSE.
|
||||
async def event_stream() -> AsyncGenerator[str, None]:
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
Returns:
|
||||
Streaming SSE response
|
||||
"""
|
||||
|
||||
async def event_stream():
|
||||
"""Generate SSE events."""
|
||||
# Send initial connection event
|
||||
yield f"data: {{'event': 'connected', 'server': 'AegisGitea MCP', 'version': '0.1.0'}}\n\n"
|
||||
|
||||
# Keep connection alive
|
||||
try:
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Heartbeat every 30 seconds
|
||||
yield f"data: {{'event': 'heartbeat'}}\n\n"
|
||||
|
||||
# Wait for next heartbeat (in production, this would handle actual events)
|
||||
import asyncio
|
||||
|
||||
yield 'data: {"event":"heartbeat"}\n\n'
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSE stream error: {e}")
|
||||
except Exception:
|
||||
logger.exception("sse_stream_error")
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
@@ -302,21 +477,12 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
|
||||
@app.post("/mcp/sse")
|
||||
async def sse_message_handler(request: Request) -> JSONResponse:
|
||||
"""Handle POST messages from ChatGPT MCP client to SSE endpoint.
|
||||
"""Handle POST messages for MCP SSE transport."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
|
||||
The MCP SSE transport uses:
|
||||
- GET /mcp/sse for server-to-client streaming
|
||||
- POST /mcp/sse for client-to-server messages
|
||||
|
||||
Returns:
|
||||
JSON response acknowledging the message
|
||||
"""
|
||||
try:
|
||||
audit = get_audit_logger()
|
||||
body = await request.json()
|
||||
logger.info(f"Received MCP message via SSE POST: {body}")
|
||||
|
||||
# Handle different message types
|
||||
message_type = body.get("type") or body.get("method")
|
||||
message_id = body.get("id")
|
||||
|
||||
@@ -328,87 +494,71 @@ async def sse_message_handler(request: Request) -> JSONResponse:
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "AegisGitea MCP", "version": "0.1.0"},
|
||||
"serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type == "tools/list":
|
||||
# Return the list of available tools
|
||||
if message_type == "tools/list":
|
||||
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": response.model_dump(by_alias=True),
|
||||
"result": response.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type == "tools/call":
|
||||
# Handle tool execution
|
||||
if message_type == "tools/call":
|
||||
tool_name = body.get("params", {}).get("name")
|
||||
tool_args = body.get("params", {}).get("arguments", {})
|
||||
|
||||
correlation_id = audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
params=tool_args,
|
||||
)
|
||||
|
||||
correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args)
|
||||
try:
|
||||
# Get tool handler
|
||||
handler = TOOL_HANDLERS.get(tool_name)
|
||||
if not handler:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
|
||||
# Execute tool with Gitea client
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, tool_args)
|
||||
|
||||
result = await _execute_tool_call(str(tool_name), tool_args, correlation_id)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
tool_name=str(tool_name),
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": {"content": [{"type": "text", "text": str(result)}]},
|
||||
"result": {"content": [{"type": "text", "text": json.dumps(result)}]},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
except Exception as exc:
|
||||
audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
tool_name=str(tool_name),
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
error=str(exc),
|
||||
)
|
||||
message = "Internal server error"
|
||||
if settings.expose_error_details:
|
||||
message = str(exc)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": {"code": -32603, "message": error_msg},
|
||||
"error": {"code": -32603, "message": message},
|
||||
}
|
||||
)
|
||||
|
||||
# Handle notifications (no response needed)
|
||||
elif message_type and message_type.startswith("notifications/"):
|
||||
logger.info(f"Received notification: {message_type}")
|
||||
if isinstance(message_type, str) and message_type.startswith("notifications/"):
|
||||
return JSONResponse(content={})
|
||||
|
||||
# Acknowledge other message types
|
||||
return JSONResponse(
|
||||
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling SSE POST message: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Invalid message format", "detail": str(e)}
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("sse_message_handler_error")
|
||||
message = "Invalid message format"
|
||||
if settings.expose_error_details:
|
||||
message = "Invalid message format (details hidden unless explicitly enabled)"
|
||||
return JSONResponse(status_code=400, content={"error": message})
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -1,15 +1,53 @@
|
||||
"""MCP tool implementations for AegisGitea."""
|
||||
"""MCP tool implementation exports."""
|
||||
|
||||
from aegis_gitea_mcp.tools.read_tools import (
|
||||
compare_refs_tool,
|
||||
get_commit_diff_tool,
|
||||
get_issue_tool,
|
||||
get_pull_request_tool,
|
||||
list_commits_tool,
|
||||
list_issues_tool,
|
||||
list_labels_tool,
|
||||
list_pull_requests_tool,
|
||||
list_releases_tool,
|
||||
list_tags_tool,
|
||||
search_code_tool,
|
||||
)
|
||||
from aegis_gitea_mcp.tools.repository import (
|
||||
get_file_contents_tool,
|
||||
get_file_tree_tool,
|
||||
get_repository_info_tool,
|
||||
list_repositories_tool,
|
||||
)
|
||||
from aegis_gitea_mcp.tools.write_tools import (
|
||||
add_labels_tool,
|
||||
assign_issue_tool,
|
||||
create_issue_comment_tool,
|
||||
create_issue_tool,
|
||||
create_pr_comment_tool,
|
||||
update_issue_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"list_repositories_tool",
|
||||
"get_repository_info_tool",
|
||||
"get_file_tree_tool",
|
||||
"get_file_contents_tool",
|
||||
"search_code_tool",
|
||||
"list_commits_tool",
|
||||
"get_commit_diff_tool",
|
||||
"compare_refs_tool",
|
||||
"list_issues_tool",
|
||||
"get_issue_tool",
|
||||
"list_pull_requests_tool",
|
||||
"get_pull_request_tool",
|
||||
"list_labels_tool",
|
||||
"list_tags_tool",
|
||||
"list_releases_tool",
|
||||
"create_issue_tool",
|
||||
"update_issue_tool",
|
||||
"create_issue_comment_tool",
|
||||
"create_pr_comment_tool",
|
||||
"add_labels_tool",
|
||||
"assign_issue_tool",
|
||||
]
|
||||
|
||||
208
src/aegis_gitea_mcp/tools/arguments.py
Normal file
208
src/aegis_gitea_mcp/tools/arguments.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Pydantic argument models for MCP tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
|
||||
|
||||
|
||||
class StrictBaseModel(BaseModel):
|
||||
"""Strict model base that rejects unexpected fields."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ListRepositoriesArgs(StrictBaseModel):
|
||||
"""Arguments for list_repositories tool."""
|
||||
|
||||
|
||||
class RepositoryArgs(StrictBaseModel):
|
||||
"""Common repository locator arguments."""
|
||||
|
||||
owner: str = Field(..., pattern=_REPO_PART_PATTERN)
|
||||
repo: str = Field(..., pattern=_REPO_PART_PATTERN)
|
||||
|
||||
|
||||
class FileTreeArgs(RepositoryArgs):
|
||||
"""Arguments for get_file_tree."""
|
||||
|
||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
||||
recursive: bool = Field(default=False)
|
||||
|
||||
|
||||
class FileContentsArgs(RepositoryArgs):
|
||||
"""Arguments for get_file_contents."""
|
||||
|
||||
filepath: str = Field(..., min_length=1, max_length=1024)
|
||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_filepath(self) -> FileContentsArgs:
|
||||
"""Validate path safety constraints."""
|
||||
normalized = self.filepath.replace("\\", "/")
|
||||
# Security decision: block traversal and absolute paths.
|
||||
if normalized.startswith("/") or ".." in normalized.split("/"):
|
||||
raise ValueError("filepath must be a relative path without traversal")
|
||||
if "\x00" in normalized:
|
||||
raise ValueError("filepath cannot contain null bytes")
|
||||
return self
|
||||
|
||||
|
||||
class SearchCodeArgs(RepositoryArgs):
|
||||
"""Arguments for search_code."""
|
||||
|
||||
query: str = Field(..., min_length=1, max_length=256)
|
||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=25, ge=1, le=100)
|
||||
|
||||
|
||||
class ListCommitsArgs(RepositoryArgs):
|
||||
"""Arguments for list_commits."""
|
||||
|
||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=25, ge=1, le=100)
|
||||
|
||||
|
||||
class CommitDiffArgs(RepositoryArgs):
|
||||
"""Arguments for get_commit_diff."""
|
||||
|
||||
sha: str = Field(..., min_length=7, max_length=64)
|
||||
|
||||
|
||||
class CompareRefsArgs(RepositoryArgs):
|
||||
"""Arguments for compare_refs."""
|
||||
|
||||
base: str = Field(..., min_length=1, max_length=200)
|
||||
head: str = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class ListIssuesArgs(RepositoryArgs):
|
||||
"""Arguments for list_issues."""
|
||||
|
||||
state: Literal["open", "closed", "all"] = Field(default="open")
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=25, ge=1, le=100)
|
||||
labels: list[str] = Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class IssueArgs(RepositoryArgs):
|
||||
"""Arguments for get_issue."""
|
||||
|
||||
issue_number: int = Field(..., ge=1)
|
||||
|
||||
|
||||
class ListPullRequestsArgs(RepositoryArgs):
|
||||
"""Arguments for list_pull_requests."""
|
||||
|
||||
state: Literal["open", "closed", "all"] = Field(default="open")
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=25, ge=1, le=100)
|
||||
|
||||
|
||||
class PullRequestArgs(RepositoryArgs):
|
||||
"""Arguments for get_pull_request."""
|
||||
|
||||
pull_number: int = Field(..., ge=1)
|
||||
|
||||
|
||||
class ListLabelsArgs(RepositoryArgs):
|
||||
"""Arguments for list_labels."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=50, ge=1, le=100)
|
||||
|
||||
|
||||
class ListTagsArgs(RepositoryArgs):
|
||||
"""Arguments for list_tags."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=50, ge=1, le=100)
|
||||
|
||||
|
||||
class ListReleasesArgs(RepositoryArgs):
|
||||
"""Arguments for list_releases."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=10_000)
|
||||
limit: int = Field(default=25, ge=1, le=100)
|
||||
|
||||
|
||||
class CreateIssueArgs(RepositoryArgs):
|
||||
"""Arguments for create_issue."""
|
||||
|
||||
title: str = Field(..., min_length=1, max_length=256)
|
||||
body: str = Field(default="", max_length=20_000)
|
||||
labels: list[str] = Field(default_factory=list, max_length=20)
|
||||
assignees: list[str] = Field(default_factory=list, max_length=20)
|
||||
|
||||
|
||||
class UpdateIssueArgs(RepositoryArgs):
|
||||
"""Arguments for update_issue."""
|
||||
|
||||
issue_number: int = Field(..., ge=1)
|
||||
title: str | None = Field(default=None, min_length=1, max_length=256)
|
||||
body: str | None = Field(default=None, max_length=20_000)
|
||||
state: Literal["open", "closed"] | None = Field(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_change(self) -> UpdateIssueArgs:
|
||||
"""Require at least one mutable field in update payload."""
|
||||
if self.title is None and self.body is None and self.state is None:
|
||||
raise ValueError("At least one of title, body, or state must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class CreateIssueCommentArgs(RepositoryArgs):
|
||||
"""Arguments for create_issue_comment."""
|
||||
|
||||
issue_number: int = Field(..., ge=1)
|
||||
body: str = Field(..., min_length=1, max_length=10_000)
|
||||
|
||||
|
||||
class CreatePrCommentArgs(RepositoryArgs):
|
||||
"""Arguments for create_pr_comment."""
|
||||
|
||||
pull_number: int = Field(..., ge=1)
|
||||
body: str = Field(..., min_length=1, max_length=10_000)
|
||||
|
||||
|
||||
class AddLabelsArgs(RepositoryArgs):
|
||||
"""Arguments for add_labels."""
|
||||
|
||||
issue_number: int = Field(..., ge=1)
|
||||
labels: list[str] = Field(..., min_length=1, max_length=20)
|
||||
|
||||
|
||||
class AssignIssueArgs(RepositoryArgs):
|
||||
"""Arguments for assign_issue."""
|
||||
|
||||
issue_number: int = Field(..., ge=1)
|
||||
assignees: list[str] = Field(..., min_length=1, max_length=20)
|
||||
|
||||
|
||||
def extract_repository(arguments: dict[str, object]) -> str | None:
|
||||
"""Extract `owner/repo` from raw argument mapping.
|
||||
|
||||
Args:
|
||||
arguments: Raw tool arguments.
|
||||
|
||||
Returns:
|
||||
`owner/repo` or None when arguments are incomplete.
|
||||
"""
|
||||
owner = arguments.get("owner")
|
||||
repo = arguments.get("repo")
|
||||
if isinstance(owner, str) and isinstance(repo, str) and owner and repo:
|
||||
return f"{owner}/{repo}"
|
||||
return None
|
||||
|
||||
|
||||
def extract_target_path(arguments: dict[str, object]) -> str | None:
|
||||
"""Extract optional target path argument for policy path checks."""
|
||||
filepath = arguments.get("filepath")
|
||||
if isinstance(filepath, str) and filepath:
|
||||
return filepath
|
||||
return None
|
||||
402
src/aegis_gitea_mcp/tools/read_tools.py
Normal file
402
src/aegis_gitea_mcp/tools/read_tools.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Extended read-only MCP tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.gitea_client import GiteaClient, GiteaError
|
||||
from aegis_gitea_mcp.response_limits import limit_items, limit_text
|
||||
from aegis_gitea_mcp.tools.arguments import (
|
||||
CommitDiffArgs,
|
||||
CompareRefsArgs,
|
||||
IssueArgs,
|
||||
ListCommitsArgs,
|
||||
ListIssuesArgs,
|
||||
ListLabelsArgs,
|
||||
ListPullRequestsArgs,
|
||||
ListReleasesArgs,
|
||||
ListTagsArgs,
|
||||
PullRequestArgs,
|
||||
SearchCodeArgs,
|
||||
)
|
||||
|
||||
|
||||
async def search_code_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Search repository code and return bounded result snippets."""
|
||||
parsed = SearchCodeArgs.model_validate(arguments)
|
||||
try:
|
||||
raw = await gitea.search_code(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
parsed.query,
|
||||
ref=parsed.ref,
|
||||
page=parsed.page,
|
||||
limit=parsed.limit,
|
||||
)
|
||||
hits_raw = raw.get("data", raw.get("hits", [])) if isinstance(raw, dict) else []
|
||||
if not isinstance(hits_raw, list):
|
||||
hits_raw = []
|
||||
|
||||
normalized_hits = []
|
||||
for item in hits_raw:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
snippet = str(item.get("content", item.get("snippet", "")))
|
||||
normalized_hits.append(
|
||||
{
|
||||
"path": item.get("filename", item.get("path", "")),
|
||||
"sha": item.get("sha", ""),
|
||||
"ref": parsed.ref,
|
||||
"snippet": limit_text(snippet),
|
||||
"score": item.get("score", 0),
|
||||
}
|
||||
)
|
||||
|
||||
bounded, omitted = limit_items(normalized_hits, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"query": parsed.query,
|
||||
"ref": parsed.ref,
|
||||
"results": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to search code: {exc}") from exc
|
||||
|
||||
|
||||
async def list_commits_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List commits for a repository reference."""
|
||||
parsed = ListCommitsArgs.model_validate(arguments)
|
||||
try:
|
||||
commits = await gitea.list_commits(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
ref=parsed.ref,
|
||||
page=parsed.page,
|
||||
limit=parsed.limit,
|
||||
)
|
||||
normalized = [
|
||||
{
|
||||
"sha": commit.get("sha", ""),
|
||||
"message": limit_text(str(commit.get("commit", {}).get("message", ""))),
|
||||
"author": commit.get("author", {}).get("login", ""),
|
||||
"created": commit.get("commit", {}).get("author", {}).get("date", ""),
|
||||
"url": commit.get("html_url", ""),
|
||||
}
|
||||
for commit in commits
|
||||
if isinstance(commit, dict)
|
||||
]
|
||||
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"ref": parsed.ref,
|
||||
"commits": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list commits: {exc}") from exc
|
||||
|
||||
|
||||
async def get_commit_diff_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return commit-level file diff metadata."""
|
||||
parsed = CommitDiffArgs.model_validate(arguments)
|
||||
try:
|
||||
commit = await gitea.get_commit_diff(parsed.owner, parsed.repo, parsed.sha)
|
||||
files = commit.get("files", []) if isinstance(commit, dict) else []
|
||||
normalized_files = []
|
||||
if isinstance(files, list):
|
||||
for item in files:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
normalized_files.append(
|
||||
{
|
||||
"filename": item.get("filename", ""),
|
||||
"status": item.get("status", ""),
|
||||
"additions": item.get("additions", 0),
|
||||
"deletions": item.get("deletions", 0),
|
||||
"changes": item.get("changes", 0),
|
||||
"patch": limit_text(str(item.get("patch", ""))),
|
||||
}
|
||||
)
|
||||
bounded, omitted = limit_items(normalized_files)
|
||||
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"sha": parsed.sha,
|
||||
"message": limit_text(
|
||||
str(commit.get("message", commit.get("commit", {}).get("message", "")))
|
||||
),
|
||||
"files": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to get commit diff: {exc}") from exc
|
||||
|
||||
|
||||
async def compare_refs_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Compare two refs and return bounded commit/file changes."""
|
||||
parsed = CompareRefsArgs.model_validate(arguments)
|
||||
try:
|
||||
comparison = await gitea.compare_refs(parsed.owner, parsed.repo, parsed.base, parsed.head)
|
||||
commits_raw = comparison.get("commits", []) if isinstance(comparison, dict) else []
|
||||
files_raw = comparison.get("files", []) if isinstance(comparison, dict) else []
|
||||
|
||||
commits = [
|
||||
{
|
||||
"sha": commit.get("sha", ""),
|
||||
"message": limit_text(str(commit.get("commit", {}).get("message", ""))),
|
||||
}
|
||||
for commit in commits_raw
|
||||
if isinstance(commit, dict)
|
||||
]
|
||||
commit_items, commit_omitted = limit_items(commits)
|
||||
|
||||
files = [
|
||||
{
|
||||
"filename": item.get("filename", ""),
|
||||
"status": item.get("status", ""),
|
||||
"additions": item.get("additions", 0),
|
||||
"deletions": item.get("deletions", 0),
|
||||
}
|
||||
for item in files_raw
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
file_items, file_omitted = limit_items(files)
|
||||
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"base": parsed.base,
|
||||
"head": parsed.head,
|
||||
"commits": commit_items,
|
||||
"files": file_items,
|
||||
"commit_count": len(commit_items),
|
||||
"file_count": len(file_items),
|
||||
"omitted_commits": commit_omitted,
|
||||
"omitted_files": file_omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to compare refs: {exc}") from exc
|
||||
|
||||
|
||||
async def list_issues_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List issues for repository."""
|
||||
parsed = ListIssuesArgs.model_validate(arguments)
|
||||
try:
|
||||
issues = await gitea.list_issues(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
state=parsed.state,
|
||||
page=parsed.page,
|
||||
limit=parsed.limit,
|
||||
labels=parsed.labels,
|
||||
)
|
||||
normalized = [
|
||||
{
|
||||
"number": issue.get("number", 0),
|
||||
"title": limit_text(str(issue.get("title", ""))),
|
||||
"state": issue.get("state", ""),
|
||||
"author": issue.get("user", {}).get("login", ""),
|
||||
"labels": [label.get("name", "") for label in issue.get("labels", [])],
|
||||
"created_at": issue.get("created_at", ""),
|
||||
"updated_at": issue.get("updated_at", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
}
|
||||
for issue in issues
|
||||
if isinstance(issue, dict)
|
||||
]
|
||||
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"state": parsed.state,
|
||||
"issues": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list issues: {exc}") from exc
|
||||
|
||||
|
||||
async def get_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get issue details."""
|
||||
parsed = IssueArgs.model_validate(arguments)
|
||||
try:
|
||||
issue = await gitea.get_issue(parsed.owner, parsed.repo, parsed.issue_number)
|
||||
return {
|
||||
"number": issue.get("number", 0),
|
||||
"title": limit_text(str(issue.get("title", ""))),
|
||||
"body": limit_text(str(issue.get("body", ""))),
|
||||
"state": issue.get("state", ""),
|
||||
"author": issue.get("user", {}).get("login", ""),
|
||||
"labels": [label.get("name", "") for label in issue.get("labels", [])],
|
||||
"assignees": [assignee.get("login", "") for assignee in issue.get("assignees", [])],
|
||||
"created_at": issue.get("created_at", ""),
|
||||
"updated_at": issue.get("updated_at", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to get issue: {exc}") from exc
|
||||
|
||||
|
||||
async def list_pull_requests_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List pull requests."""
|
||||
parsed = ListPullRequestsArgs.model_validate(arguments)
|
||||
try:
|
||||
pull_requests = await gitea.list_pull_requests(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
state=parsed.state,
|
||||
page=parsed.page,
|
||||
limit=parsed.limit,
|
||||
)
|
||||
normalized = [
|
||||
{
|
||||
"number": pull.get("number", 0),
|
||||
"title": limit_text(str(pull.get("title", ""))),
|
||||
"state": pull.get("state", ""),
|
||||
"author": pull.get("user", {}).get("login", ""),
|
||||
"draft": pull.get("draft", False),
|
||||
"mergeable": pull.get("mergeable", False),
|
||||
"created_at": pull.get("created_at", ""),
|
||||
"updated_at": pull.get("updated_at", ""),
|
||||
"url": pull.get("html_url", ""),
|
||||
}
|
||||
for pull in pull_requests
|
||||
if isinstance(pull, dict)
|
||||
]
|
||||
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"state": parsed.state,
|
||||
"pull_requests": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list pull requests: {exc}") from exc
|
||||
|
||||
|
||||
async def get_pull_request_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get pull request details."""
|
||||
parsed = PullRequestArgs.model_validate(arguments)
|
||||
try:
|
||||
pull = await gitea.get_pull_request(parsed.owner, parsed.repo, parsed.pull_number)
|
||||
return {
|
||||
"number": pull.get("number", 0),
|
||||
"title": limit_text(str(pull.get("title", ""))),
|
||||
"body": limit_text(str(pull.get("body", ""))),
|
||||
"state": pull.get("state", ""),
|
||||
"draft": pull.get("draft", False),
|
||||
"mergeable": pull.get("mergeable", False),
|
||||
"author": pull.get("user", {}).get("login", ""),
|
||||
"base": pull.get("base", {}).get("ref", ""),
|
||||
"head": pull.get("head", {}).get("ref", ""),
|
||||
"created_at": pull.get("created_at", ""),
|
||||
"updated_at": pull.get("updated_at", ""),
|
||||
"url": pull.get("html_url", ""),
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to get pull request: {exc}") from exc
|
||||
|
||||
|
||||
async def list_labels_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List labels configured on repository."""
|
||||
parsed = ListLabelsArgs.model_validate(arguments)
|
||||
try:
|
||||
labels = await gitea.list_labels(
|
||||
parsed.owner, parsed.repo, page=parsed.page, limit=parsed.limit
|
||||
)
|
||||
normalized = [
|
||||
{
|
||||
"id": label.get("id", 0),
|
||||
"name": label.get("name", ""),
|
||||
"color": label.get("color", ""),
|
||||
"description": limit_text(str(label.get("description", ""))),
|
||||
}
|
||||
for label in labels
|
||||
if isinstance(label, dict)
|
||||
]
|
||||
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"labels": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list labels: {exc}") from exc
|
||||
|
||||
|
||||
async def list_tags_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List repository tags."""
|
||||
parsed = ListTagsArgs.model_validate(arguments)
|
||||
try:
|
||||
tags = await gitea.list_tags(
|
||||
parsed.owner, parsed.repo, page=parsed.page, limit=parsed.limit
|
||||
)
|
||||
normalized = [
|
||||
{
|
||||
"name": tag.get("name", ""),
|
||||
"commit": tag.get("commit", {}).get("sha", ""),
|
||||
"zipball_url": tag.get("zipball_url", ""),
|
||||
"tarball_url": tag.get("tarball_url", ""),
|
||||
}
|
||||
for tag in tags
|
||||
if isinstance(tag, dict)
|
||||
]
|
||||
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"tags": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list tags: {exc}") from exc
|
||||
|
||||
|
||||
async def list_releases_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List repository releases."""
|
||||
parsed = ListReleasesArgs.model_validate(arguments)
|
||||
try:
|
||||
releases = await gitea.list_releases(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
page=parsed.page,
|
||||
limit=parsed.limit,
|
||||
)
|
||||
normalized = [
|
||||
{
|
||||
"id": release.get("id", 0),
|
||||
"tag_name": release.get("tag_name", ""),
|
||||
"name": limit_text(str(release.get("name", ""))),
|
||||
"draft": release.get("draft", False),
|
||||
"prerelease": release.get("prerelease", False),
|
||||
"body": limit_text(str(release.get("body", ""))),
|
||||
"created_at": release.get("created_at", ""),
|
||||
"published_at": release.get("published_at", ""),
|
||||
"url": release.get("html_url", ""),
|
||||
}
|
||||
for release in releases
|
||||
if isinstance(release, dict)
|
||||
]
|
||||
bounded, omitted = limit_items(normalized, configured_limit=parsed.limit)
|
||||
return {
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"releases": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list releases: {exc}") from exc
|
||||
@@ -1,26 +1,36 @@
|
||||
"""Repository-related MCP tool implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict
|
||||
import binascii
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.gitea_client import GiteaClient, GiteaError
|
||||
from aegis_gitea_mcp.response_limits import limit_items, limit_text
|
||||
from aegis_gitea_mcp.security import sanitize_untrusted_text
|
||||
from aegis_gitea_mcp.tools.arguments import (
|
||||
FileContentsArgs,
|
||||
FileTreeArgs,
|
||||
ListRepositoriesArgs,
|
||||
RepositoryArgs,
|
||||
)
|
||||
|
||||
|
||||
async def list_repositories_tool(gitea: GiteaClient, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""List all repositories visible to the bot user.
|
||||
async def list_repositories_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List repositories visible to the bot user.
|
||||
|
||||
Args:
|
||||
gitea: Initialized Gitea client
|
||||
arguments: Tool arguments (empty for this tool)
|
||||
gitea: Initialized Gitea client.
|
||||
arguments: Tool arguments.
|
||||
|
||||
Returns:
|
||||
Dict containing list of repositories with metadata
|
||||
Response payload with bounded repository list.
|
||||
"""
|
||||
ListRepositoriesArgs.model_validate(arguments)
|
||||
try:
|
||||
repos = await gitea.list_repositories()
|
||||
|
||||
# Transform to simplified format
|
||||
simplified_repos = [
|
||||
repositories = await gitea.list_repositories()
|
||||
simplified = [
|
||||
{
|
||||
"owner": repo.get("owner", {}).get("login", ""),
|
||||
"name": repo.get("name", ""),
|
||||
@@ -32,39 +42,24 @@ async def list_repositories_tool(gitea: GiteaClient, arguments: Dict[str, Any])
|
||||
"stars": repo.get("stars_count", 0),
|
||||
"url": repo.get("html_url", ""),
|
||||
}
|
||||
for repo in repos
|
||||
for repo in repositories
|
||||
]
|
||||
|
||||
bounded, omitted = limit_items(simplified)
|
||||
return {
|
||||
"repositories": simplified_repos,
|
||||
"count": len(simplified_repos),
|
||||
"repositories": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to list repositories: {exc}") from exc
|
||||
|
||||
|
||||
async def get_repository_info_tool(
|
||||
gitea: GiteaClient, arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific repository.
|
||||
|
||||
Args:
|
||||
gitea: Initialized Gitea client
|
||||
arguments: Tool arguments with 'owner' and 'repo'
|
||||
|
||||
Returns:
|
||||
Dict containing repository information
|
||||
"""
|
||||
owner = arguments.get("owner")
|
||||
repo = arguments.get("repo")
|
||||
|
||||
if not owner or not repo:
|
||||
raise ValueError("Both 'owner' and 'repo' arguments are required")
|
||||
|
||||
async def get_repository_info_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get detailed metadata for a repository."""
|
||||
parsed = RepositoryArgs.model_validate(arguments)
|
||||
try:
|
||||
repo_data = await gitea.get_repository(owner, repo)
|
||||
|
||||
repo_data = await gitea.get_repository(parsed.owner, parsed.repo)
|
||||
return {
|
||||
"owner": repo_data.get("owner", {}).get("login", ""),
|
||||
"name": repo_data.get("name", ""),
|
||||
@@ -83,107 +78,82 @@ async def get_repository_info_tool(
|
||||
"url": repo_data.get("html_url", ""),
|
||||
"clone_url": repo_data.get("clone_url", ""),
|
||||
}
|
||||
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to get repository info: {exc}") from exc
|
||||
|
||||
|
||||
async def get_file_tree_tool(gitea: GiteaClient, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get file tree for a repository.
|
||||
|
||||
Args:
|
||||
gitea: Initialized Gitea client
|
||||
arguments: Tool arguments with 'owner', 'repo', optional 'ref' and 'recursive'
|
||||
|
||||
Returns:
|
||||
Dict containing file tree structure
|
||||
"""
|
||||
owner = arguments.get("owner")
|
||||
repo = arguments.get("repo")
|
||||
ref = arguments.get("ref", "main")
|
||||
recursive = arguments.get("recursive", False)
|
||||
|
||||
if not owner or not repo:
|
||||
raise ValueError("Both 'owner' and 'repo' arguments are required")
|
||||
|
||||
async def get_file_tree_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get repository file tree at selected ref."""
|
||||
parsed = FileTreeArgs.model_validate(arguments)
|
||||
try:
|
||||
tree_data = await gitea.get_tree(owner, repo, ref, recursive)
|
||||
|
||||
# Transform tree entries to simplified format
|
||||
tree_data = await gitea.get_tree(parsed.owner, parsed.repo, parsed.ref, parsed.recursive)
|
||||
tree_entries = tree_data.get("tree", [])
|
||||
simplified_tree = [
|
||||
simplified = [
|
||||
{
|
||||
"path": entry.get("path", ""),
|
||||
"type": entry.get("type", ""), # 'blob' (file) or 'tree' (directory)
|
||||
"type": entry.get("type", ""),
|
||||
"size": entry.get("size", 0),
|
||||
"sha": entry.get("sha", ""),
|
||||
}
|
||||
for entry in tree_entries
|
||||
]
|
||||
bounded, omitted = limit_items(simplified)
|
||||
|
||||
return {
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"ref": ref,
|
||||
"tree": simplified_tree,
|
||||
"count": len(simplified_tree),
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"ref": parsed.ref,
|
||||
"recursive": parsed.recursive,
|
||||
"tree": bounded,
|
||||
"count": len(bounded),
|
||||
"omitted": omitted,
|
||||
}
|
||||
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to get file tree: {exc}") from exc
|
||||
|
||||
|
||||
async def get_file_contents_tool(gitea: GiteaClient, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get contents of a file in a repository.
|
||||
async def get_file_contents_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Read file contents from a repository ref.
|
||||
|
||||
Args:
|
||||
gitea: Initialized Gitea client
|
||||
arguments: Tool arguments with 'owner', 'repo', 'filepath', optional 'ref'
|
||||
|
||||
Returns:
|
||||
Dict containing file contents and metadata
|
||||
Security notes:
|
||||
- Repository content is treated as untrusted data and never executed.
|
||||
- Text output is truncated to configured limits to reduce prompt-stuffing risk.
|
||||
"""
|
||||
owner = arguments.get("owner")
|
||||
repo = arguments.get("repo")
|
||||
filepath = arguments.get("filepath")
|
||||
ref = arguments.get("ref", "main")
|
||||
|
||||
if not owner or not repo or not filepath:
|
||||
raise ValueError("'owner', 'repo', and 'filepath' arguments are required")
|
||||
|
||||
parsed = FileContentsArgs.model_validate(arguments)
|
||||
try:
|
||||
file_data = await gitea.get_file_contents(owner, repo, filepath, ref)
|
||||
file_data = await gitea.get_file_contents(
|
||||
parsed.owner, parsed.repo, parsed.filepath, parsed.ref
|
||||
)
|
||||
|
||||
# Content is base64-encoded by Gitea
|
||||
content_b64 = file_data.get("content", "")
|
||||
encoding = file_data.get("encoding", "base64")
|
||||
content = str(content_b64)
|
||||
|
||||
# Decode if base64
|
||||
content = content_b64
|
||||
if encoding == "base64":
|
||||
try:
|
||||
content_bytes = base64.b64decode(content_b64)
|
||||
# Try to decode as UTF-8 text
|
||||
decoded_bytes = base64.b64decode(content_b64)
|
||||
try:
|
||||
content = content_bytes.decode("utf-8")
|
||||
content = decoded_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# If not text, keep as base64
|
||||
content = content_b64
|
||||
except Exception:
|
||||
# If decode fails, keep as-is
|
||||
pass
|
||||
# Edge case: binary files should remain encoded instead of forcing invalid text.
|
||||
content = str(content_b64)
|
||||
except (binascii.Error, ValueError):
|
||||
content = str(content_b64)
|
||||
|
||||
# Validation logic: keep untrusted content bounded before returning it to LLM clients.
|
||||
content = sanitize_untrusted_text(content, max_chars=200_000)
|
||||
content = limit_text(content)
|
||||
|
||||
return {
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"filepath": filepath,
|
||||
"ref": ref,
|
||||
"owner": parsed.owner,
|
||||
"repo": parsed.repo,
|
||||
"filepath": parsed.filepath,
|
||||
"ref": parsed.ref,
|
||||
"content": content,
|
||||
"encoding": encoding,
|
||||
"size": file_data.get("size", 0),
|
||||
"sha": file_data.get("sha", ""),
|
||||
"url": file_data.get("html_url", ""),
|
||||
}
|
||||
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to get file contents: {exc}") from exc
|
||||
|
||||
141
src/aegis_gitea_mcp/tools/write_tools.py
Normal file
141
src/aegis_gitea_mcp/tools/write_tools.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Write-mode MCP tool implementations (disabled by default)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.gitea_client import GiteaClient, GiteaError
|
||||
from aegis_gitea_mcp.response_limits import limit_text
|
||||
from aegis_gitea_mcp.tools.arguments import (
|
||||
AddLabelsArgs,
|
||||
AssignIssueArgs,
|
||||
CreateIssueArgs,
|
||||
CreateIssueCommentArgs,
|
||||
CreatePrCommentArgs,
|
||||
UpdateIssueArgs,
|
||||
)
|
||||
|
||||
|
||||
async def create_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Create a new issue in write mode."""
|
||||
parsed = CreateIssueArgs.model_validate(arguments)
|
||||
try:
|
||||
issue = await gitea.create_issue(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
title=parsed.title,
|
||||
body=parsed.body,
|
||||
labels=parsed.labels,
|
||||
assignees=parsed.assignees,
|
||||
)
|
||||
return {
|
||||
"number": issue.get("number", 0),
|
||||
"title": limit_text(str(issue.get("title", ""))),
|
||||
"state": issue.get("state", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to create issue: {exc}") from exc
|
||||
|
||||
|
||||
async def update_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Update issue fields in write mode."""
|
||||
parsed = UpdateIssueArgs.model_validate(arguments)
|
||||
try:
|
||||
issue = await gitea.update_issue(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
parsed.issue_number,
|
||||
title=parsed.title,
|
||||
body=parsed.body,
|
||||
state=parsed.state,
|
||||
)
|
||||
return {
|
||||
"number": issue.get("number", parsed.issue_number),
|
||||
"title": limit_text(str(issue.get("title", ""))),
|
||||
"state": issue.get("state", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to update issue: {exc}") from exc
|
||||
|
||||
|
||||
async def create_issue_comment_tool(
|
||||
gitea: GiteaClient, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Create issue comment in write mode."""
|
||||
parsed = CreateIssueCommentArgs.model_validate(arguments)
|
||||
try:
|
||||
comment = await gitea.create_issue_comment(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
parsed.issue_number,
|
||||
parsed.body,
|
||||
)
|
||||
return {
|
||||
"id": comment.get("id", 0),
|
||||
"issue_number": parsed.issue_number,
|
||||
"body": limit_text(str(comment.get("body", ""))),
|
||||
"url": comment.get("html_url", ""),
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to create issue comment: {exc}") from exc
|
||||
|
||||
|
||||
async def create_pr_comment_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Create PR discussion comment in write mode."""
|
||||
parsed = CreatePrCommentArgs.model_validate(arguments)
|
||||
try:
|
||||
comment = await gitea.create_pr_comment(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
parsed.pull_number,
|
||||
parsed.body,
|
||||
)
|
||||
return {
|
||||
"id": comment.get("id", 0),
|
||||
"pull_number": parsed.pull_number,
|
||||
"body": limit_text(str(comment.get("body", ""))),
|
||||
"url": comment.get("html_url", ""),
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to create PR comment: {exc}") from exc
|
||||
|
||||
|
||||
async def add_labels_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add labels to an issue or pull request."""
|
||||
parsed = AddLabelsArgs.model_validate(arguments)
|
||||
try:
|
||||
result = await gitea.add_labels(
|
||||
parsed.owner, parsed.repo, parsed.issue_number, parsed.labels
|
||||
)
|
||||
label_names = []
|
||||
if isinstance(result, dict):
|
||||
label_names = [label.get("name", "") for label in result.get("labels", [])]
|
||||
return {
|
||||
"issue_number": parsed.issue_number,
|
||||
"labels": label_names or parsed.labels,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to add labels: {exc}") from exc
|
||||
|
||||
|
||||
async def assign_issue_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Assign users to an issue or pull request."""
|
||||
parsed = AssignIssueArgs.model_validate(arguments)
|
||||
try:
|
||||
result = await gitea.assign_issue(
|
||||
parsed.owner,
|
||||
parsed.repo,
|
||||
parsed.issue_number,
|
||||
parsed.assignees,
|
||||
)
|
||||
assignees = []
|
||||
if isinstance(result, dict):
|
||||
assignees = [assignee.get("login", "") for assignee in result.get("assignees", [])]
|
||||
return {
|
||||
"issue_number": parsed.issue_number,
|
||||
"assignees": assignees or parsed.assignees,
|
||||
}
|
||||
except GiteaError as exc:
|
||||
raise RuntimeError(f"Failed to assign issue: {exc}") from exc
|
||||
Reference in New Issue
Block a user