Files
AegisGitea-MCP/src/aegis_gitea_mcp/rate_limit.py
T

111 lines
3.5 KiB
Python

"""In-memory request rate limiting for MCP endpoints."""
from __future__ import annotations
import hashlib
import time
from collections import defaultdict, deque
from dataclasses import dataclass
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.config import get_settings
@dataclass(frozen=True)
class RateLimitDecision:
"""Result of request rate-limit checks."""
allowed: bool
reason: str
class SlidingWindowLimiter:
"""Sliding-window limiter keyed by arbitrary identifiers."""
def __init__(self, max_requests: int, window_seconds: int) -> None:
"""Initialize a fixed-window limiter.
Args:
max_requests: Maximum allowed requests within window.
window_seconds: Rolling time window length.
"""
self.max_requests = max_requests
self.window_seconds = window_seconds
self._events: dict[str, deque[float]] = defaultdict(deque)
def allow(self, key: str) -> bool:
"""Check and record request for the provided key."""
now = time.time()
boundary = now - self.window_seconds
events = self._events[key]
while events and events[0] < boundary:
events.popleft()
if len(events) >= self.max_requests:
return False
events.append(now)
return True
class RequestRateLimiter:
"""Combined per-IP and per-token request limiter."""
def __init__(self) -> None:
"""Initialize with current settings."""
settings = get_settings()
self._audit = get_audit_logger()
self._ip_limiter = SlidingWindowLimiter(settings.rate_limit_per_minute, 60)
self._token_limiter = SlidingWindowLimiter(settings.token_rate_limit_per_minute, 60)
def check(self, client_ip: str, token: str | None) -> RateLimitDecision:
"""Evaluate request against IP and token limits.
Args:
client_ip: Request source IP.
token: Optional authenticated API token.
Returns:
Rate limit decision.
"""
if not self._ip_limiter.allow(client_ip):
self._audit.log_security_event(
event_type="rate_limit_ip_exceeded",
description="Per-IP request rate limit exceeded",
severity="medium",
metadata={"client_ip": client_ip},
)
return RateLimitDecision(False, "Per-IP rate limit exceeded")
if token:
# Hash token before using it as a key to avoid storing secrets in memory maps.
token_key = hashlib.sha256(token.encode("utf-8")).hexdigest()
if not self._token_limiter.allow(token_key):
self._audit.log_security_event(
event_type="rate_limit_token_exceeded",
description="Per-token request rate limit exceeded",
severity="high",
metadata={"client_ip": client_ip},
)
return RateLimitDecision(False, "Per-token rate limit exceeded")
return RateLimitDecision(True, "within limits")
_rate_limiter: RequestRateLimiter | None = None
def get_rate_limiter() -> RequestRateLimiter:
"""Get global request limiter."""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RequestRateLimiter()
return _rate_limiter
def reset_rate_limiter() -> None:
"""Reset global limiter (primarily for tests)."""
global _rate_limiter
_rate_limiter = None