feat: harden Claude MCP OAuth transport

This commit is contained in:
2026-06-13 21:05:11 +02:00
parent ed3130ef74
commit 541124e92a
11 changed files with 1377 additions and 68 deletions
+57 -4
View File
@@ -106,12 +106,12 @@ class Settings(BaseSettings):
description="Secret detection mode: off, mask, or block",
)
# OAuth2 configuration (for ChatGPT per-user Gitea authentication)
# OAuth2 configuration (for per-client Gitea authentication)
oauth_mode: bool = Field(
default=False,
description=(
"Enable per-user OAuth2 authentication mode. "
"When true, each ChatGPT user authenticates with their own Gitea account. "
"When true, each client user authenticates with their own Gitea account. "
"GITEA_TOKEN and MCP_API_KEYS are not required in this mode."
),
)
@@ -126,8 +126,9 @@ class Settings(BaseSettings):
oauth_expected_audience: str = Field(
default="",
description=(
"Expected OIDC audience for access tokens. "
"Defaults to GITEA_OAUTH_CLIENT_ID when unset."
"Additional expected OIDC audience for access tokens. The canonical MCP "
"resource URL and the Gitea OAuth client id are always accepted; set this "
"to require an extra audience value."
),
)
oauth_cache_ttl_seconds: int = Field(
@@ -139,6 +140,37 @@ class Settings(BaseSettings):
default="https://hiddenden.cafe/docs/mcp-gitea",
description="Public documentation URL for OAuth-protected MCP resource behavior",
)
oauth_state_secret: str = Field(
default="",
description=(
"Server secret used to HMAC-sign the OAuth proxy state parameter. "
"Required when OAUTH_MODE=true so callback state is tamper-evident."
),
)
oauth_redirect_allowlist_raw: str = Field(
default="",
description=(
"Comma-separated additional allowed client redirect URIs for the OAuth "
"callback proxy. Claude's callback URLs and loopback URIs are always allowed."
),
alias="OAUTH_REDIRECT_ALLOWLIST",
)
dcr_enabled: bool = Field(
default=True,
description=(
"Enable RFC 7591 Dynamic Client Registration at /register. Claude's "
"connectors register dynamically; disable to require manual client_id/secret."
),
)
dcr_storage_path: Path = Field(
default=Path("/var/lib/aegis-mcp/dcr_clients.json"),
description="Path to the JSON file that persists dynamically registered clients",
)
repo_authz_cache_ttl_seconds: int = Field(
default=60,
description="TTL (seconds) for cached per-user repository permission decisions",
ge=1,
)
# Authentication configuration
auth_enabled: bool = Field(
@@ -269,12 +301,28 @@ class Settings(BaseSettings):
"Set ALLOW_INSECURE_BIND=true to explicitly permit this."
)
extra_redirect_uris: list[str] = []
if self.oauth_redirect_allowlist_raw.strip():
extra_redirect_uris = [
value.strip()
for value in self.oauth_redirect_allowlist_raw.split(",")
if value.strip()
]
object.__setattr__(self, "_oauth_redirect_allowlist", extra_redirect_uris)
if self.oauth_mode:
# In OAuth mode, per-user Gitea tokens are used; no shared bot token or API keys needed.
if not self.gitea_oauth_client_id.strip():
raise ValueError("GITEA_OAUTH_CLIENT_ID is required when OAUTH_MODE=true.")
if not self.gitea_oauth_client_secret.strip():
raise ValueError("GITEA_OAUTH_CLIENT_SECRET is required when OAUTH_MODE=true.")
# The proxy state parameter carries the client's redirect_uri across the Gitea
# round-trip; it must be HMAC-signed, which requires a server-held secret.
if not self.oauth_state_secret.strip():
raise ValueError(
"OAUTH_STATE_SECRET is required when OAUTH_MODE=true so the OAuth "
"proxy state parameter can be HMAC-signed and verified."
)
else:
# Standard API key mode: require bot token and at least one API key.
if not self.gitea_token.strip():
@@ -308,6 +356,11 @@ class Settings(BaseSettings):
"""Get parsed list of repositories allowed for write-mode operations."""
return list(getattr(self, "_write_repository_whitelist", []))
@property
def oauth_redirect_allowlist(self) -> list[str]:
"""Get parsed list of additional allowed client redirect URIs."""
return list(getattr(self, "_oauth_redirect_allowlist", []))
@property
def gitea_base_url(self) -> str:
"""Get Gitea base URL as normalized string."""
+1 -1
View File
@@ -63,7 +63,7 @@ def _tool(
AVAILABLE_TOOLS: list[MCPTool] = [
_tool(
"list_repositories",
"List repositories visible to the configured bot account.",
"List repositories visible to the authenticated Gitea API token.",
{"type": "object", "properties": {}, "required": []},
),
_tool(
+26 -6
View File
@@ -177,6 +177,29 @@ class GiteaOAuthValidator:
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
return jwks
def _acceptable_audiences(self) -> list[str]:
"""Return the set of OIDC audiences this MCP server will accept.
Per the MCP authorization spec (RFC 8707 / RFC 9728) tokens are bound to
the MCP server's canonical resource URL, so the configured public base is
the primary accepted audience. The upstream Gitea OAuth client id is also
accepted because Gitea — the actual token issuer behind this proxy —
stamps ``aud`` with the client id rather than the MCP resource URL. An
operator may add a further required audience via OAUTH_EXPECTED_AUDIENCE.
"""
audiences: list[str] = []
canonical_resource = self.settings.public_base
if canonical_resource:
audiences.append(canonical_resource)
gitea_client_id = self.settings.gitea_oauth_client_id.strip()
if gitea_client_id:
audiences.append(gitea_client_id)
configured = self.settings.oauth_expected_audience.strip()
if configured:
audiences.append(configured)
# Preserve order while removing duplicates.
return list(dict.fromkeys(audiences))
async def _validate_jwt(self, token: str) -> dict[str, Any]:
"""Validate JWT access token using OIDC discovery and JWKS."""
discovery = await self._get_discovery_document()
@@ -216,19 +239,16 @@ class GiteaOAuthValidator:
"oauth_jwt_invalid_jwk",
) from exc
expected_audience = (
self.settings.oauth_expected_audience.strip()
or self.settings.gitea_oauth_client_id.strip()
)
accepted_audiences = self._acceptable_audiences()
decode_options = cast(Any, {"verify_aud": bool(expected_audience)})
decode_options = cast(Any, {"verify_aud": bool(accepted_audiences)})
try:
claims = jwt.decode(
token,
key=cast(Any, public_key),
algorithms=["RS256"],
issuer=issuer,
audience=expected_audience or None,
audience=accepted_audiences or None,
options=decode_options,
)
except InvalidTokenError as exc:
+380
View File
@@ -0,0 +1,380 @@
"""OAuth proxy helpers for signed state, redirect validation, and DCR storage."""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import secrets
import time
from fnmatch import fnmatchcase
from pathlib import Path
from typing import Any
from urllib.parse import ParseResult, urlparse, urlunparse
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
_CLAUDE_CALLBACK_URIS = {
"https://claude.ai/api/mcp/auth_callback",
"https://claude.com/api/mcp/auth_callback",
}
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
_SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {"none", "client_secret_post"}
_SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"}
_SUPPORTED_RESPONSE_TYPES = {"code"}
class OAuthRegistrationRequest(BaseModel):
"""Incoming RFC 7591 client registration request."""
client_name: str | None = Field(default=None, max_length=200)
redirect_uris: list[str] = Field(..., min_length=1)
grant_types: list[str] = Field(default_factory=lambda: ["authorization_code", "refresh_token"])
response_types: list[str] = Field(default_factory=lambda: ["code"])
token_endpoint_auth_method: str = Field(default="none", max_length=64)
scope: str | None = Field(default=None, max_length=512)
model_config = ConfigDict(extra="forbid")
@field_validator("redirect_uris")
@classmethod
def validate_redirect_uris(cls, value: list[str]) -> list[str]:
"""Normalize and validate redirect URIs."""
uris = [uri.strip() for uri in value if isinstance(uri, str) and uri.strip()]
if not uris:
raise ValueError("redirect_uris must contain at least one non-empty URI")
return uris
@field_validator("grant_types")
@classmethod
def validate_grant_types(cls, value: list[str]) -> list[str]:
"""Restrict supported grant types to authorization code and refresh token."""
normalized = [item.strip() for item in value if item.strip()]
if not normalized:
raise ValueError("grant_types must not be empty")
if any(item not in _SUPPORTED_GRANT_TYPES for item in normalized):
raise ValueError("Unsupported grant_types requested")
return normalized
@field_validator("response_types")
@classmethod
def validate_response_types(cls, value: list[str]) -> list[str]:
"""Restrict supported response types to authorization code."""
normalized = [item.strip() for item in value if item.strip()]
if not normalized:
raise ValueError("response_types must not be empty")
if any(item not in _SUPPORTED_RESPONSE_TYPES for item in normalized):
raise ValueError("Unsupported response_types requested")
return normalized
@field_validator("token_endpoint_auth_method")
@classmethod
def validate_token_endpoint_auth_method(cls, value: str) -> str:
"""Restrict token endpoint auth methods to the supported subset."""
normalized = value.strip().lower()
if normalized not in _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS:
raise ValueError("Unsupported token_endpoint_auth_method requested")
return normalized
@model_validator(mode="after")
def validate_pkce_ready(self) -> OAuthRegistrationRequest:
"""Ensure the request is usable for PKCE-based authorization code flow."""
if "authorization_code" not in self.grant_types:
raise ValueError("authorization_code grant is required")
if "code" not in self.response_types:
raise ValueError("code response type is required")
return self
class OAuthClientRecord(BaseModel):
"""Persisted OAuth client registration record."""
client_id: str
client_name: str | None = None
redirect_uris: list[str]
grant_types: list[str]
response_types: list[str]
token_endpoint_auth_method: str
client_id_issued_at: int
client_secret_expires_at: int = 0
client_secret_hash: str | None = None
scope: str | None = None
model_config = ConfigDict(extra="forbid")
def _canonicalize_url(value: str) -> str:
"""Normalize a URL for comparison."""
parsed = urlparse(value.strip())
if not parsed.scheme or not parsed.netloc:
return ""
normalized = ParseResult(
scheme=parsed.scheme.lower(),
netloc=parsed.netloc.lower(),
path=parsed.path or "/",
params=parsed.params,
query=parsed.query,
fragment="",
)
return urlunparse(normalized).rstrip("/")
def is_loopback_redirect_uri(redirect_uri: str) -> bool:
"""Return whether a redirect URI uses a loopback host."""
parsed = urlparse(redirect_uri.strip())
if parsed.scheme != "http":
return False
host = (parsed.hostname or "").lower()
return host in _LOOPBACK_HOSTS
def is_claude_redirect_uri(redirect_uri: str) -> bool:
"""Return whether a redirect URI is a built-in Claude callback URL."""
return _canonicalize_url(redirect_uri) in _CLAUDE_CALLBACK_URIS
def is_redirect_uri_allowed(redirect_uri: str, allowlist: list[str]) -> bool:
"""Return whether a redirect URI is allowed by policy."""
normalized = _canonicalize_url(redirect_uri)
if not normalized:
return False
if is_loopback_redirect_uri(redirect_uri) or is_claude_redirect_uri(redirect_uri):
return True
for pattern in allowlist:
candidate = pattern.strip()
if not candidate:
continue
if fnmatchcase(normalized, _canonicalize_url(candidate) or candidate):
return True
if fnmatchcase(redirect_uri.strip(), candidate):
return True
return False
def is_origin_allowed(origin: str, request_base: str, public_base: str | None) -> bool:
"""Return whether a browser Origin is allowed for MCP transport requests."""
normalized_origin = _canonicalize_url(origin)
if not normalized_origin:
return False
expected_bases = [request_base.rstrip("/")]
if public_base:
expected_bases.append(public_base.rstrip("/"))
return normalized_origin in expected_bases
def encode_proxy_state(
secret: str,
redirect_uri: str,
original_state: str,
*,
ttl_seconds: int = 600,
) -> str:
"""Create a signed OAuth state wrapper for the proxy callback round-trip."""
payload = {
"redirect_uri": redirect_uri,
"state": original_state,
"issued_at": int(time.time()),
"nonce": secrets.token_urlsafe(16),
"ttl_seconds": ttl_seconds,
}
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
signature = hmac.new(secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256)
envelope = {
"payload": payload,
"signature": signature.hexdigest(),
}
return base64.urlsafe_b64encode(
json.dumps(envelope, sort_keys=True, separators=(",", ":")).encode("utf-8")
).decode("ascii")
def decode_proxy_state(secret: str, encoded_state: str) -> dict[str, str]:
"""Verify and unpack a signed OAuth state wrapper."""
try:
raw = base64.urlsafe_b64decode(encoded_state.encode("ascii"))
envelope = json.loads(raw)
except Exception as exc: # pragma: no cover - guarded by tests
raise ValueError("Invalid or missing state parameter") from exc
if not isinstance(envelope, dict):
raise ValueError("Invalid or missing state parameter")
payload = envelope.get("payload")
signature = envelope.get("signature")
if not isinstance(payload, dict) or not isinstance(signature, str):
raise ValueError("Invalid or missing state parameter")
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
expected_signature = hmac.new(
secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
raise ValueError("Invalid or missing state parameter")
issued_at = payload.get("issued_at")
ttl_seconds = payload.get("ttl_seconds")
now = int(time.time())
if not isinstance(issued_at, int) or not isinstance(ttl_seconds, int):
raise ValueError("Invalid or missing state parameter")
if issued_at > now or now - issued_at > max(ttl_seconds, 1):
raise ValueError("Invalid or missing state parameter")
redirect_uri = payload.get("redirect_uri")
if not isinstance(redirect_uri, str) or not redirect_uri.strip():
raise ValueError("Invalid or missing state parameter")
original_state = payload.get("state")
if not isinstance(original_state, str):
raise ValueError("Invalid or missing state parameter")
return {"redirect_uri": redirect_uri, "state": original_state}
class OAuthClientRegistry:
"""Persisted OAuth client registry for dynamic client registration."""
def __init__(self, storage_path: Path) -> None:
"""Initialize registry storage."""
self.storage_path = storage_path
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
self._clients: dict[str, OAuthClientRecord] = {}
self._loaded = False
@staticmethod
def _hash_secret(secret: str) -> str:
"""Hash client secrets before persistence."""
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
def _load(self) -> None:
"""Load persisted registrations from disk once."""
if self._loaded:
return
self._loaded = True
if not self.storage_path.exists():
self._clients = {}
return
raw = json.loads(self.storage_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("Persisted DCR storage must be a JSON object")
clients: dict[str, OAuthClientRecord] = {}
for client_id, payload in raw.items():
if not isinstance(client_id, str):
raise ValueError("Persisted client id must be a string")
if not isinstance(payload, dict):
raise ValueError(f"Persisted client record for {client_id} must be a mapping")
record = OAuthClientRecord.model_validate({"client_id": client_id, **payload})
clients[client_id] = record
self._clients = clients
def _persist(self) -> None:
"""Write registrations atomically."""
payload = {
client_id: record.model_dump(mode="json", exclude={"client_id"})
for client_id, record in self._clients.items()
}
tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp")
tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2), encoding="utf-8")
tmp_path.replace(self.storage_path)
def get(self, client_id: str) -> OAuthClientRecord | None:
"""Look up a registered client by identifier."""
self._load()
return self._clients.get(client_id)
def is_known_client(
self,
client_id: str,
*,
fallback_client_id: str = "",
fallback_client_secret: str = "",
) -> bool:
"""Return whether a client is recognized by the registry or environment."""
if not client_id.strip():
return False
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
return True
return self.get(client_id) is not None
def validate_client_secret(
self,
client_id: str,
client_secret: str | None,
*,
fallback_client_id: str = "",
fallback_client_secret: str = "",
) -> bool:
"""Validate a client identifier and optional secret."""
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
if not fallback_client_secret.strip():
return True
if not client_secret:
return False
return hmac.compare_digest(
self._hash_secret(client_secret), self._hash_secret(fallback_client_secret.strip())
)
record = self.get(client_id)
if record is None:
return False
if record.client_secret_hash is None:
return True
if not client_secret:
return False
return hmac.compare_digest(self._hash_secret(client_secret), record.client_secret_hash)
def register(self, request: OAuthRegistrationRequest) -> dict[str, Any]:
"""Persist a new OAuth client registration and return its public metadata."""
self._load()
client_id = secrets.token_urlsafe(24)
client_secret: str | None = None
client_secret_hash: str | None = None
if request.token_endpoint_auth_method != "none":
client_secret = secrets.token_urlsafe(32)
client_secret_hash = self._hash_secret(client_secret)
record = OAuthClientRecord(
client_id=client_id,
client_name=request.client_name,
redirect_uris=list(request.redirect_uris),
grant_types=list(request.grant_types),
response_types=list(request.response_types),
token_endpoint_auth_method=request.token_endpoint_auth_method,
client_id_issued_at=int(time.time()),
client_secret_hash=client_secret_hash,
scope=request.scope,
)
self._clients[client_id] = record
self._persist()
response: dict[str, Any] = record.model_dump(exclude={"client_secret_hash"})
if client_secret is not None:
response["client_secret"] = client_secret
response["client_secret_expires_at"] = 0
return response
_oauth_client_registry: OAuthClientRegistry | None = None
def get_oauth_client_registry(storage_path: Path) -> OAuthClientRegistry:
"""Get or create the global OAuth client registry."""
global _oauth_client_registry
if _oauth_client_registry is None or _oauth_client_registry.storage_path != storage_path:
_oauth_client_registry = OAuthClientRegistry(storage_path)
return _oauth_client_registry
def reset_oauth_client_registry() -> None:
"""Reset the global OAuth client registry (primarily for tests)."""
global _oauth_client_registry
_oauth_client_registry = None
+385 -43
View File
@@ -3,10 +3,10 @@
from __future__ import annotations
import asyncio
import base64
import hashlib
import json
import logging
import time
import urllib.parse
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
@@ -36,11 +36,20 @@ from aegis_gitea_mcp.mcp_protocol import (
get_tool_by_name,
)
from aegis_gitea_mcp.oauth import get_oauth_validator
from aegis_gitea_mcp.oauth_flow import (
OAuthRegistrationRequest,
decode_proxy_state,
encode_proxy_state,
get_oauth_client_registry,
is_origin_allowed,
is_redirect_uri_allowed,
)
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
from aegis_gitea_mcp.rate_limit import get_rate_limiter
from aegis_gitea_mcp.request_context import (
clear_gitea_auth_context,
get_gitea_user_login,
get_gitea_user_scopes,
get_gitea_user_token,
set_gitea_user_login,
@@ -94,9 +103,40 @@ _api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache(
_REAUTH_GUIDANCE = (
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
"Revoke the authorization in Gitea (Settings > Applications > Authorized OAuth2 Applications) "
"and in ChatGPT (Settings > Connected apps), then re-authorize."
"and in your client, then re-authorize."
)
_repo_authz_cache: BoundedTTLCache[str, bool] | None = None
def _get_repo_authz_cache() -> BoundedTTLCache[str, bool]:
"""Get the bounded cache for per-user repository permission checks."""
global _repo_authz_cache
settings = get_settings()
if _repo_authz_cache is None:
_repo_authz_cache = BoundedTTLCache(
ttl_seconds=settings.repo_authz_cache_ttl_seconds,
max_size=2048,
)
return _repo_authz_cache
def reset_repo_authz_cache() -> None:
"""Reset the repository authorization cache (primarily for tests)."""
global _repo_authz_cache
_repo_authz_cache = None
def _repo_authz_cache_key(login: str, repository: str, required_scope: str) -> str:
"""Build a bounded cache key for a user/repository permission check."""
normalized_login = login.strip().lower()
return f"{normalized_login}:{repository.lower()}:{required_scope}"
def _is_mcp_transport_path(path: str) -> bool:
"""Return whether a request targets the MCP transport surface."""
return path in {"/mcp", "/mcp/sse"} or path.startswith("/mcp/")
def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
"""Return whether granted scopes satisfy the required MCP tool scope."""
@@ -116,6 +156,129 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
return required_scope in expanded
def _repo_permission_satisfied(permission: dict[str, Any], required_scope: str) -> bool:
"""Return whether a repository permission payload satisfies the requested scope."""
permission_name = str(permission.get("permission", "")).lower().strip()
if permission_name in {"admin", "owner"}:
return True
if required_scope == WRITE_SCOPE and permission_name == "write":
return True
if required_scope == READ_SCOPE and permission_name in {"read", "write"}:
return True
nested_permissions = permission.get("permissions")
if isinstance(nested_permissions, dict):
return _repo_permission_satisfied(nested_permissions, required_scope)
if required_scope == WRITE_SCOPE:
return bool(permission.get("push") or permission.get("admin"))
return bool(permission.get("pull") or permission.get("push") or permission.get("admin"))
async def _verify_user_repository_access(
*,
repository: str,
required_scope: str,
user_login: str,
correlation_id: str,
tool_name: str,
) -> None:
"""Verify the authenticated user can access the target repository before PAT fallback."""
settings = get_settings()
audit = get_audit_logger()
service_token = settings.gitea_token.strip()
if not service_token:
raise HTTPException(status_code=500, detail="Repository authorization misconfigured")
if not user_login.strip() or user_login == "unknown":
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason="repository_permission_missing_user",
correlation_id=correlation_id,
)
raise HTTPException(
status_code=403,
detail="Unable to verify repository permission for this user.",
)
cache_key = _repo_authz_cache_key(user_login, repository, required_scope)
cached = _get_repo_authz_cache().get(cache_key)
if cached is True:
return
owner, repo = repository.split("/", 1)
encoded_owner = urllib.parse.quote(owner, safe="")
encoded_repo = urllib.parse.quote(repo, safe="")
encoded_user = urllib.parse.quote(user_login, safe="")
permission_url = (
f"{settings.gitea_base_url}/api/v1/repos/{encoded_owner}/{encoded_repo}"
f"/collaborators/{encoded_user}/permission"
)
try:
async with httpx.AsyncClient(timeout=settings.request_timeout_seconds) as client:
response = await client.get(
permission_url,
headers={"Authorization": f"token {service_token}", "Accept": "application/json"},
)
except httpx.RequestError as exc:
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason="repository_permission_probe_failed",
correlation_id=correlation_id,
)
raise HTTPException(
status_code=403,
detail="Unable to verify repository permission for this user.",
) from exc
if response.status_code != 200:
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason=f"repository_permission_probe:{response.status_code}",
correlation_id=correlation_id,
)
raise HTTPException(
status_code=403,
detail="User does not have permission for the requested repository.",
)
try:
permission_payload = response.json()
except ValueError as exc:
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason="repository_permission_invalid_json",
correlation_id=correlation_id,
)
raise HTTPException(
status_code=403,
detail="Unable to verify repository permission for this user.",
) from exc
if isinstance(permission_payload, dict) and _repo_permission_satisfied(
permission_payload, required_scope
):
_get_repo_authz_cache().set(cache_key, True)
return
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason="repository_permission_denied",
correlation_id=correlation_id,
)
raise HTTPException(
status_code=403,
detail="User does not have permission for the requested repository.",
)
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
@@ -243,6 +406,67 @@ async def request_context_middleware(
metrics.record_http_request(request.method, request.url.path, status_code)
def _cors_headers(origin: str) -> dict[str, str]:
"""Build strict CORS headers for a validated browser origin."""
return {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "GET,POST,OPTIONS",
"Access-Control-Allow-Headers": "Authorization,Content-Type,MCP-Protocol-Version,X-Request-ID",
"Access-Control-Expose-Headers": "X-Request-ID,WWW-Authenticate",
"Vary": "Origin",
}
@app.middleware("http")
async def strict_origin_and_cors_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Enforce strict browser origins for MCP transport requests."""
if request.url.path not in {"/mcp", "/mcp/sse"}:
return await call_next(request)
settings = get_settings()
origin = request.headers.get("origin")
expected_base = settings.public_base or str(request.base_url).rstrip("/")
if origin and not is_origin_allowed(origin, expected_base, settings.public_base):
return JSONResponse(
status_code=403,
content={
"error": "Origin not allowed",
"message": "The request origin is not allowed for this MCP transport.",
"request_id": getattr(request.state, "request_id", "-"),
},
)
if request.method == "OPTIONS":
response = Response(status_code=204)
else:
response = await call_next(request)
if origin and is_origin_allowed(origin, expected_base, settings.public_base):
for header, value in _cors_headers(origin).items():
response.headers[header] = value
return response
def _oauth_invalid_client_response() -> JSONResponse:
"""Return an RFC 6749 invalid_client error for token endpoint failures."""
response = JSONResponse(status_code=401, content={"error": "invalid_client"})
response.headers["WWW-Authenticate"] = 'Basic realm="oauth"'
return response
def _jsonrpc_error(message_id: Any, code: int, message: str) -> JSONResponse:
"""Build a JSON-RPC error response envelope."""
return JSONResponse(
content={"jsonrpc": "2.0", "id": message_id, "error": {"code": code, "message": message}}
)
@app.middleware("http")
async def authenticate_and_rate_limit(
request: Request,
@@ -255,11 +479,14 @@ async def authenticate_and_rate_limit(
if request.url.path in {"/", "/health"}:
return await call_next(request)
if request.method == "OPTIONS" and request.url.path in {"/mcp", "/mcp/sse"}:
return await call_next(request)
if request.url.path == "/metrics" and settings.metrics_enabled:
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
return await call_next(request)
# OAuth discovery and token endpoints must be public so ChatGPT can complete the flow.
# OAuth discovery and token endpoints must be public so MCP clients can complete the flow.
if request.url.path in {
"/oauth/token",
"/.well-known/oauth-protected-resource",
@@ -268,7 +495,11 @@ async def authenticate_and_rate_limit(
}:
return await call_next(request)
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
if not (
request.url.path in {"/mcp/tools"}
or _is_mcp_transport_path(request.url.path)
or request.url.path.startswith("/automation/")
):
return await call_next(request)
oauth_validator = get_oauth_validator()
@@ -296,7 +527,7 @@ async def authenticate_and_rate_limit(
return await call_next(request)
if not access_token:
if request.url.path.startswith("/mcp/"):
if _is_mcp_transport_path(request.url.path):
return _oauth_unauthorized_response(
request,
"Provide Authorization: Bearer <token>.",
@@ -315,7 +546,7 @@ async def authenticate_and_rate_limit(
access_token, client_ip, user_agent
)
if not is_valid:
if request.url.path.startswith("/mcp/"):
if _is_mcp_transport_path(request.url.path):
return _oauth_unauthorized_response(
request,
error_message or "Invalid or expired OAuth token.",
@@ -400,7 +631,7 @@ async def authenticate_and_rate_limit(
"OAuth token is valid but lacks required Gitea API access. "
"Re-authorize this OAuth app in Gitea and try again."
)
if request.url.path.startswith("/mcp/"):
if _is_mcp_transport_path(request.url.path):
return _oauth_unauthorized_response(
request,
message,
@@ -508,9 +739,14 @@ async def health() -> dict[str, str]:
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Protected Resource Metadata (RFC 9728).
Required by the MCP Authorization spec so that OAuth clients (e.g. ChatGPT)
can discover the authorization server that protects this resource.
ChatGPT fetches this endpoint when it first connects to the MCP server via SSE.
Required by the MCP Authorization spec so that OAuth clients (Claude's
connector infrastructure) can discover the authorization server that
protects this resource. Claude fetches this endpoint when it first connects.
The ``resource`` value MUST be THIS server's own canonical public URL: the
MCP client verifies that the resource identifier matches the origin it
derived the MCP server URL from (RFC 9728 / RFC 8707). Returning the upstream
Gitea URL here would fail that check.
"""
settings = get_settings()
gitea_base = settings.gitea_base_url
@@ -521,7 +757,7 @@ async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
return JSONResponse(
content={
"resource": gitea_base,
"resource": base_url,
"authorization_servers": authorization_servers,
"bearer_methods_supported": ["header"],
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
@@ -534,24 +770,52 @@ async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
async def oauth_authorize_proxy(request: Request) -> RedirectResponse:
"""Proxy OAuth authorization to Gitea, replacing redirect_uri with our own callback.
Clients (ChatGPT, Claude, etc.) send their own redirect_uri which Gitea doesn't know
Clients (Claude, Claude Code, Cowork, etc.) send their own redirect_uri which Gitea doesn't know
about. This endpoint intercepts the request, encodes the original redirect_uri and
state into a new state parameter, and forwards the request to Gitea using the MCP
server's own callback URI — the only URI that needs to be registered in Gitea.
"""
settings = get_settings()
base_url = settings.public_base or str(request.base_url).rstrip("/")
registry = get_oauth_client_registry(settings.dcr_storage_path)
params = dict(request.query_params)
client_redirect_uri = params.pop("redirect_uri", "")
client_redirect_uri = params.pop("redirect_uri", "").strip()
client_id = params.get("client_id", "").strip() or settings.gitea_oauth_client_id.strip()
original_state = params.get("state", "")
params.pop("client_secret", None)
# Encode the client's redirect_uri + original state into a tamper-evident wrapper.
# We simply base64-encode a JSON blob; Gitea will echo it back on the callback.
proxy_state_data = {"redirect_uri": client_redirect_uri, "state": original_state}
proxy_state = base64.urlsafe_b64encode(json.dumps(proxy_state_data).encode()).decode()
if not client_id:
raise HTTPException(status_code=400, detail="Missing client_id")
if not registry.is_known_client(
client_id,
fallback_client_id=settings.gitea_oauth_client_id,
):
raise HTTPException(status_code=401, detail="invalid_client")
if not client_redirect_uri:
raise HTTPException(status_code=400, detail="Missing redirect_uri")
if not is_redirect_uri_allowed(client_redirect_uri, settings.oauth_redirect_allowlist):
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
code_challenge = params.get("code_challenge", "").strip()
code_challenge_method = params.get("code_challenge_method", "S256").strip().upper()
if not code_challenge:
raise HTTPException(status_code=400, detail="PKCE code_challenge is required")
if code_challenge_method != "S256":
raise HTTPException(status_code=400, detail="PKCE code_challenge_method must be S256")
proxy_state = encode_proxy_state(
settings.oauth_state_secret,
client_redirect_uri,
original_state,
ttl_seconds=600,
)
params["client_id"] = settings.gitea_oauth_client_id
params["state"] = proxy_state
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
params["redirect_uri"] = f"{base_url}/oauth/callback"
gitea_authorize_url = f"{settings.gitea_base_url}/login/oauth/authorize"
@@ -568,14 +832,17 @@ async def oauth_callback_proxy(request: Request) -> RedirectResponse:
error_description = request.query_params.get("error_description", "")
try:
state_data = json.loads(base64.urlsafe_b64decode(proxy_state.encode()))
state_data = decode_proxy_state(get_settings().oauth_state_secret, proxy_state)
client_redirect_uri = state_data["redirect_uri"]
original_state = state_data["state"]
except Exception as exc:
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid or missing state parameter") from exc
settings = get_settings()
if not client_redirect_uri:
raise HTTPException(status_code=400, detail="No client redirect_uri in state")
if not is_redirect_uri_allowed(client_redirect_uri, settings.oauth_redirect_allowlist):
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
result_params: dict[str, str] = {}
if error:
@@ -595,26 +862,31 @@ async def oauth_callback_proxy(request: Request) -> RedirectResponse:
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).
Proxies Gitea's OAuth authorization server metadata so that ChatGPT can
discover the authorize URL, token URL, and supported features directly
from this server without needing to know the Gitea URL upfront.
Advertises this server's OAuth proxy endpoints so that Claude's connector
infrastructure can discover the authorize URL, token URL, and dynamic client
registration endpoint directly from this server without knowing the Gitea URL
upfront. The authorize/token endpoints are this server's proxy routes because
Gitea does not know Claude's redirect_uri.
"""
settings = get_settings()
base_url = settings.public_base or str(request.base_url).rstrip("/")
gitea_base = settings.gitea_base_url
return JSONResponse(
content={
"issuer": gitea_base,
"authorization_endpoint": f"{base_url}/oauth/authorize",
"token_endpoint": f"{base_url}/oauth/token",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
}
)
metadata: dict[str, Any] = {
"issuer": gitea_base,
"authorization_endpoint": f"{base_url}/oauth/authorize",
"token_endpoint": f"{base_url}/oauth/token",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
}
if settings.dcr_enabled:
# RFC 7591 dynamic client registration endpoint (Claude registers here).
metadata["registration_endpoint"] = f"{base_url}/register"
return JSONResponse(content=metadata)
@app.get("/.well-known/openid-configuration")
@@ -649,15 +921,47 @@ async def openid_configuration(request: Request) -> JSONResponse:
)
@app.post("/register")
async def oauth_dynamic_client_registration(request: Request) -> JSONResponse:
"""Persist a new OAuth client registration for Claude and similar MCP clients."""
settings = get_settings()
if not settings.dcr_enabled:
raise HTTPException(status_code=404, detail="Dynamic client registration is disabled")
content_type = request.headers.get("content-type", "").split(";", 1)[0].strip().lower()
if content_type != "application/json":
raise HTTPException(status_code=415, detail="Content-Type must be application/json")
registry = get_oauth_client_registry(settings.dcr_storage_path)
try:
payload = await request.json()
registration_request = OAuthRegistrationRequest.model_validate(payload)
except ValidationError as exc:
raise HTTPException(status_code=400, detail="Invalid registration payload") from exc
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid registration payload") from exc
for redirect_uri in registration_request.redirect_uris:
if not is_redirect_uri_allowed(redirect_uri, settings.oauth_redirect_allowlist):
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
response = registry.register(registration_request)
response["client_id_issued_at"] = int(time.time())
response["client_secret_expires_at"] = 0
return JSONResponse(content=response)
@app.post("/oauth/token")
async def oauth_token_proxy(request: Request) -> JSONResponse:
"""Proxy OAuth2 token exchange to Gitea.
ChatGPT sends the authorization code here after the user logs in to Gitea.
The client sends the authorization code here after the user logs in to Gitea.
This endpoint forwards the code to Gitea's token endpoint and returns the
access_token to ChatGPT, completing the OAuth2 Authorization Code flow.
access_token to the client, completing the OAuth2 Authorization Code flow.
"""
settings = get_settings()
registry = get_oauth_client_registry(settings.dcr_storage_path)
try:
form_data = await request.form()
@@ -683,32 +987,45 @@ async def oauth_token_proxy(request: Request) -> JSONResponse:
# URI to Gitea, we must use the same URI here — not the client's original redirect_uri.
base_url = settings.public_base or str(request.base_url).rstrip("/")
if not client_id:
return _oauth_invalid_client_response()
if not registry.validate_client_secret(
client_id,
client_secret or None,
fallback_client_id=settings.gitea_oauth_client_id,
fallback_client_secret=settings.gitea_oauth_client_secret,
):
return _oauth_invalid_client_response()
gitea_token_url = f"{settings.gitea_base_url}/login/oauth/access_token"
upstream_client_id = settings.gitea_oauth_client_id
upstream_client_secret = settings.gitea_oauth_client_secret
if grant_type == "refresh_token":
if not refresh_token:
raise HTTPException(status_code=400, detail="Missing refresh_token")
payload: dict[str, str] = {
"client_id": client_id,
"client_secret": client_secret,
"client_id": upstream_client_id,
"client_secret": upstream_client_secret,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
else:
if not code:
raise HTTPException(status_code=400, detail="Missing authorization code")
if not code_verifier:
raise HTTPException(status_code=400, detail="Missing code_verifier")
payload = {
"client_id": client_id,
"client_secret": client_secret,
"client_id": upstream_client_id,
"client_secret": upstream_client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{base_url}/oauth/callback",
}
if code_verifier:
payload["code_verifier"] = code_verifier
payload["code_verifier"] = code_verifier
try:
async with httpx.AsyncClient(timeout=30) as client:
async with httpx.AsyncClient(timeout=settings.request_timeout_seconds) as client:
response = await client.post(
gitea_token_url,
data=payload,
@@ -846,6 +1163,29 @@ async def _execute_tool_call(
if not user_token:
raise HTTPException(status_code=401, detail="Missing authenticated user token context")
if settings.gitea_token.strip():
if not repository:
audit.log_access_denied(
tool_name=tool_name,
reason="service_pat_requires_repository_target",
correlation_id=correlation_id,
)
raise HTTPException(
status_code=403,
detail=(
"Service PAT mode requires a repository target so per-user "
"permission can be verified."
),
)
user_login = get_gitea_user_login()
await _verify_user_repository_access(
repository=repository,
required_scope=required_scope,
user_login=user_login or "",
correlation_id=correlation_id,
tool_name=tool_name,
)
# In OAuth mode, Gitea OIDC access_tokens can't call the Gitea REST API
# (they only carry OIDC scopes). If a service PAT is configured via
# GITEA_TOKEN, use that for API calls while OIDC handles identity/authz.
@@ -970,6 +1310,7 @@ async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
)
@app.get("/mcp")
@app.get("/mcp/sse")
async def sse_endpoint(request: Request) -> StreamingResponse:
"""Server-Sent Events endpoint for MCP transport."""
@@ -1004,6 +1345,7 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
)
@app.post("/mcp")
@app.post("/mcp/sse")
async def sse_message_handler(request: Request) -> JSONResponse:
"""Handle POST messages for MCP SSE transport."""
+1 -1
View File
@@ -18,7 +18,7 @@ from aegis_gitea_mcp.tools.arguments import (
async def list_repositories_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
"""List repositories visible to the bot user.
"""List repositories visible to the active Gitea API token.
Args:
gitea: Initialized Gitea client.
+7
View File
@@ -9,9 +9,11 @@ from aegis_gitea_mcp.audit import reset_audit_logger
from aegis_gitea_mcp.auth import reset_validator
from aegis_gitea_mcp.config import reset_settings
from aegis_gitea_mcp.oauth import reset_oauth_validator
from aegis_gitea_mcp.oauth_flow import reset_oauth_client_registry
from aegis_gitea_mcp.observability import reset_metrics_registry
from aegis_gitea_mcp.policy import reset_policy_engine
from aegis_gitea_mcp.rate_limit import reset_rate_limiter
from aegis_gitea_mcp.server import reset_repo_authz_cache
@pytest.fixture(autouse=True)
@@ -22,6 +24,8 @@ def reset_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[
reset_audit_logger()
reset_validator()
reset_oauth_validator()
reset_oauth_client_registry()
reset_repo_authz_cache()
reset_policy_engine()
reset_rate_limiter()
reset_metrics_registry()
@@ -37,6 +41,8 @@ def reset_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[
reset_audit_logger()
reset_validator()
reset_oauth_validator()
reset_oauth_client_registry()
reset_repo_authz_cache()
reset_policy_engine()
reset_rate_limiter()
reset_metrics_registry()
@@ -66,4 +72,5 @@ def mock_env_oauth(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
+1
View File
@@ -28,6 +28,7 @@ def full_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("ENVIRONMENT", "test")
monkeypatch.setenv("MCP_HOST", "127.0.0.1")
monkeypatch.setenv("MCP_PORT", "8080")
+240 -4
View File
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
from aegis_gitea_mcp.config import reset_settings
from aegis_gitea_mcp.oauth import GiteaOAuthValidator, get_oauth_validator, reset_oauth_validator
from aegis_gitea_mcp.oauth_flow import OAuthClientRegistry, OAuthRegistrationRequest
from aegis_gitea_mcp.request_context import (
get_gitea_user_login,
get_gitea_user_token,
@@ -40,6 +41,7 @@ def mock_env_oauth(monkeypatch):
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
@@ -57,6 +59,24 @@ def oauth_client(mock_env_oauth):
return TestClient(app, raise_server_exceptions=False)
def _register_public_client(oauth_client: TestClient, redirect_uri: str) -> dict[str, str]:
"""Register a public OAuth client for test flows."""
response = oauth_client.post(
"/register",
json={
"client_name": "pytest-client",
"redirect_uris": [redirect_uri],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
},
)
assert response.status_code == 200
payload = response.json()
assert "client_id" in payload
return payload
# ---------------------------------------------------------------------------
# GiteaOAuthValidator unit tests
# ---------------------------------------------------------------------------
@@ -248,19 +268,39 @@ def test_oauth_token_endpoint_available_when_oauth_mode_false(monkeypatch):
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.post("/oauth/token", data={"code": "abc123"})
registration = client.post(
"/register",
json={
"client_name": "pytest-client",
"redirect_uris": ["http://127.0.0.1:8080/callback"],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code"],
"response_types": ["code"],
},
)
assert registration.status_code == 200
client_id = registration.json()["client_id"]
response = client.post(
"/oauth/token",
data={"client_id": client_id, "code": "abc123", "code_verifier": "pkce"},
)
assert response.status_code == 200
def test_oauth_token_endpoint_missing_code(oauth_client):
"""POST /oauth/token without a code returns 400."""
response = oauth_client.post("/oauth/token", data={})
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
response = oauth_client.post(
"/oauth/token",
data={"client_id": client_data["client_id"], "code_verifier": "pkce"},
)
assert response.status_code == 400
def test_oauth_token_endpoint_proxy_success(oauth_client):
"""POST /oauth/token proxies successfully to Gitea and returns access_token."""
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
@@ -276,7 +316,11 @@ def test_oauth_token_endpoint_proxy_success(oauth_client):
response = oauth_client.post(
"/oauth/token",
data={"code": "auth-code-123", "redirect_uri": "https://chat.openai.com/callback"},
data={
"client_id": client_data["client_id"],
"code": "auth-code-123",
"code_verifier": "pkce-verifier",
},
)
assert response.status_code == 200
@@ -286,6 +330,7 @@ def test_oauth_token_endpoint_proxy_success(oauth_client):
def test_oauth_token_endpoint_gitea_error(oauth_client):
"""POST /oauth/token propagates Gitea error status."""
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.json.return_value = {"error": "invalid_grant"}
@@ -296,11 +341,202 @@ def test_oauth_token_endpoint_gitea_error(oauth_client):
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
response = oauth_client.post("/oauth/token", data={"code": "bad-code"})
response = oauth_client.post(
"/oauth/token",
data={
"client_id": client_data["client_id"],
"code": "bad-code",
"code_verifier": "pkce-verifier",
},
)
assert response.status_code == 400
def test_oauth_authorize_and_callback_round_trip(oauth_client):
"""OAuth authorize/callback round-trip preserves the original redirect URI and state."""
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
authorize_response = oauth_client.get(
"/oauth/authorize",
params={
"client_id": client_data["client_id"],
"redirect_uri": "http://127.0.0.1:8080/callback",
"state": "original-state",
"code_challenge": "pkce-challenge",
"code_challenge_method": "S256",
},
follow_redirects=False,
)
assert authorize_response.status_code == 302
location = authorize_response.headers["location"]
assert "state=" in location
assert "redirect_uri=http%3A%2F%2F127.0.0.1%3A8080%2Fcallback" not in location
from urllib.parse import parse_qs, urlparse
parsed = urlparse(location)
query = parse_qs(parsed.query)
proxy_state = query["state"][0]
callback_response = oauth_client.get(
"/oauth/callback",
params={"state": proxy_state, "code": "auth-code-123"},
follow_redirects=False,
)
assert callback_response.status_code == 302
callback_location = callback_response.headers["location"]
assert callback_location.startswith("http://127.0.0.1:8080/callback?")
assert "code=auth-code-123" in callback_location
assert "state=original-state" in callback_location
def test_oauth_callback_rejects_tampered_state(oauth_client):
"""OAuth callback rejects modified signed proxy state."""
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
authorize_response = oauth_client.get(
"/oauth/authorize",
params={
"client_id": client_data["client_id"],
"redirect_uri": "http://127.0.0.1:8080/callback",
"state": "original-state",
"code_challenge": "pkce-challenge",
"code_challenge_method": "S256",
},
follow_redirects=False,
)
from urllib.parse import parse_qs, urlparse
proxy_state = parse_qs(urlparse(authorize_response.headers["location"]).query)["state"][0]
tampered_state = proxy_state[:-1] + ("A" if proxy_state[-1] != "A" else "B")
callback_response = oauth_client.get(
"/oauth/callback",
params={"state": tampered_state, "code": "auth-code-123"},
)
assert callback_response.status_code == 400
@pytest.mark.parametrize(
"redirect_uri",
[
"https://claude.ai/api/mcp/auth_callback",
"https://claude.com/api/mcp/auth_callback",
],
)
def test_dcr_accepts_default_claude_callbacks(oauth_client, redirect_uri):
"""Claude's hosted connector callback URLs are allowed by default."""
response = oauth_client.post(
"/register",
json={
"client_name": "claude-client",
"redirect_uris": [redirect_uri],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
},
)
assert response.status_code == 200
def test_oauth_authorize_rejects_unknown_client(oauth_client):
"""OAuth authorize returns invalid_client for unregistered client IDs."""
response = oauth_client.get(
"/oauth/authorize",
params={
"client_id": "unknown-client",
"redirect_uri": "http://127.0.0.1:8080/callback",
"state": "x",
"code_challenge": "pkce-challenge",
"code_challenge_method": "S256",
},
)
assert response.status_code == 401
assert response.json()["detail"] == "invalid_client"
def test_oauth_token_rejects_unknown_dcr_client(oauth_client):
"""Unknown dynamic clients receive RFC 6749 invalid_client from token endpoint."""
response = oauth_client.post(
"/oauth/token",
data={
"client_id": "deleted-or-unknown-client",
"code": "auth-code-123",
"code_verifier": "pkce-verifier",
},
)
assert response.status_code == 401
assert response.json() == {"error": "invalid_client"}
def test_oauth_authorize_requires_pkce_s256(oauth_client):
"""Authorization endpoint enforces PKCE S256 for public clients."""
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
missing_challenge = oauth_client.get(
"/oauth/authorize",
params={
"client_id": client_data["client_id"],
"redirect_uri": "http://127.0.0.1:8080/callback",
"state": "x",
},
)
plain_method = oauth_client.get(
"/oauth/authorize",
params={
"client_id": client_data["client_id"],
"redirect_uri": "http://127.0.0.1:8080/callback",
"state": "x",
"code_challenge": "pkce-challenge",
"code_challenge_method": "plain",
},
)
assert missing_challenge.status_code == 400
assert plain_method.status_code == 400
def test_register_rejects_foreign_redirect_uri(oauth_client):
"""DCR rejects redirect URIs outside the allowlist and loopback/Claude patterns."""
response = oauth_client.post(
"/register",
json={
"client_name": "pytest-client",
"redirect_uris": ["https://evil.example.com/callback"],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code"],
"response_types": ["code"],
},
)
assert response.status_code == 400
def test_dcr_registry_persists_registered_clients(tmp_path):
"""Registered OAuth clients survive registry reloads."""
storage_path = tmp_path / "dcr_clients.json"
registry = OAuthClientRegistry(storage_path)
request = OAuthRegistrationRequest.model_validate(
{
"client_name": "persisted-client",
"redirect_uris": ["http://127.0.0.1:8080/callback"],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code"],
"response_types": ["code"],
}
)
response = registry.register(request)
reloaded = OAuthClientRegistry(storage_path)
assert reloaded.get(response["client_id"]) is not None
# ---------------------------------------------------------------------------
# Config validation tests
# ---------------------------------------------------------------------------
+67 -2
View File
@@ -25,13 +25,14 @@ def reset_state(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("OAUTH_CACHE_TTL_SECONDS", "600")
yield
reset_settings()
reset_oauth_validator()
def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
def _build_jwt_fixture(aud: str = "test-client-id") -> tuple[str, dict[str, object]]:
"""Generate RS256 access token and matching JWKS payload."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()
@@ -44,7 +45,7 @@ def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
"sub": "user-1",
"preferred_username": "alice",
"scope": "read:repository write:repository",
"aud": "test-client-id",
"aud": aud,
"iss": "https://gitea.example.com",
"iat": now,
"exp": now + 3600,
@@ -56,6 +57,70 @@ def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
return token, {"keys": [jwk]}
async def _validate_with_jwks(
validator: GiteaOAuthValidator, token: str, jwks: dict[str, object]
) -> tuple[bool, str | None, dict[str, object] | None]:
"""Drive a JWT validation with mocked discovery + JWKS responses."""
discovery_response = MagicMock()
discovery_response.status_code = 200
discovery_response.json.return_value = {
"issuer": "https://gitea.example.com",
"jwks_uri": "https://gitea.example.com/login/oauth/keys",
}
jwks_response = MagicMock()
jwks_response.status_code = 200
jwks_response.json.return_value = jwks
with patch("aegis_gitea_mcp.oauth.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=[discovery_response, jwks_response])
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
return await validator.validate_oauth_token(token, "127.0.0.1", "TestAgent")
def test_acceptable_audiences_includes_resource_and_client_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""The canonical MCP resource and the Gitea client id are accepted audiences."""
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
reset_settings()
reset_oauth_validator()
audiences = GiteaOAuthValidator()._acceptable_audiences()
assert "https://mcp.example.com" in audiences
assert "test-client-id" in audiences
@pytest.mark.asyncio
async def test_jwt_with_canonical_resource_audience_is_accepted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A token whose aud is the canonical MCP resource URL validates (P4)."""
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
reset_settings()
reset_oauth_validator()
token, jwks = _build_jwt_fixture(aud="https://mcp.example.com")
valid, error, principal = await _validate_with_jwks(GiteaOAuthValidator(), token, jwks)
assert valid is True
assert error is None
assert principal is not None
@pytest.mark.asyncio
async def test_jwt_with_foreign_audience_is_rejected() -> None:
"""A token minted for a different audience is rejected (audience binding)."""
token, jwks = _build_jwt_fixture(aud="some-other-service")
# Foreign-audience JWT fails JWT validation, then falls back to userinfo, which
# is not mocked here and raises a network error -> overall failure.
with patch("aegis_gitea_mcp.oauth.GiteaOAuthValidator._validate_userinfo") as mock_userinfo:
from aegis_gitea_mcp.oauth import OAuthTokenValidationError
mock_userinfo.side_effect = OAuthTokenValidationError("Invalid", "userinfo_denied")
valid, _error, principal = await _validate_with_jwks(GiteaOAuthValidator(), token, jwks)
assert valid is False
assert principal is None
@pytest.mark.asyncio
async def test_validate_oauth_token_with_oidc_jwt_and_cache() -> None:
"""JWT token validation uses discovery + JWKS and caches both documents."""
+212 -7
View File
@@ -29,6 +29,7 @@ def oauth_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("ENVIRONMENT", "test")
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
monkeypatch.setenv("WRITE_MODE", "false")
@@ -84,12 +85,13 @@ def test_health_endpoint(client: TestClient) -> None:
def test_oauth_protected_resource_metadata(client: TestClient) -> None:
"""OAuth protected-resource metadata contains required OpenAI-compatible fields."""
"""PRM advertises THIS server's canonical URL as the protected resource."""
response = client.get("/.well-known/oauth-protected-resource")
assert response.status_code == 200
data = response.json()
assert data["resource"] == "https://gitea.example.com"
# RFC 9728/8707: the resource identifier is the MCP server's own URL, not Gitea's.
assert data["resource"] == "http://testserver"
assert data["authorization_servers"] == [
"http://testserver",
"https://gitea.example.com",
@@ -100,12 +102,15 @@ def test_oauth_protected_resource_metadata(client: TestClient) -> None:
def test_oauth_authorization_server_metadata(client: TestClient) -> None:
"""Auth server metadata includes expected OAuth endpoints and scopes."""
"""Auth server metadata advertises this server's proxy OAuth endpoints."""
response = client.get("/.well-known/oauth-authorization-server")
assert response.status_code == 200
payload = response.json()
assert payload["authorization_endpoint"].endswith("/login/oauth/authorize")
assert payload["token_endpoint"].endswith("/oauth/token")
# Claude must be sent to our proxy authorize endpoint (Gitea does not know
# Claude's redirect_uri), so the endpoint lives on this server.
assert payload["authorization_endpoint"] == "http://testserver/oauth/authorize"
assert payload["token_endpoint"] == "http://testserver/oauth/token"
assert payload["registration_endpoint"] == "http://testserver/register"
assert payload["scopes_supported"] == ["read:repository", "write:repository"]
@@ -115,8 +120,8 @@ def test_openid_configuration_metadata(client: TestClient) -> None:
assert response.status_code == 200
payload = response.json()
assert payload["issuer"] == "https://gitea.example.com"
assert payload["authorization_endpoint"].endswith("/login/oauth/authorize")
assert payload["token_endpoint"].endswith("/oauth/token")
assert payload["authorization_endpoint"] == "http://testserver/oauth/authorize"
assert payload["token_endpoint"] == "http://testserver/oauth/token"
assert payload["userinfo_endpoint"].endswith("/login/oauth/userinfo")
assert payload["jwks_uri"].endswith("/login/oauth/keys")
assert "read:repository" in payload["scopes_supported"]
@@ -129,6 +134,7 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
monkeypatch.setenv("ENVIRONMENT", "test")
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
@@ -149,6 +155,8 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
protected_response = client.get("/.well-known/oauth-protected-resource")
assert protected_response.status_code == 200
protected_payload = protected_response.json()
# P4: the protected resource identifier must equal this server's public base.
assert protected_payload["resource"] == "https://mcp.example.com"
assert protected_payload["authorization_servers"] == [
"https://mcp.example.com",
"https://gitea.example.com",
@@ -166,6 +174,201 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
)
def test_mcp_streamable_http_path_works(client: TestClient) -> None:
"""The spec path /mcp exposes the same transport behavior as the SSE alias."""
response = client.post(
"/mcp",
headers={"Authorization": "Bearer valid-read"},
json={"jsonrpc": "2.0", "id": "init-1", "method": "initialize"},
)
assert response.status_code == 200
payload = response.json()
assert payload["result"]["protocolVersion"] == "2024-11-05"
def test_mcp_preflight_allows_same_origin(client: TestClient) -> None:
"""Same-origin preflight requests to /mcp return strict CORS headers."""
response = client.options(
"/mcp",
headers={
"Origin": "http://testserver",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "authorization,content-type",
},
)
assert response.status_code == 204
assert response.headers["Access-Control-Allow-Origin"] == "http://testserver"
def test_mcp_preflight_rejects_cross_origin(client: TestClient) -> None:
"""Cross-origin browser requests to /mcp are denied."""
response = client.options(
"/mcp",
headers={
"Origin": "https://evil.example.com",
"Access-Control-Request-Method": "POST",
},
)
assert response.status_code == 403
def test_service_pat_requests_verify_user_repo_access_before_execution(
oauth_env: None, mock_oauth_validation: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Service PAT fallback checks the user's repository permission before executing tools."""
from aegis_gitea_mcp import server
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
server._api_scope_cache.clear()
server.reset_repo_authz_cache()
probe_response = MagicMock()
probe_response.status_code = 200
repo_response = MagicMock()
repo_response.status_code = 403
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=[probe_response, repo_response])
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
from aegis_gitea_mcp.server import app
client = TestClient(app, raise_server_exceptions=False)
response = client.post(
"/mcp/tool/call",
headers={"Authorization": "Bearer valid-read"},
json={
"tool": "get_repository_info",
"arguments": {"owner": "acme", "repo": "demo"},
},
)
assert response.status_code == 403
assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_service_pat_repo_authz_allows_user_with_read_permission(
oauth_env: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Read-level collaborator permission allows service PAT execution to proceed."""
from aegis_gitea_mcp import server
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
server.reset_repo_authz_cache()
permission_response = MagicMock()
permission_response.status_code = 200
permission_response.json.return_value = {"permission": "read"}
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=permission_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
await server._verify_user_repository_access(
repository="acme/demo",
required_scope=server.READ_SCOPE,
user_login="alice",
correlation_id="corr-1",
tool_name="get_repository_info",
)
mock_client.get.assert_awaited_once()
requested_url = mock_client.get.await_args.args[0]
requested_headers = mock_client.get.await_args.kwargs["headers"]
assert requested_url.endswith("/api/v1/repos/acme/demo/collaborators/alice/permission")
assert requested_headers["Authorization"] == "token service-pat-token"
@pytest.mark.asyncio
async def test_service_pat_repo_authz_denies_read_user_for_write_tool(
oauth_env: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Read permission is insufficient for write tools in service PAT mode."""
from fastapi import HTTPException
from aegis_gitea_mcp import server
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
server.reset_repo_authz_cache()
permission_response = MagicMock()
permission_response.status_code = 200
permission_response.json.return_value = {"permission": "read"}
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=permission_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
with pytest.raises(HTTPException) as exc_info:
await server._verify_user_repository_access(
repository="acme/demo",
required_scope=server.WRITE_SCOPE,
user_login="alice",
correlation_id="corr-1",
tool_name="create_issue",
)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_service_pat_repo_authz_cache_hit_and_expiry(
oauth_env: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Repository permission decisions are cached briefly and rechecked after expiry."""
from aegis_gitea_mcp import cache as cache_module
from aegis_gitea_mcp import server
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
monkeypatch.setenv("REPO_AUTHZ_CACHE_TTL_SECONDS", "1")
server.reset_repo_authz_cache()
now = 1000.0
monkeypatch.setattr(cache_module.time, "monotonic", lambda: now)
permission_response = MagicMock()
permission_response.status_code = 200
permission_response.json.return_value = {"permission": "read"}
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=permission_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
for _ in range(2):
await server._verify_user_repository_access(
repository="acme/demo",
required_scope=server.READ_SCOPE,
user_login="alice",
correlation_id="corr-1",
tool_name="get_repository_info",
)
assert mock_client.get.await_count == 1
now = 1002.0
await server._verify_user_repository_access(
repository="acme/demo",
required_scope=server.READ_SCOPE,
user_login="alice",
correlation_id="corr-1",
tool_name="get_repository_info",
)
assert mock_client.get.await_count == 2
def test_scope_compatibility_write_implies_read() -> None:
"""write:repository grants read-level access for read tools."""
from aegis_gitea_mcp.server import READ_SCOPE, _has_required_scope
@@ -348,6 +551,7 @@ async def test_startup_event_fails_when_discovery_unreachable(
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
from aegis_gitea_mcp import server
@@ -377,6 +581,7 @@ async def test_startup_event_succeeds_when_discovery_ready(
monkeypatch.setenv("OAUTH_MODE", "true")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
from aegis_gitea_mcp import server