refactor: lifespan handlers, module-level imports, bounded scope cache
Replace deprecated @app.on_event startup/shutdown handlers with a FastAPI lifespan context manager, move the inline hashlib/time imports in the auth middleware to module top, and back the unbounded _api_scope_cache with a new size- and TTL-bounded BoundedTTLCache utility. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
"""Bounded, TTL-based in-memory caches with size eviction.
|
||||
|
||||
Provides a small dependency-free cache used by the auth middleware and the
|
||||
per-user authorization layer. Entries expire after a TTL and the cache is
|
||||
bounded by a maximum size to prevent unbounded memory growth from untrusted
|
||||
key cardinality (e.g. one entry per distinct token or per (user, repo) pair).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class BoundedTTLCache(Generic[K, V]):
|
||||
"""A size-bounded cache whose entries expire after a fixed TTL.
|
||||
|
||||
Eviction is least-recently-inserted (FIFO) once ``max_size`` is reached.
|
||||
Expired entries are removed lazily on access and proactively when the
|
||||
cache is full, so the cache never exceeds ``max_size`` live entries.
|
||||
"""
|
||||
|
||||
def __init__(self, *, ttl_seconds: float, max_size: int = 1024) -> None:
|
||||
"""Initialize the cache with a TTL and maximum entry count."""
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds must be positive")
|
||||
if max_size <= 0:
|
||||
raise ValueError("max_size must be positive")
|
||||
self._ttl = float(ttl_seconds)
|
||||
self._max_size = int(max_size)
|
||||
self._store: OrderedDict[K, tuple[V, float]] = OrderedDict()
|
||||
|
||||
def get(self, key: K) -> V | None:
|
||||
"""Return the cached value for ``key`` or ``None`` if absent/expired."""
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
value, expiry = entry
|
||||
if time.monotonic() >= expiry:
|
||||
# Lazily evict expired entry.
|
||||
self._store.pop(key, None)
|
||||
return None
|
||||
return value
|
||||
|
||||
def set(self, key: K, value: V) -> None:
|
||||
"""Store ``value`` under ``key`` with the configured TTL."""
|
||||
now = time.monotonic()
|
||||
# Drop the existing entry so reinsertion refreshes ordering.
|
||||
self._store.pop(key, None)
|
||||
self._store[key] = (value, now + self._ttl)
|
||||
self._evict(now)
|
||||
|
||||
def _evict(self, now: float) -> None:
|
||||
"""Remove expired entries, then enforce the size bound (FIFO)."""
|
||||
expired = [key for key, (_, expiry) in self._store.items() if now >= expiry]
|
||||
for key in expired:
|
||||
self._store.pop(key, None)
|
||||
while len(self._store) > self._max_size:
|
||||
self._store.popitem(last=False)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries (primarily for tests)."""
|
||||
self._store.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of stored (not necessarily live) entries."""
|
||||
return len(self._store)
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -18,6 +20,7 @@ from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
|
||||
from aegis_gitea_mcp.cache import BoundedTTLCache
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
from aegis_gitea_mcp.gitea_client import (
|
||||
GiteaAuthenticationError,
|
||||
@@ -81,9 +84,12 @@ READ_SCOPE = "read:repository"
|
||||
WRITE_SCOPE = "write:repository"
|
||||
|
||||
# Cache of tokens verified to have Gitea API scope.
|
||||
# Key: hash of token prefix, Value: monotonic expiry time.
|
||||
_api_scope_cache: dict[str, float] = {}
|
||||
# Key: hash of token prefix, Value: sentinel marking the token as probe-verified.
|
||||
# Bounded by size and TTL so untrusted token cardinality cannot grow it without limit.
|
||||
_API_SCOPE_CACHE_TTL = 60 # seconds
|
||||
_api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache(
|
||||
ttl_seconds=_API_SCOPE_CACHE_TTL, max_size=4096
|
||||
)
|
||||
|
||||
_REAUTH_GUIDANCE = (
|
||||
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
|
||||
@@ -110,10 +116,21 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
||||
return required_scope in expanded
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
|
||||
await startup_event()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await shutdown_event()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="AegisGitea MCP Server",
|
||||
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
@@ -335,22 +352,18 @@ async def authenticate_and_rate_limit(
|
||||
# Probe: verify the token actually works for Gitea's repository API.
|
||||
# Try both "token" and "Bearer" header formats since Gitea may
|
||||
# accept OAuth tokens differently depending on version/config.
|
||||
import hashlib
|
||||
import time as _time
|
||||
|
||||
token_hash = hashlib.sha256(access_token.encode()).hexdigest()[:16]
|
||||
now = _time.monotonic()
|
||||
probe_result = "skip:cached"
|
||||
token_type = "jwt" if access_token.count(".") == 2 else "opaque"
|
||||
|
||||
if token_hash not in _api_scope_cache or now >= _api_scope_cache[token_hash]:
|
||||
if _api_scope_cache.get(token_hash) is None:
|
||||
# JWT tokens (OIDC) are already cryptographically validated via JWKS above.
|
||||
# Gitea's OIDC access_tokens cannot access the REST API without additional
|
||||
# Gitea-specific scope configuration, so we skip the probe for them and
|
||||
# rely on per-call API errors for actual permission enforcement.
|
||||
if token_type == "jwt":
|
||||
probe_result = "skip:jwt"
|
||||
_api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL
|
||||
_api_scope_cache.set(token_hash, True)
|
||||
else:
|
||||
try:
|
||||
probe_status = None
|
||||
@@ -403,7 +416,7 @@ async def authenticate_and_rate_limit(
|
||||
)
|
||||
else:
|
||||
probe_result = "pass"
|
||||
_api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL
|
||||
_api_scope_cache.set(token_hash, True)
|
||||
except httpx.RequestError:
|
||||
probe_result = "skip:error"
|
||||
logger.debug("oauth_api_scope_probe_network_error")
|
||||
@@ -422,7 +435,6 @@ async def authenticate_and_rate_limit(
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
"""Initialize server state on startup."""
|
||||
settings = get_settings()
|
||||
@@ -470,7 +482,6 @@ async def startup_event() -> None:
|
||||
logger.info("gitea_oidc_discovery_ready", extra={"issuer": settings.gitea_base_url})
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
"""Log server shutdown event."""
|
||||
logger.info("server_stopping")
|
||||
@@ -653,14 +664,19 @@ async def oauth_token_proxy(request: Request) -> JSONResponse:
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid request body") from exc
|
||||
|
||||
grant_type = form_data.get("grant_type", "authorization_code")
|
||||
code = form_data.get("code")
|
||||
refresh_token = form_data.get("refresh_token")
|
||||
code_verifier = form_data.get("code_verifier", "")
|
||||
# ChatGPT sends the client_id and client_secret (that were configured in the GPT Action
|
||||
# settings) in the POST body. Use those directly; fall back to env vars if not provided.
|
||||
client_id = form_data.get("client_id") or settings.gitea_oauth_client_id
|
||||
client_secret = form_data.get("client_secret") or settings.gitea_oauth_client_secret
|
||||
def _field(name: str, default: str = "") -> str:
|
||||
"""Read a string form field, ignoring uploaded-file parts."""
|
||||
value = form_data.get(name, default)
|
||||
return value if isinstance(value, str) else default
|
||||
|
||||
grant_type = _field("grant_type", "authorization_code")
|
||||
code = _field("code")
|
||||
refresh_token = _field("refresh_token")
|
||||
code_verifier = _field("code_verifier")
|
||||
# The MCP client (Claude) sends client_id and, for confidential clients, client_secret
|
||||
# in the POST body. Use those directly; fall back to env vars if not provided.
|
||||
client_id = _field("client_id") or settings.gitea_oauth_client_id
|
||||
client_secret = _field("client_secret") or settings.gitea_oauth_client_secret
|
||||
|
||||
# Gitea validates that redirect_uri in the token exchange matches the one used during
|
||||
# authorization. Because our /oauth/authorize proxy always forwards our own callback
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Tests for the bounded TTL cache utility."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from aegis_gitea_mcp.cache import BoundedTTLCache
|
||||
|
||||
|
||||
def test_set_and_get_returns_value() -> None:
|
||||
"""A stored value is returned before it expires."""
|
||||
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=8)
|
||||
cache.set("a", 1)
|
||||
assert cache.get("a") == 1
|
||||
|
||||
|
||||
def test_missing_key_returns_none() -> None:
|
||||
"""An unknown key returns None."""
|
||||
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60)
|
||||
assert cache.get("missing") is None
|
||||
|
||||
|
||||
def test_entry_expires_after_ttl() -> None:
|
||||
"""An entry is evicted once its TTL elapses."""
|
||||
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=0.05, max_size=8)
|
||||
cache.set("a", 1)
|
||||
assert cache.get("a") == 1
|
||||
time.sleep(0.06)
|
||||
assert cache.get("a") is None
|
||||
|
||||
|
||||
def test_size_bound_evicts_oldest() -> None:
|
||||
"""The cache never exceeds max_size; oldest entries are evicted first."""
|
||||
cache: BoundedTTLCache[int, int] = BoundedTTLCache(ttl_seconds=60, max_size=3)
|
||||
for i in range(5):
|
||||
cache.set(i, i)
|
||||
assert len(cache) == 3
|
||||
# 0 and 1 were evicted; 2, 3, 4 remain.
|
||||
assert cache.get(0) is None
|
||||
assert cache.get(1) is None
|
||||
assert cache.get(4) == 4
|
||||
|
||||
|
||||
def test_reinsert_refreshes_recency() -> None:
|
||||
"""Re-setting a key refreshes its position so it is not evicted first."""
|
||||
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=2)
|
||||
cache.set("a", 1)
|
||||
cache.set("b", 2)
|
||||
cache.set("a", 3) # refresh "a"
|
||||
cache.set("c", 4) # should evict "b", the oldest
|
||||
assert cache.get("b") is None
|
||||
assert cache.get("a") == 3
|
||||
assert cache.get("c") == 4
|
||||
|
||||
|
||||
def test_clear_empties_cache() -> None:
|
||||
"""clear() removes all entries."""
|
||||
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60)
|
||||
cache.set("a", 1)
|
||||
cache.clear()
|
||||
assert cache.get("a") is None
|
||||
assert len(cache) == 0
|
||||
|
||||
|
||||
def test_invalid_constructor_args() -> None:
|
||||
"""Non-positive TTL or size is rejected."""
|
||||
with pytest.raises(ValueError):
|
||||
BoundedTTLCache(ttl_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
BoundedTTLCache(ttl_seconds=60, max_size=0)
|
||||
Reference in New Issue
Block a user