207 lines
7.0 KiB
Python
207 lines
7.0 KiB
Python
"""Authentication module for MCP server API key validation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import hmac
|
|
import secrets
|
|
from datetime import datetime, timezone
|
|
|
|
from aegis_gitea_mcp.audit import get_audit_logger
|
|
from aegis_gitea_mcp.config import get_settings
|
|
|
|
|
|
class AuthenticationError(Exception):
|
|
"""Raised when authentication fails."""
|
|
|
|
|
|
class APIKeyValidator:
|
|
"""Validate API keys for MCP server access."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize API key validator state."""
|
|
self.settings = get_settings()
|
|
self.audit = get_audit_logger()
|
|
self._failed_attempts: dict[str, list[datetime]] = {}
|
|
|
|
def _constant_time_compare(self, candidate: str, expected: str) -> bool:
|
|
"""Compare API keys in constant time to mitigate timing attacks."""
|
|
return hmac.compare_digest(candidate, expected)
|
|
|
|
def _check_rate_limit(self, identifier: str) -> bool:
|
|
"""Check whether authentication failures exceed configured threshold."""
|
|
now = datetime.now(timezone.utc)
|
|
boundary = now.timestamp() - self.settings.auth_failure_window
|
|
|
|
if identifier in self._failed_attempts:
|
|
self._failed_attempts[identifier] = [
|
|
attempt
|
|
for attempt in self._failed_attempts[identifier]
|
|
if attempt.timestamp() > boundary
|
|
]
|
|
|
|
return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures
|
|
|
|
def _record_failed_attempt(self, identifier: str) -> None:
|
|
"""Record a failed authentication attempt for rate limiting."""
|
|
attempt_time = datetime.now(timezone.utc)
|
|
self._failed_attempts.setdefault(identifier, []).append(attempt_time)
|
|
|
|
if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures:
|
|
self.audit.log_security_event(
|
|
event_type="auth_rate_limit_exceeded",
|
|
description="Authentication failure threshold exceeded",
|
|
severity="high",
|
|
metadata={
|
|
"identifier": identifier,
|
|
"failure_count": len(self._failed_attempts[identifier]),
|
|
"window_seconds": self.settings.auth_failure_window,
|
|
},
|
|
)
|
|
|
|
def validate_api_key(
|
|
self,
|
|
provided_key: str | None,
|
|
client_ip: str,
|
|
user_agent: str,
|
|
) -> tuple[bool, str | None]:
|
|
"""Validate an API key.
|
|
|
|
Args:
|
|
provided_key: API key provided by client.
|
|
client_ip: Request source IP address.
|
|
user_agent: Request user agent.
|
|
|
|
Returns:
|
|
Tuple of `(is_valid, error_message)`.
|
|
"""
|
|
if not self.settings.auth_enabled:
|
|
# Security note: auth-disabled mode is explicit and should be monitored.
|
|
self.audit.log_security_event(
|
|
event_type="auth_disabled",
|
|
description="Authentication disabled; request was allowed",
|
|
severity="critical",
|
|
metadata={"client_ip": client_ip},
|
|
)
|
|
return True, None
|
|
|
|
if not self._check_rate_limit(client_ip):
|
|
self.audit.log_access_denied(
|
|
tool_name="api_authentication",
|
|
reason="rate_limit_exceeded",
|
|
)
|
|
return False, "Too many failed authentication attempts. Please try again later."
|
|
|
|
if not provided_key:
|
|
self._record_failed_attempt(client_ip)
|
|
self.audit.log_access_denied(
|
|
tool_name="api_authentication",
|
|
reason="missing_api_key",
|
|
)
|
|
return False, "Authorization header missing. Required: Authorization: Bearer <api-key>"
|
|
|
|
if len(provided_key) < 32:
|
|
# Validation logic: reject short keys early to reduce brute force surface.
|
|
self._record_failed_attempt(client_ip)
|
|
self.audit.log_access_denied(
|
|
tool_name="api_authentication",
|
|
reason="invalid_key_format",
|
|
)
|
|
return False, "Invalid API key format"
|
|
|
|
valid_keys = self.settings.mcp_api_keys
|
|
if not valid_keys:
|
|
self.audit.log_security_event(
|
|
event_type="no_api_keys_configured",
|
|
description="No API keys configured while auth is enabled",
|
|
severity="critical",
|
|
metadata={"client_ip": client_ip},
|
|
)
|
|
return False, "Server configuration error: No API keys configured"
|
|
|
|
is_valid = any(
|
|
self._constant_time_compare(provided_key, valid_key) for valid_key in valid_keys
|
|
)
|
|
|
|
if is_valid:
|
|
key_fingerprint = hashlib.sha256(provided_key.encode("utf-8")).hexdigest()[:12]
|
|
self.audit.log_tool_invocation(
|
|
tool_name="api_authentication",
|
|
result_status="success",
|
|
params={
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent,
|
|
"key_fingerprint": key_fingerprint,
|
|
},
|
|
)
|
|
return True, None
|
|
|
|
self._record_failed_attempt(client_ip)
|
|
self.audit.log_access_denied(
|
|
tool_name="api_authentication",
|
|
reason="invalid_api_key",
|
|
)
|
|
self.audit.log_security_event(
|
|
event_type="invalid_api_key_attempt",
|
|
description="Invalid API key was presented",
|
|
severity="medium",
|
|
metadata={"client_ip": client_ip, "user_agent": user_agent},
|
|
)
|
|
return False, "Invalid API key"
|
|
|
|
def extract_bearer_token(self, authorization_header: str | None) -> str | None:
|
|
"""Extract API token from `Authorization: Bearer <token>` header.
|
|
|
|
Security note:
|
|
The scheme is case-sensitive by policy (`Bearer`) to prevent accepting
|
|
ambiguous client implementations and to align strict API contracts.
|
|
"""
|
|
if not authorization_header:
|
|
return None
|
|
|
|
parts = authorization_header.split(" ")
|
|
if len(parts) != 2:
|
|
return None
|
|
|
|
scheme, token = parts
|
|
if scheme != "Bearer":
|
|
return None
|
|
if not token.strip():
|
|
return None
|
|
|
|
return token.strip()
|
|
|
|
|
|
def generate_api_key(length: int = 64) -> str:
|
|
"""Generate a cryptographically secure API key.
|
|
|
|
Args:
|
|
length: Length of key in characters.
|
|
|
|
Returns:
|
|
Generated API key string.
|
|
"""
|
|
return secrets.token_hex(length // 2)
|
|
|
|
|
|
def hash_api_key(api_key: str) -> str:
|
|
"""Hash an API key for secure storage and comparison."""
|
|
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
|
|
|
|
|
_validator: APIKeyValidator | None = None
|
|
|
|
|
|
def get_validator() -> APIKeyValidator:
|
|
"""Get or create global API key validator instance."""
|
|
global _validator
|
|
if _validator is None:
|
|
_validator = APIKeyValidator()
|
|
return _validator
|
|
|
|
|
|
def reset_validator() -> None:
|
|
"""Reset global API key validator instance (primarily for testing)."""
|
|
global _validator
|
|
_validator = None
|