From ed3130ef745c01205f4dc6753246c99c29702c11 Mon Sep 17 00:00:00 2001 From: latte Date: Fri, 12 Jun 2026 18:39:06 +0200 Subject: [PATCH] refactor: lifespan handlers, module-level imports, bounded scope cache Replace deprecated @app.on_event startup/shutdown handlers with a FastAPI lifespan context manager, move the inline hashlib/time imports in the auth middleware to module top, and back the unbounded _api_scope_cache with a new size- and TTL-bounded BoundedTTLCache utility. Co-Authored-By: Claude Opus 4.8 --- src/aegis_gitea_mcp/cache.py | 71 ++++++++++++++++++++++++++++++++++ src/aegis_gitea_mcp/server.py | 54 +++++++++++++++++--------- tests/test_cache.py | 72 +++++++++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 19 deletions(-) create mode 100644 src/aegis_gitea_mcp/cache.py create mode 100644 tests/test_cache.py diff --git a/src/aegis_gitea_mcp/cache.py b/src/aegis_gitea_mcp/cache.py new file mode 100644 index 0000000..53540bc --- /dev/null +++ b/src/aegis_gitea_mcp/cache.py @@ -0,0 +1,71 @@ +"""Bounded, TTL-based in-memory caches with size eviction. + +Provides a small dependency-free cache used by the auth middleware and the +per-user authorization layer. Entries expire after a TTL and the cache is +bounded by a maximum size to prevent unbounded memory growth from untrusted +key cardinality (e.g. one entry per distinct token or per (user, repo) pair). +""" + +from __future__ import annotations + +import time +from collections import OrderedDict +from typing import Generic, TypeVar + +K = TypeVar("K") +V = TypeVar("V") + + +class BoundedTTLCache(Generic[K, V]): + """A size-bounded cache whose entries expire after a fixed TTL. + + Eviction is least-recently-inserted (FIFO) once ``max_size`` is reached. + Expired entries are removed lazily on access and proactively when the + cache is full, so the cache never exceeds ``max_size`` live entries. + """ + + def __init__(self, *, ttl_seconds: float, max_size: int = 1024) -> None: + """Initialize the cache with a TTL and maximum entry count.""" + if ttl_seconds <= 0: + raise ValueError("ttl_seconds must be positive") + if max_size <= 0: + raise ValueError("max_size must be positive") + self._ttl = float(ttl_seconds) + self._max_size = int(max_size) + self._store: OrderedDict[K, tuple[V, float]] = OrderedDict() + + def get(self, key: K) -> V | None: + """Return the cached value for ``key`` or ``None`` if absent/expired.""" + entry = self._store.get(key) + if entry is None: + return None + value, expiry = entry + if time.monotonic() >= expiry: + # Lazily evict expired entry. + self._store.pop(key, None) + return None + return value + + def set(self, key: K, value: V) -> None: + """Store ``value`` under ``key`` with the configured TTL.""" + now = time.monotonic() + # Drop the existing entry so reinsertion refreshes ordering. + self._store.pop(key, None) + self._store[key] = (value, now + self._ttl) + self._evict(now) + + def _evict(self, now: float) -> None: + """Remove expired entries, then enforce the size bound (FIFO).""" + expired = [key for key, (_, expiry) in self._store.items() if now >= expiry] + for key in expired: + self._store.pop(key, None) + while len(self._store) > self._max_size: + self._store.popitem(last=False) + + def clear(self) -> None: + """Remove all entries (primarily for tests).""" + self._store.clear() + + def __len__(self) -> int: + """Return the number of stored (not necessarily live) entries.""" + return len(self._store) diff --git a/src/aegis_gitea_mcp/server.py b/src/aegis_gitea_mcp/server.py index 29bc139..8acb22c 100644 --- a/src/aegis_gitea_mcp/server.py +++ b/src/aegis_gitea_mcp/server.py @@ -4,11 +4,13 @@ from __future__ import annotations import asyncio import base64 +import hashlib import json import logging import urllib.parse import uuid from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager from typing import Any import httpx @@ -18,6 +20,7 @@ from pydantic import BaseModel, Field, ValidationError from aegis_gitea_mcp.audit import get_audit_logger from aegis_gitea_mcp.automation import AutomationError, AutomationManager +from aegis_gitea_mcp.cache import BoundedTTLCache from aegis_gitea_mcp.config import get_settings from aegis_gitea_mcp.gitea_client import ( GiteaAuthenticationError, @@ -81,9 +84,12 @@ READ_SCOPE = "read:repository" WRITE_SCOPE = "write:repository" # Cache of tokens verified to have Gitea API scope. -# Key: hash of token prefix, Value: monotonic expiry time. -_api_scope_cache: dict[str, float] = {} +# Key: hash of token prefix, Value: sentinel marking the token as probe-verified. +# Bounded by size and TTL so untrusted token cardinality cannot grow it without limit. _API_SCOPE_CACHE_TTL = 60 # seconds +_api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache( + ttl_seconds=_API_SCOPE_CACHE_TTL, max_size=4096 +) _REAUTH_GUIDANCE = ( "Your OAuth token lacks Gitea API scopes (e.g. read:repository). " @@ -110,10 +116,21 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool: return required_scope in expanded +@asynccontextmanager +async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: + """Run startup and shutdown hooks via the FastAPI lifespan protocol.""" + await startup_event() + try: + yield + finally: + await shutdown_event() + + app = FastAPI( title="AegisGitea MCP Server", description="Security-first MCP server for controlled AI access to self-hosted Gitea", version="0.2.0", + lifespan=lifespan, ) @@ -335,22 +352,18 @@ async def authenticate_and_rate_limit( # Probe: verify the token actually works for Gitea's repository API. # Try both "token" and "Bearer" header formats since Gitea may # accept OAuth tokens differently depending on version/config. - import hashlib - import time as _time - token_hash = hashlib.sha256(access_token.encode()).hexdigest()[:16] - now = _time.monotonic() probe_result = "skip:cached" token_type = "jwt" if access_token.count(".") == 2 else "opaque" - if token_hash not in _api_scope_cache or now >= _api_scope_cache[token_hash]: + if _api_scope_cache.get(token_hash) is None: # JWT tokens (OIDC) are already cryptographically validated via JWKS above. # Gitea's OIDC access_tokens cannot access the REST API without additional # Gitea-specific scope configuration, so we skip the probe for them and # rely on per-call API errors for actual permission enforcement. if token_type == "jwt": probe_result = "skip:jwt" - _api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL + _api_scope_cache.set(token_hash, True) else: try: probe_status = None @@ -403,7 +416,7 @@ async def authenticate_and_rate_limit( ) else: probe_result = "pass" - _api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL + _api_scope_cache.set(token_hash, True) except httpx.RequestError: probe_result = "skip:error" logger.debug("oauth_api_scope_probe_network_error") @@ -422,7 +435,6 @@ async def authenticate_and_rate_limit( return await call_next(request) -@app.on_event("startup") async def startup_event() -> None: """Initialize server state on startup.""" settings = get_settings() @@ -470,7 +482,6 @@ async def startup_event() -> None: logger.info("gitea_oidc_discovery_ready", extra={"issuer": settings.gitea_base_url}) -@app.on_event("shutdown") async def shutdown_event() -> None: """Log server shutdown event.""" logger.info("server_stopping") @@ -653,14 +664,19 @@ async def oauth_token_proxy(request: Request) -> JSONResponse: except Exception as exc: raise HTTPException(status_code=400, detail="Invalid request body") from exc - grant_type = form_data.get("grant_type", "authorization_code") - code = form_data.get("code") - refresh_token = form_data.get("refresh_token") - code_verifier = form_data.get("code_verifier", "") - # ChatGPT sends the client_id and client_secret (that were configured in the GPT Action - # settings) in the POST body. Use those directly; fall back to env vars if not provided. - client_id = form_data.get("client_id") or settings.gitea_oauth_client_id - client_secret = form_data.get("client_secret") or settings.gitea_oauth_client_secret + def _field(name: str, default: str = "") -> str: + """Read a string form field, ignoring uploaded-file parts.""" + value = form_data.get(name, default) + return value if isinstance(value, str) else default + + grant_type = _field("grant_type", "authorization_code") + code = _field("code") + refresh_token = _field("refresh_token") + code_verifier = _field("code_verifier") + # The MCP client (Claude) sends client_id and, for confidential clients, client_secret + # in the POST body. Use those directly; fall back to env vars if not provided. + client_id = _field("client_id") or settings.gitea_oauth_client_id + client_secret = _field("client_secret") or settings.gitea_oauth_client_secret # Gitea validates that redirect_uri in the token exchange matches the one used during # authorization. Because our /oauth/authorize proxy always forwards our own callback diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..2aaf139 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,72 @@ +"""Tests for the bounded TTL cache utility.""" + +from __future__ import annotations + +import time + +import pytest + +from aegis_gitea_mcp.cache import BoundedTTLCache + + +def test_set_and_get_returns_value() -> None: + """A stored value is returned before it expires.""" + cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=8) + cache.set("a", 1) + assert cache.get("a") == 1 + + +def test_missing_key_returns_none() -> None: + """An unknown key returns None.""" + cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60) + assert cache.get("missing") is None + + +def test_entry_expires_after_ttl() -> None: + """An entry is evicted once its TTL elapses.""" + cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=0.05, max_size=8) + cache.set("a", 1) + assert cache.get("a") == 1 + time.sleep(0.06) + assert cache.get("a") is None + + +def test_size_bound_evicts_oldest() -> None: + """The cache never exceeds max_size; oldest entries are evicted first.""" + cache: BoundedTTLCache[int, int] = BoundedTTLCache(ttl_seconds=60, max_size=3) + for i in range(5): + cache.set(i, i) + assert len(cache) == 3 + # 0 and 1 were evicted; 2, 3, 4 remain. + assert cache.get(0) is None + assert cache.get(1) is None + assert cache.get(4) == 4 + + +def test_reinsert_refreshes_recency() -> None: + """Re-setting a key refreshes its position so it is not evicted first.""" + cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=2) + cache.set("a", 1) + cache.set("b", 2) + cache.set("a", 3) # refresh "a" + cache.set("c", 4) # should evict "b", the oldest + assert cache.get("b") is None + assert cache.get("a") == 3 + assert cache.get("c") == 4 + + +def test_clear_empties_cache() -> None: + """clear() removes all entries.""" + cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60) + cache.set("a", 1) + cache.clear() + assert cache.get("a") is None + assert len(cache) == 0 + + +def test_invalid_constructor_args() -> None: + """Non-positive TTL or size is rejected.""" + with pytest.raises(ValueError): + BoundedTTLCache(ttl_seconds=0) + with pytest.raises(ValueError): + BoundedTTLCache(ttl_seconds=60, max_size=0)