refactor: lifespan handlers, module-level imports, bounded scope cache
Replace deprecated @app.on_event startup/shutdown handlers with a FastAPI lifespan context manager, move the inline hashlib/time imports in the auth middleware to module top, and back the unbounded _api_scope_cache with a new size- and TTL-bounded BoundedTTLCache utility. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,71 @@
|
|||||||
|
"""Bounded, TTL-based in-memory caches with size eviction.
|
||||||
|
|
||||||
|
Provides a small dependency-free cache used by the auth middleware and the
|
||||||
|
per-user authorization layer. Entries expire after a TTL and the cache is
|
||||||
|
bounded by a maximum size to prevent unbounded memory growth from untrusted
|
||||||
|
key cardinality (e.g. one entry per distinct token or per (user, repo) pair).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
K = TypeVar("K")
|
||||||
|
V = TypeVar("V")
|
||||||
|
|
||||||
|
|
||||||
|
class BoundedTTLCache(Generic[K, V]):
|
||||||
|
"""A size-bounded cache whose entries expire after a fixed TTL.
|
||||||
|
|
||||||
|
Eviction is least-recently-inserted (FIFO) once ``max_size`` is reached.
|
||||||
|
Expired entries are removed lazily on access and proactively when the
|
||||||
|
cache is full, so the cache never exceeds ``max_size`` live entries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, ttl_seconds: float, max_size: int = 1024) -> None:
|
||||||
|
"""Initialize the cache with a TTL and maximum entry count."""
|
||||||
|
if ttl_seconds <= 0:
|
||||||
|
raise ValueError("ttl_seconds must be positive")
|
||||||
|
if max_size <= 0:
|
||||||
|
raise ValueError("max_size must be positive")
|
||||||
|
self._ttl = float(ttl_seconds)
|
||||||
|
self._max_size = int(max_size)
|
||||||
|
self._store: OrderedDict[K, tuple[V, float]] = OrderedDict()
|
||||||
|
|
||||||
|
def get(self, key: K) -> V | None:
|
||||||
|
"""Return the cached value for ``key`` or ``None`` if absent/expired."""
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
value, expiry = entry
|
||||||
|
if time.monotonic() >= expiry:
|
||||||
|
# Lazily evict expired entry.
|
||||||
|
self._store.pop(key, None)
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
def set(self, key: K, value: V) -> None:
|
||||||
|
"""Store ``value`` under ``key`` with the configured TTL."""
|
||||||
|
now = time.monotonic()
|
||||||
|
# Drop the existing entry so reinsertion refreshes ordering.
|
||||||
|
self._store.pop(key, None)
|
||||||
|
self._store[key] = (value, now + self._ttl)
|
||||||
|
self._evict(now)
|
||||||
|
|
||||||
|
def _evict(self, now: float) -> None:
|
||||||
|
"""Remove expired entries, then enforce the size bound (FIFO)."""
|
||||||
|
expired = [key for key, (_, expiry) in self._store.items() if now >= expiry]
|
||||||
|
for key in expired:
|
||||||
|
self._store.pop(key, None)
|
||||||
|
while len(self._store) > self._max_size:
|
||||||
|
self._store.popitem(last=False)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all entries (primarily for tests)."""
|
||||||
|
self._store.clear()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of stored (not necessarily live) entries."""
|
||||||
|
return len(self._store)
|
||||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -18,6 +20,7 @@ from pydantic import BaseModel, Field, ValidationError
|
|||||||
|
|
||||||
from aegis_gitea_mcp.audit import get_audit_logger
|
from aegis_gitea_mcp.audit import get_audit_logger
|
||||||
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
|
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
|
||||||
|
from aegis_gitea_mcp.cache import BoundedTTLCache
|
||||||
from aegis_gitea_mcp.config import get_settings
|
from aegis_gitea_mcp.config import get_settings
|
||||||
from aegis_gitea_mcp.gitea_client import (
|
from aegis_gitea_mcp.gitea_client import (
|
||||||
GiteaAuthenticationError,
|
GiteaAuthenticationError,
|
||||||
@@ -81,9 +84,12 @@ READ_SCOPE = "read:repository"
|
|||||||
WRITE_SCOPE = "write:repository"
|
WRITE_SCOPE = "write:repository"
|
||||||
|
|
||||||
# Cache of tokens verified to have Gitea API scope.
|
# Cache of tokens verified to have Gitea API scope.
|
||||||
# Key: hash of token prefix, Value: monotonic expiry time.
|
# Key: hash of token prefix, Value: sentinel marking the token as probe-verified.
|
||||||
_api_scope_cache: dict[str, float] = {}
|
# Bounded by size and TTL so untrusted token cardinality cannot grow it without limit.
|
||||||
_API_SCOPE_CACHE_TTL = 60 # seconds
|
_API_SCOPE_CACHE_TTL = 60 # seconds
|
||||||
|
_api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache(
|
||||||
|
ttl_seconds=_API_SCOPE_CACHE_TTL, max_size=4096
|
||||||
|
)
|
||||||
|
|
||||||
_REAUTH_GUIDANCE = (
|
_REAUTH_GUIDANCE = (
|
||||||
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
|
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
|
||||||
@@ -110,10 +116,21 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
|||||||
return required_scope in expanded
|
return required_scope in expanded
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
|
||||||
|
await startup_event()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await shutdown_event()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="AegisGitea MCP Server",
|
title="AegisGitea MCP Server",
|
||||||
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
|
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
|
||||||
version="0.2.0",
|
version="0.2.0",
|
||||||
|
lifespan=lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -335,22 +352,18 @@ async def authenticate_and_rate_limit(
|
|||||||
# Probe: verify the token actually works for Gitea's repository API.
|
# Probe: verify the token actually works for Gitea's repository API.
|
||||||
# Try both "token" and "Bearer" header formats since Gitea may
|
# Try both "token" and "Bearer" header formats since Gitea may
|
||||||
# accept OAuth tokens differently depending on version/config.
|
# accept OAuth tokens differently depending on version/config.
|
||||||
import hashlib
|
|
||||||
import time as _time
|
|
||||||
|
|
||||||
token_hash = hashlib.sha256(access_token.encode()).hexdigest()[:16]
|
token_hash = hashlib.sha256(access_token.encode()).hexdigest()[:16]
|
||||||
now = _time.monotonic()
|
|
||||||
probe_result = "skip:cached"
|
probe_result = "skip:cached"
|
||||||
token_type = "jwt" if access_token.count(".") == 2 else "opaque"
|
token_type = "jwt" if access_token.count(".") == 2 else "opaque"
|
||||||
|
|
||||||
if token_hash not in _api_scope_cache or now >= _api_scope_cache[token_hash]:
|
if _api_scope_cache.get(token_hash) is None:
|
||||||
# JWT tokens (OIDC) are already cryptographically validated via JWKS above.
|
# JWT tokens (OIDC) are already cryptographically validated via JWKS above.
|
||||||
# Gitea's OIDC access_tokens cannot access the REST API without additional
|
# Gitea's OIDC access_tokens cannot access the REST API without additional
|
||||||
# Gitea-specific scope configuration, so we skip the probe for them and
|
# Gitea-specific scope configuration, so we skip the probe for them and
|
||||||
# rely on per-call API errors for actual permission enforcement.
|
# rely on per-call API errors for actual permission enforcement.
|
||||||
if token_type == "jwt":
|
if token_type == "jwt":
|
||||||
probe_result = "skip:jwt"
|
probe_result = "skip:jwt"
|
||||||
_api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL
|
_api_scope_cache.set(token_hash, True)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
probe_status = None
|
probe_status = None
|
||||||
@@ -403,7 +416,7 @@ async def authenticate_and_rate_limit(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
probe_result = "pass"
|
probe_result = "pass"
|
||||||
_api_scope_cache[token_hash] = now + _API_SCOPE_CACHE_TTL
|
_api_scope_cache.set(token_hash, True)
|
||||||
except httpx.RequestError:
|
except httpx.RequestError:
|
||||||
probe_result = "skip:error"
|
probe_result = "skip:error"
|
||||||
logger.debug("oauth_api_scope_probe_network_error")
|
logger.debug("oauth_api_scope_probe_network_error")
|
||||||
@@ -422,7 +435,6 @@ async def authenticate_and_rate_limit(
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event() -> None:
|
async def startup_event() -> None:
|
||||||
"""Initialize server state on startup."""
|
"""Initialize server state on startup."""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
@@ -470,7 +482,6 @@ async def startup_event() -> None:
|
|||||||
logger.info("gitea_oidc_discovery_ready", extra={"issuer": settings.gitea_base_url})
|
logger.info("gitea_oidc_discovery_ready", extra={"issuer": settings.gitea_base_url})
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def shutdown_event() -> None:
|
async def shutdown_event() -> None:
|
||||||
"""Log server shutdown event."""
|
"""Log server shutdown event."""
|
||||||
logger.info("server_stopping")
|
logger.info("server_stopping")
|
||||||
@@ -653,14 +664,19 @@ async def oauth_token_proxy(request: Request) -> JSONResponse:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise HTTPException(status_code=400, detail="Invalid request body") from exc
|
raise HTTPException(status_code=400, detail="Invalid request body") from exc
|
||||||
|
|
||||||
grant_type = form_data.get("grant_type", "authorization_code")
|
def _field(name: str, default: str = "") -> str:
|
||||||
code = form_data.get("code")
|
"""Read a string form field, ignoring uploaded-file parts."""
|
||||||
refresh_token = form_data.get("refresh_token")
|
value = form_data.get(name, default)
|
||||||
code_verifier = form_data.get("code_verifier", "")
|
return value if isinstance(value, str) else default
|
||||||
# ChatGPT sends the client_id and client_secret (that were configured in the GPT Action
|
|
||||||
# settings) in the POST body. Use those directly; fall back to env vars if not provided.
|
grant_type = _field("grant_type", "authorization_code")
|
||||||
client_id = form_data.get("client_id") or settings.gitea_oauth_client_id
|
code = _field("code")
|
||||||
client_secret = form_data.get("client_secret") or settings.gitea_oauth_client_secret
|
refresh_token = _field("refresh_token")
|
||||||
|
code_verifier = _field("code_verifier")
|
||||||
|
# The MCP client (Claude) sends client_id and, for confidential clients, client_secret
|
||||||
|
# in the POST body. Use those directly; fall back to env vars if not provided.
|
||||||
|
client_id = _field("client_id") or settings.gitea_oauth_client_id
|
||||||
|
client_secret = _field("client_secret") or settings.gitea_oauth_client_secret
|
||||||
|
|
||||||
# Gitea validates that redirect_uri in the token exchange matches the one used during
|
# Gitea validates that redirect_uri in the token exchange matches the one used during
|
||||||
# authorization. Because our /oauth/authorize proxy always forwards our own callback
|
# authorization. Because our /oauth/authorize proxy always forwards our own callback
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""Tests for the bounded TTL cache utility."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from aegis_gitea_mcp.cache import BoundedTTLCache
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_and_get_returns_value() -> None:
|
||||||
|
"""A stored value is returned before it expires."""
|
||||||
|
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=8)
|
||||||
|
cache.set("a", 1)
|
||||||
|
assert cache.get("a") == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_key_returns_none() -> None:
|
||||||
|
"""An unknown key returns None."""
|
||||||
|
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60)
|
||||||
|
assert cache.get("missing") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_entry_expires_after_ttl() -> None:
|
||||||
|
"""An entry is evicted once its TTL elapses."""
|
||||||
|
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=0.05, max_size=8)
|
||||||
|
cache.set("a", 1)
|
||||||
|
assert cache.get("a") == 1
|
||||||
|
time.sleep(0.06)
|
||||||
|
assert cache.get("a") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_size_bound_evicts_oldest() -> None:
|
||||||
|
"""The cache never exceeds max_size; oldest entries are evicted first."""
|
||||||
|
cache: BoundedTTLCache[int, int] = BoundedTTLCache(ttl_seconds=60, max_size=3)
|
||||||
|
for i in range(5):
|
||||||
|
cache.set(i, i)
|
||||||
|
assert len(cache) == 3
|
||||||
|
# 0 and 1 were evicted; 2, 3, 4 remain.
|
||||||
|
assert cache.get(0) is None
|
||||||
|
assert cache.get(1) is None
|
||||||
|
assert cache.get(4) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_reinsert_refreshes_recency() -> None:
|
||||||
|
"""Re-setting a key refreshes its position so it is not evicted first."""
|
||||||
|
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60, max_size=2)
|
||||||
|
cache.set("a", 1)
|
||||||
|
cache.set("b", 2)
|
||||||
|
cache.set("a", 3) # refresh "a"
|
||||||
|
cache.set("c", 4) # should evict "b", the oldest
|
||||||
|
assert cache.get("b") is None
|
||||||
|
assert cache.get("a") == 3
|
||||||
|
assert cache.get("c") == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_empties_cache() -> None:
|
||||||
|
"""clear() removes all entries."""
|
||||||
|
cache: BoundedTTLCache[str, int] = BoundedTTLCache(ttl_seconds=60)
|
||||||
|
cache.set("a", 1)
|
||||||
|
cache.clear()
|
||||||
|
assert cache.get("a") is None
|
||||||
|
assert len(cache) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_constructor_args() -> None:
|
||||||
|
"""Non-positive TTL or size is rejected."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
BoundedTTLCache(ttl_seconds=0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
BoundedTTLCache(ttl_seconds=60, max_size=0)
|
||||||
Reference in New Issue
Block a user