111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
"""In-memory request rate limiting for MCP endpoints."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from dataclasses import dataclass
|
|
|
|
from aegis_gitea_mcp.audit import get_audit_logger
|
|
from aegis_gitea_mcp.config import get_settings
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RateLimitDecision:
|
|
"""Result of request rate-limit checks."""
|
|
|
|
allowed: bool
|
|
reason: str
|
|
|
|
|
|
class SlidingWindowLimiter:
|
|
"""Sliding-window limiter keyed by arbitrary identifiers."""
|
|
|
|
def __init__(self, max_requests: int, window_seconds: int) -> None:
|
|
"""Initialize a fixed-window limiter.
|
|
|
|
Args:
|
|
max_requests: Maximum allowed requests within window.
|
|
window_seconds: Rolling time window length.
|
|
"""
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
self._events: dict[str, deque[float]] = defaultdict(deque)
|
|
|
|
def allow(self, key: str) -> bool:
|
|
"""Check and record request for the provided key."""
|
|
now = time.time()
|
|
boundary = now - self.window_seconds
|
|
|
|
events = self._events[key]
|
|
while events and events[0] < boundary:
|
|
events.popleft()
|
|
|
|
if len(events) >= self.max_requests:
|
|
return False
|
|
|
|
events.append(now)
|
|
return True
|
|
|
|
|
|
class RequestRateLimiter:
|
|
"""Combined per-IP and per-token request limiter."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize with current settings."""
|
|
settings = get_settings()
|
|
self._audit = get_audit_logger()
|
|
self._ip_limiter = SlidingWindowLimiter(settings.rate_limit_per_minute, 60)
|
|
self._token_limiter = SlidingWindowLimiter(settings.token_rate_limit_per_minute, 60)
|
|
|
|
def check(self, client_ip: str, token: str | None) -> RateLimitDecision:
|
|
"""Evaluate request against IP and token limits.
|
|
|
|
Args:
|
|
client_ip: Request source IP.
|
|
token: Optional authenticated API token.
|
|
|
|
Returns:
|
|
Rate limit decision.
|
|
"""
|
|
if not self._ip_limiter.allow(client_ip):
|
|
self._audit.log_security_event(
|
|
event_type="rate_limit_ip_exceeded",
|
|
description="Per-IP request rate limit exceeded",
|
|
severity="medium",
|
|
metadata={"client_ip": client_ip},
|
|
)
|
|
return RateLimitDecision(False, "Per-IP rate limit exceeded")
|
|
|
|
if token:
|
|
# Hash token before using it as a key to avoid storing secrets in memory maps.
|
|
token_key = hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
if not self._token_limiter.allow(token_key):
|
|
self._audit.log_security_event(
|
|
event_type="rate_limit_token_exceeded",
|
|
description="Per-token request rate limit exceeded",
|
|
severity="high",
|
|
metadata={"client_ip": client_ip},
|
|
)
|
|
return RateLimitDecision(False, "Per-token rate limit exceeded")
|
|
|
|
return RateLimitDecision(True, "within limits")
|
|
|
|
|
|
_rate_limiter: RequestRateLimiter | None = None
|
|
|
|
|
|
def get_rate_limiter() -> RequestRateLimiter:
|
|
"""Get global request limiter."""
|
|
global _rate_limiter
|
|
if _rate_limiter is None:
|
|
_rate_limiter = RequestRateLimiter()
|
|
return _rate_limiter
|
|
|
|
|
|
def reset_rate_limiter() -> None:
|
|
"""Reset global limiter (primarily for tests)."""
|
|
global _rate_limiter
|
|
_rate_limiter = None
|