refactor: lifespan handlers, module-level imports, bounded scope cache

Replace deprecated @app.on_event startup/shutdown handlers with a FastAPI
lifespan context manager, move the inline hashlib/time imports in the auth
middleware to module top, and back the unbounded _api_scope_cache with a new
size- and TTL-bounded BoundedTTLCache utility.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 18:39:06 +02:00
parent 71c993e4cd
commit ed3130ef74
3 changed files with 178 additions and 19 deletions
+71
View File
@@ -0,0 +1,71 @@
"""Bounded, TTL-based in-memory caches with size eviction.
Provides a small dependency-free cache used by the auth middleware and the
per-user authorization layer. Entries expire after a TTL and the cache is
bounded by a maximum size to prevent unbounded memory growth from untrusted
key cardinality (e.g. one entry per distinct token or per (user, repo) pair).
"""
from __future__ import annotations
import time
from collections import OrderedDict
from typing import Generic, TypeVar
K = TypeVar("K")
V = TypeVar("V")
class BoundedTTLCache(Generic[K, V]):
"""A size-bounded cache whose entries expire after a fixed TTL.
Eviction is least-recently-inserted (FIFO) once ``max_size`` is reached.
Expired entries are removed lazily on access and proactively when the
cache is full, so the cache never exceeds ``max_size`` live entries.
"""
def __init__(self, *, ttl_seconds: float, max_size: int = 1024) -> None:
"""Initialize the cache with a TTL and maximum entry count."""
if ttl_seconds <= 0:
raise ValueError("ttl_seconds must be positive")
if max_size <= 0:
raise ValueError("max_size must be positive")
self._ttl = float(ttl_seconds)
self._max_size = int(max_size)
self._store: OrderedDict[K, tuple[V, float]] = OrderedDict()
def get(self, key: K) -> V | None:
"""Return the cached value for ``key`` or ``None`` if absent/expired."""
entry = self._store.get(key)
if entry is None:
return None
value, expiry = entry
if time.monotonic() >= expiry:
# Lazily evict expired entry.
self._store.pop(key, None)
return None
return value
def set(self, key: K, value: V) -> None:
"""Store ``value`` under ``key`` with the configured TTL."""
now = time.monotonic()
# Drop the existing entry so reinsertion refreshes ordering.
self._store.pop(key, None)
self._store[key] = (value, now + self._ttl)
self._evict(now)
def _evict(self, now: float) -> None:
"""Remove expired entries, then enforce the size bound (FIFO)."""
expired = [key for key, (_, expiry) in self._store.items() if now >= expiry]
for key in expired:
self._store.pop(key, None)
while len(self._store) > self._max_size:
self._store.popitem(last=False)
def clear(self) -> None:
"""Remove all entries (primarily for tests)."""
self._store.clear()
def __len__(self) -> int:
"""Return the number of stored (not necessarily live) entries."""
return len(self._store)
+35 -19
View File
@@ -4,11 +4,13 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import hashlib
import json import json
import logging import logging
import urllib.parse import urllib.parse
import uuid import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Any from typing import Any
import httpx import httpx
@@ -18,6 +20,7 @@ from pydantic import BaseModel, Field, ValidationError
from aegis_gitea_mcp.audit import get_audit_logger from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.automation import AutomationError, AutomationManager from aegis_gitea_mcp.automation import AutomationError, AutomationManager
from aegis_gitea_mcp.cache import BoundedTTLCache
from aegis_gitea_mcp.config import get_settings from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.gitea_client import ( from aegis_gitea_mcp.gitea_client import (
GiteaAuthenticationError, GiteaAuthenticationError,
@@ -81,9 +84,12 @@ READ_SCOPE = "read:repository"
WRITE_SCOPE = "write:repository" WRITE_SCOPE = "write:repository"
# Cache of tokens verified to have Gitea API scope. # Cache of tokens verified to have Gitea API scope.
# Key: hash of token prefix, Value: monotonic expiry time. # Key: hash of token prefix, Value: sentinel marking the token as probe-verified.
_api_scope_cache: dict[str, float] = {} # Bounded by size and TTL so untrusted token cardinality cannot grow it without limit.
_API_SCOPE_CACHE_TTL = 60 # seconds _API_SCOPE_CACHE_TTL = 60 # seconds
_api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache(
ttl_seconds=_API_SCOPE_CACHE_TTL, max_size=4096
)
_REAUTH_GUIDANCE = ( _REAUTH_GUIDANCE = (
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). " "Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
@@ -110,10 +116,21 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
return required_scope in expanded return required_scope in expanded
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
await startup_event()
try:
yield
finally:
await shutdown_event()
app = FastAPI( app = FastAPI(
title="AegisGitea MCP Server", title="AegisGitea MCP Server",
description="Security-first MCP server for controlled AI access to self-hosted Gitea", description="Security-first MCP server for controlled AI access to self-hosted Gitea",
version="0.2.0", version="0.2.0",
lifespan=lifespan,
) )
@@ -335,22 +352,18 @@ async def authenticate_and_rate_limit(
# Probe: verify the token actually works for Gitea's repository API. # Probe: verify the token actually works for Gitea's repository API.
# Try both "token" and "Bearer" header formats since Gitea may # Try both "token" and "Bearer" header formats since Gitea may
# accept OAuth tokens differently depending on version/config. # accept OAuth tokens differently depending on version/config.
import hashlib
import time as _time
token_hash = hashlib.sha256(access_token.encode()).hexdigest()[:16] token_hash = hashlib.sha256(access_token.encode()).hexdigest()[:16]
now = _time.monotonic()
probe_result = "skip:cached" probe_result = "skip:cached"
token_type = "jwt" if access_token.count(".") == 2 else "opaque" token_type = "jwt" if access_token.count(".") == 2 else "opaque"
if token_hash not in _api_scope_cache or now >= _api_scope_cache[token_hash]: if _api_scope_cache.get(token_hash) is None:
# JWT tokens (OIDC) are already cryptographically validated via JWKS above. # JWT tokens (OIDC) are already cryptographically validated via JWKS above.
# Gitea's OIDC access_tokens cannot access the REST API without additional # Gitea's OIDC access_tokens cannot access the REST API without additional
# Gitea-specific scope configuration, so we skip the probe for them and # Gitea-specific scope configuration, so we skip the probe for them and
# rely on per-call API errors for actual permission enforcement. # rely on per-call API errors for actual permission enforcement.
if token_type == "jwt": if token_type == "jwt":
probe_result = "skip:jwt" probe_result = "skip:jwt"
_api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL _api_scope_cache.set(token_hash, True)
else: else:
try: try:
probe_status = None probe_status = None
@@ -403,7 +416,7 @@ async def authenticate_and_rate_limit(
) )
else: else:
probe_result = "pass" probe_result = "pass"
_api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL _api_scope_cache.set(token_hash, True)
except httpx.RequestError: except httpx.RequestError:
probe_result = "skip:error" probe_result = "skip:error"
logger.debug("oauth_api_scope_probe_network_error") logger.debug("oauth_api_scope_probe_network_error")
@@ -422,7 +435,6 @@ async def authenticate_and_rate_limit(
return await call_next(request) return await call_next(request)
@app.on_event("startup")
async def startup_event() -> None: async def startup_event() -> None:
"""Initialize server state on startup.""" """Initialize server state on startup."""
settings = get_settings() settings = get_settings()
@@ -470,7 +482,6 @@ async def startup_event() -> None:
logger.info("gitea_oidc_discovery_ready", extra={"issuer": settings.gitea_base_url}) logger.info("gitea_oidc_discovery_ready", extra={"issuer": settings.gitea_base_url})
@app.on_event("shutdown")
async def shutdown_event() -> None: async def shutdown_event() -> None:
"""Log server shutdown event.""" """Log server shutdown event."""
logger.info("server_stopping") logger.info("server_stopping")
@@ -653,14 +664,19 @@ async def oauth_token_proxy(request: Request) -> JSONResponse:
except Exception as exc: except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid request body") from exc raise HTTPException(status_code=400, detail="Invalid request body") from exc
grant_type = form_data.get("grant_type", "authorization_code") def _field(name: str, default: str = "") -> str:
code = form_data.get("code") """Read a string form field, ignoring uploaded-file parts."""
refresh_token = form_data.get("refresh_token") value = form_data.get(name, default)
code_verifier = form_data.get("code_verifier", "") return value if isinstance(value, str) else default
# ChatGPT sends the client_id and client_secret (that were configured in the GPT Action
# settings) in the POST body. Use those directly; fall back to env vars if not provided. grant_type = _field("grant_type", "authorization_code")
client_id = form_data.get("client_id") or settings.gitea_oauth_client_id code = _field("code")
client_secret = form_data.get("client_secret") or settings.gitea_oauth_client_secret refresh_token = _field("refresh_token")
code_verifier = _field("code_verifier")
# The MCP client (Claude) sends client_id and, for confidential clients, client_secret
# in the POST body. Use those directly; fall back to env vars if not provided.
client_id = _field("client_id") or settings.gitea_oauth_client_id
client_secret = _field("client_secret") or settings.gitea_oauth_client_secret
# Gitea validates that redirect_uri in the token exchange matches the one used during # Gitea validates that redirect_uri in the token exchange matches the one used during
# authorization. Because our /oauth/authorize proxy always forwards our own callback # authorization. Because our /oauth/authorize proxy always forwards our own callback
+72
View File
@@ -0,0 +1,72 @@
"""Tests for the bounded TTL cache utility."""
from __future__ import annotations
import time
import pytest
from aegis_gitea_mcp.cache import BoundedTTLCache
def test_set_and_get_returns_value() -> None:
"""A stored value is returned before it expires."""
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=8)
cache.set("a", 1)
assert cache.get("a") == 1
def test_missing_key_returns_none() -> None:
"""An unknown key returns None."""
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60)
assert cache.get("missing") is None
def test_entry_expires_after_ttl() -> None:
"""An entry is evicted once its TTL elapses."""
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=0.05, max_size=8)
cache.set("a", 1)
assert cache.get("a") == 1
time.sleep(0.06)
assert cache.get("a") is None
def test_size_bound_evicts_oldest() -> None:
"""The cache never exceeds max_size; oldest entries are evicted first."""
cache: BoundedTTLCache[int, int] = BoundedTTLCache(ttl_seconds=60, max_size=3)
for i in range(5):
cache.set(i, i)
assert len(cache) == 3
# 0 and 1 were evicted; 2, 3, 4 remain.
assert cache.get(0) is None
assert cache.get(1) is None
assert cache.get(4) == 4
def test_reinsert_refreshes_recency() -> None:
"""Re-setting a key refreshes its position so it is not evicted first."""
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=2)
cache.set("a", 1)
cache.set("b", 2)
cache.set("a", 3) # refresh "a"
cache.set("c", 4) # should evict "b", the oldest
assert cache.get("b") is None
assert cache.get("a") == 3
assert cache.get("c") == 4
def test_clear_empties_cache() -> None:
"""clear() removes all entries."""
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60)
cache.set("a", 1)
cache.clear()
assert cache.get("a") is None
assert len(cache) == 0
def test_invalid_constructor_args() -> None:
"""Non-positive TTL or size is rejected."""
with pytest.raises(ValueError):
BoundedTTLCache(ttl_seconds=0)
with pytest.raises(ValueError):
BoundedTTLCache(ttl_seconds=60, max_size=0)