"""Authentication module for MCP server API key validation.""" from __future__ import annotations import hashlib import hmac import secrets from datetime import datetime, timezone from aegis_gitea_mcp.audit import get_audit_logger from aegis_gitea_mcp.config import get_settings class AuthenticationError(Exception): """Raised when authentication fails.""" class APIKeyValidator: """Validate API keys for MCP server access.""" def __init__(self) -> None: """Initialize API key validator state.""" self.settings = get_settings() self.audit = get_audit_logger() self._failed_attempts: dict[str, list[datetime]] = {} def _constant_time_compare(self, candidate: str, expected: str) -> bool: """Compare API keys in constant time to mitigate timing attacks.""" return hmac.compare_digest(candidate, expected) def _check_rate_limit(self, identifier: str) -> bool: """Check whether authentication failures exceed configured threshold.""" now = datetime.now(timezone.utc) boundary = now.timestamp() - self.settings.auth_failure_window if identifier in self._failed_attempts: self._failed_attempts[identifier] = [ attempt for attempt in self._failed_attempts[identifier] if attempt.timestamp() > boundary ] return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures def _record_failed_attempt(self, identifier: str) -> None: """Record a failed authentication attempt for rate limiting.""" attempt_time = datetime.now(timezone.utc) self._failed_attempts.setdefault(identifier, []).append(attempt_time) if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures: self.audit.log_security_event( event_type="auth_rate_limit_exceeded", description="Authentication failure threshold exceeded", severity="high", metadata={ "identifier": identifier, "failure_count": len(self._failed_attempts[identifier]), "window_seconds": self.settings.auth_failure_window, }, ) def validate_api_key( self, provided_key: str | None, client_ip: str, user_agent: str, ) -> tuple[bool, str | None]: """Validate an API key. Args: provided_key: API key provided by client. client_ip: Request source IP address. user_agent: Request user agent. Returns: Tuple of `(is_valid, error_message)`. """ if not self.settings.auth_enabled: # Security note: auth-disabled mode is explicit and should be monitored. self.audit.log_security_event( event_type="auth_disabled", description="Authentication disabled; request was allowed", severity="critical", metadata={"client_ip": client_ip}, ) return True, None if not self._check_rate_limit(client_ip): self.audit.log_access_denied( tool_name="api_authentication", reason="rate_limit_exceeded", ) return False, "Too many failed authentication attempts. Please try again later." if not provided_key: self._record_failed_attempt(client_ip) self.audit.log_access_denied( tool_name="api_authentication", reason="missing_api_key", ) return False, "Authorization header missing. Required: Authorization: Bearer " if len(provided_key) < 32: # Validation logic: reject short keys early to reduce brute force surface. self._record_failed_attempt(client_ip) self.audit.log_access_denied( tool_name="api_authentication", reason="invalid_key_format", ) return False, "Invalid API key format" valid_keys = self.settings.mcp_api_keys if not valid_keys: self.audit.log_security_event( event_type="no_api_keys_configured", description="No API keys configured while auth is enabled", severity="critical", metadata={"client_ip": client_ip}, ) return False, "Server configuration error: No API keys configured" is_valid = any( self._constant_time_compare(provided_key, valid_key) for valid_key in valid_keys ) if is_valid: key_fingerprint = hashlib.sha256(provided_key.encode("utf-8")).hexdigest()[:12] self.audit.log_tool_invocation( tool_name="api_authentication", result_status="success", params={ "client_ip": client_ip, "user_agent": user_agent, "key_fingerprint": key_fingerprint, }, ) return True, None self._record_failed_attempt(client_ip) self.audit.log_access_denied( tool_name="api_authentication", reason="invalid_api_key", ) self.audit.log_security_event( event_type="invalid_api_key_attempt", description="Invalid API key was presented", severity="medium", metadata={"client_ip": client_ip, "user_agent": user_agent}, ) return False, "Invalid API key" def extract_bearer_token(self, authorization_header: str | None) -> str | None: """Extract API token from `Authorization: Bearer ` header. Security note: The scheme is case-sensitive by policy (`Bearer`) to prevent accepting ambiguous client implementations and to align strict API contracts. """ if not authorization_header: return None parts = authorization_header.split(" ") if len(parts) != 2: return None scheme, token = parts if scheme != "Bearer": return None if not token.strip(): return None return token.strip() def generate_api_key(length: int = 64) -> str: """Generate a cryptographically secure API key. Args: length: Length of key in characters. Returns: Generated API key string. """ return secrets.token_hex(length // 2) def hash_api_key(api_key: str) -> str: """Hash an API key for secure storage and comparison.""" return hashlib.sha256(api_key.encode("utf-8")).hexdigest() _validator: APIKeyValidator | None = None def get_validator() -> APIKeyValidator: """Get or create global API key validator instance.""" global _validator if _validator is None: _validator = APIKeyValidator() return _validator def reset_validator() -> None: """Reset global API key validator instance (primarily for testing).""" global _validator _validator = None