"""In-memory request rate limiting for MCP endpoints.""" from __future__ import annotations import hashlib import time from collections import defaultdict, deque from dataclasses import dataclass from aegis_gitea_mcp.audit import get_audit_logger from aegis_gitea_mcp.config import get_settings @dataclass(frozen=True) class RateLimitDecision: """Result of request rate-limit checks.""" allowed: bool reason: str class SlidingWindowLimiter: """Sliding-window limiter keyed by arbitrary identifiers.""" def __init__(self, max_requests: int, window_seconds: int) -> None: """Initialize a fixed-window limiter. Args: max_requests: Maximum allowed requests within window. window_seconds: Rolling time window length. """ self.max_requests = max_requests self.window_seconds = window_seconds self._events: dict[str, deque[float]] = defaultdict(deque) def allow(self, key: str) -> bool: """Check and record request for the provided key.""" now = time.time() boundary = now - self.window_seconds events = self._events[key] while events and events[0] < boundary: events.popleft() if len(events) >= self.max_requests: return False events.append(now) return True class RequestRateLimiter: """Combined per-IP and per-token request limiter.""" def __init__(self) -> None: """Initialize with current settings.""" settings = get_settings() self._audit = get_audit_logger() self._ip_limiter = SlidingWindowLimiter(settings.rate_limit_per_minute, 60) self._token_limiter = SlidingWindowLimiter(settings.token_rate_limit_per_minute, 60) def check(self, client_ip: str, token: str | None) -> RateLimitDecision: """Evaluate request against IP and token limits. Args: client_ip: Request source IP. token: Optional authenticated API token. Returns: Rate limit decision. """ if not self._ip_limiter.allow(client_ip): self._audit.log_security_event( event_type="rate_limit_ip_exceeded", description="Per-IP request rate limit exceeded", severity="medium", metadata={"client_ip": client_ip}, ) return RateLimitDecision(False, "Per-IP rate limit exceeded") if token: # Hash token before using it as a key to avoid storing secrets in memory maps. token_key = hashlib.sha256(token.encode("utf-8")).hexdigest() if not self._token_limiter.allow(token_key): self._audit.log_security_event( event_type="rate_limit_token_exceeded", description="Per-token request rate limit exceeded", severity="high", metadata={"client_ip": client_ip}, ) return RateLimitDecision(False, "Per-token rate limit exceeded") return RateLimitDecision(True, "within limits") _rate_limiter: RequestRateLimiter | None = None def get_rate_limiter() -> RequestRateLimiter: """Get global request limiter.""" global _rate_limiter if _rate_limiter is None: _rate_limiter = RequestRateLimiter() return _rate_limiter def reset_rate_limiter() -> None: """Reset global limiter (primarily for tests).""" global _rate_limiter _rate_limiter = None