feat: harden Claude MCP OAuth transport
This commit is contained in:
@@ -106,12 +106,12 @@ class Settings(BaseSettings):
|
||||
description="Secret detection mode: off, mask, or block",
|
||||
)
|
||||
|
||||
# OAuth2 configuration (for ChatGPT per-user Gitea authentication)
|
||||
# OAuth2 configuration (for per-client Gitea authentication)
|
||||
oauth_mode: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable per-user OAuth2 authentication mode. "
|
||||
"When true, each ChatGPT user authenticates with their own Gitea account. "
|
||||
"When true, each client user authenticates with their own Gitea account. "
|
||||
"GITEA_TOKEN and MCP_API_KEYS are not required in this mode."
|
||||
),
|
||||
)
|
||||
@@ -126,8 +126,9 @@ class Settings(BaseSettings):
|
||||
oauth_expected_audience: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Expected OIDC audience for access tokens. "
|
||||
"Defaults to GITEA_OAUTH_CLIENT_ID when unset."
|
||||
"Additional expected OIDC audience for access tokens. The canonical MCP "
|
||||
"resource URL and the Gitea OAuth client id are always accepted; set this "
|
||||
"to require an extra audience value."
|
||||
),
|
||||
)
|
||||
oauth_cache_ttl_seconds: int = Field(
|
||||
@@ -139,6 +140,37 @@ class Settings(BaseSettings):
|
||||
default="https://hiddenden.cafe/docs/mcp-gitea",
|
||||
description="Public documentation URL for OAuth-protected MCP resource behavior",
|
||||
)
|
||||
oauth_state_secret: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Server secret used to HMAC-sign the OAuth proxy state parameter. "
|
||||
"Required when OAUTH_MODE=true so callback state is tamper-evident."
|
||||
),
|
||||
)
|
||||
oauth_redirect_allowlist_raw: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"Comma-separated additional allowed client redirect URIs for the OAuth "
|
||||
"callback proxy. Claude's callback URLs and loopback URIs are always allowed."
|
||||
),
|
||||
alias="OAUTH_REDIRECT_ALLOWLIST",
|
||||
)
|
||||
dcr_enabled: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Enable RFC 7591 Dynamic Client Registration at /register. Claude's "
|
||||
"connectors register dynamically; disable to require manual client_id/secret."
|
||||
),
|
||||
)
|
||||
dcr_storage_path: Path = Field(
|
||||
default=Path("/var/lib/aegis-mcp/dcr_clients.json"),
|
||||
description="Path to the JSON file that persists dynamically registered clients",
|
||||
)
|
||||
repo_authz_cache_ttl_seconds: int = Field(
|
||||
default=60,
|
||||
description="TTL (seconds) for cached per-user repository permission decisions",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
# Authentication configuration
|
||||
auth_enabled: bool = Field(
|
||||
@@ -269,12 +301,28 @@ class Settings(BaseSettings):
|
||||
"Set ALLOW_INSECURE_BIND=true to explicitly permit this."
|
||||
)
|
||||
|
||||
extra_redirect_uris: list[str] = []
|
||||
if self.oauth_redirect_allowlist_raw.strip():
|
||||
extra_redirect_uris = [
|
||||
value.strip()
|
||||
for value in self.oauth_redirect_allowlist_raw.split(",")
|
||||
if value.strip()
|
||||
]
|
||||
object.__setattr__(self, "_oauth_redirect_allowlist", extra_redirect_uris)
|
||||
|
||||
if self.oauth_mode:
|
||||
# In OAuth mode, per-user Gitea tokens are used; no shared bot token or API keys needed.
|
||||
if not self.gitea_oauth_client_id.strip():
|
||||
raise ValueError("GITEA_OAUTH_CLIENT_ID is required when OAUTH_MODE=true.")
|
||||
if not self.gitea_oauth_client_secret.strip():
|
||||
raise ValueError("GITEA_OAUTH_CLIENT_SECRET is required when OAUTH_MODE=true.")
|
||||
# The proxy state parameter carries the client's redirect_uri across the Gitea
|
||||
# round-trip; it must be HMAC-signed, which requires a server-held secret.
|
||||
if not self.oauth_state_secret.strip():
|
||||
raise ValueError(
|
||||
"OAUTH_STATE_SECRET is required when OAUTH_MODE=true so the OAuth "
|
||||
"proxy state parameter can be HMAC-signed and verified."
|
||||
)
|
||||
else:
|
||||
# Standard API key mode: require bot token and at least one API key.
|
||||
if not self.gitea_token.strip():
|
||||
@@ -308,6 +356,11 @@ class Settings(BaseSettings):
|
||||
"""Get parsed list of repositories allowed for write-mode operations."""
|
||||
return list(getattr(self, "_write_repository_whitelist", []))
|
||||
|
||||
@property
|
||||
def oauth_redirect_allowlist(self) -> list[str]:
|
||||
"""Get parsed list of additional allowed client redirect URIs."""
|
||||
return list(getattr(self, "_oauth_redirect_allowlist", []))
|
||||
|
||||
@property
|
||||
def gitea_base_url(self) -> str:
|
||||
"""Get Gitea base URL as normalized string."""
|
||||
|
||||
@@ -63,7 +63,7 @@ def _tool(
|
||||
AVAILABLE_TOOLS: list[MCPTool] = [
|
||||
_tool(
|
||||
"list_repositories",
|
||||
"List repositories visible to the configured bot account.",
|
||||
"List repositories visible to the authenticated Gitea API token.",
|
||||
{"type": "object", "properties": {}, "required": []},
|
||||
),
|
||||
_tool(
|
||||
|
||||
@@ -177,6 +177,29 @@ class GiteaOAuthValidator:
|
||||
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
|
||||
return jwks
|
||||
|
||||
def _acceptable_audiences(self) -> list[str]:
|
||||
"""Return the set of OIDC audiences this MCP server will accept.
|
||||
|
||||
Per the MCP authorization spec (RFC 8707 / RFC 9728) tokens are bound to
|
||||
the MCP server's canonical resource URL, so the configured public base is
|
||||
the primary accepted audience. The upstream Gitea OAuth client id is also
|
||||
accepted because Gitea — the actual token issuer behind this proxy —
|
||||
stamps ``aud`` with the client id rather than the MCP resource URL. An
|
||||
operator may add a further required audience via OAUTH_EXPECTED_AUDIENCE.
|
||||
"""
|
||||
audiences: list[str] = []
|
||||
canonical_resource = self.settings.public_base
|
||||
if canonical_resource:
|
||||
audiences.append(canonical_resource)
|
||||
gitea_client_id = self.settings.gitea_oauth_client_id.strip()
|
||||
if gitea_client_id:
|
||||
audiences.append(gitea_client_id)
|
||||
configured = self.settings.oauth_expected_audience.strip()
|
||||
if configured:
|
||||
audiences.append(configured)
|
||||
# Preserve order while removing duplicates.
|
||||
return list(dict.fromkeys(audiences))
|
||||
|
||||
async def _validate_jwt(self, token: str) -> dict[str, Any]:
|
||||
"""Validate JWT access token using OIDC discovery and JWKS."""
|
||||
discovery = await self._get_discovery_document()
|
||||
@@ -216,19 +239,16 @@ class GiteaOAuthValidator:
|
||||
"oauth_jwt_invalid_jwk",
|
||||
) from exc
|
||||
|
||||
expected_audience = (
|
||||
self.settings.oauth_expected_audience.strip()
|
||||
or self.settings.gitea_oauth_client_id.strip()
|
||||
)
|
||||
accepted_audiences = self._acceptable_audiences()
|
||||
|
||||
decode_options = cast(Any, {"verify_aud": bool(expected_audience)})
|
||||
decode_options = cast(Any, {"verify_aud": bool(accepted_audiences)})
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key=cast(Any, public_key),
|
||||
algorithms=["RS256"],
|
||||
issuer=issuer,
|
||||
audience=expected_audience or None,
|
||||
audience=accepted_audiences or None,
|
||||
options=decode_options,
|
||||
)
|
||||
except InvalidTokenError as exc:
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
"""OAuth proxy helpers for signed state, redirect validation, and DCR storage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from fnmatch import fnmatchcase
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import ParseResult, urlparse, urlunparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
_CLAUDE_CALLBACK_URIS = {
|
||||
"https://claude.ai/api/mcp/auth_callback",
|
||||
"https://claude.com/api/mcp/auth_callback",
|
||||
}
|
||||
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
|
||||
_SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {"none", "client_secret_post"}
|
||||
_SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"}
|
||||
_SUPPORTED_RESPONSE_TYPES = {"code"}
|
||||
|
||||
|
||||
class OAuthRegistrationRequest(BaseModel):
|
||||
"""Incoming RFC 7591 client registration request."""
|
||||
|
||||
client_name: str | None = Field(default=None, max_length=200)
|
||||
redirect_uris: list[str] = Field(..., min_length=1)
|
||||
grant_types: list[str] = Field(default_factory=lambda: ["authorization_code", "refresh_token"])
|
||||
response_types: list[str] = Field(default_factory=lambda: ["code"])
|
||||
token_endpoint_auth_method: str = Field(default="none", max_length=64)
|
||||
scope: str | None = Field(default=None, max_length=512)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("redirect_uris")
|
||||
@classmethod
|
||||
def validate_redirect_uris(cls, value: list[str]) -> list[str]:
|
||||
"""Normalize and validate redirect URIs."""
|
||||
uris = [uri.strip() for uri in value if isinstance(uri, str) and uri.strip()]
|
||||
if not uris:
|
||||
raise ValueError("redirect_uris must contain at least one non-empty URI")
|
||||
return uris
|
||||
|
||||
@field_validator("grant_types")
|
||||
@classmethod
|
||||
def validate_grant_types(cls, value: list[str]) -> list[str]:
|
||||
"""Restrict supported grant types to authorization code and refresh token."""
|
||||
normalized = [item.strip() for item in value if item.strip()]
|
||||
if not normalized:
|
||||
raise ValueError("grant_types must not be empty")
|
||||
if any(item not in _SUPPORTED_GRANT_TYPES for item in normalized):
|
||||
raise ValueError("Unsupported grant_types requested")
|
||||
return normalized
|
||||
|
||||
@field_validator("response_types")
|
||||
@classmethod
|
||||
def validate_response_types(cls, value: list[str]) -> list[str]:
|
||||
"""Restrict supported response types to authorization code."""
|
||||
normalized = [item.strip() for item in value if item.strip()]
|
||||
if not normalized:
|
||||
raise ValueError("response_types must not be empty")
|
||||
if any(item not in _SUPPORTED_RESPONSE_TYPES for item in normalized):
|
||||
raise ValueError("Unsupported response_types requested")
|
||||
return normalized
|
||||
|
||||
@field_validator("token_endpoint_auth_method")
|
||||
@classmethod
|
||||
def validate_token_endpoint_auth_method(cls, value: str) -> str:
|
||||
"""Restrict token endpoint auth methods to the supported subset."""
|
||||
normalized = value.strip().lower()
|
||||
if normalized not in _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS:
|
||||
raise ValueError("Unsupported token_endpoint_auth_method requested")
|
||||
return normalized
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_pkce_ready(self) -> OAuthRegistrationRequest:
|
||||
"""Ensure the request is usable for PKCE-based authorization code flow."""
|
||||
if "authorization_code" not in self.grant_types:
|
||||
raise ValueError("authorization_code grant is required")
|
||||
if "code" not in self.response_types:
|
||||
raise ValueError("code response type is required")
|
||||
return self
|
||||
|
||||
|
||||
class OAuthClientRecord(BaseModel):
|
||||
"""Persisted OAuth client registration record."""
|
||||
|
||||
client_id: str
|
||||
client_name: str | None = None
|
||||
redirect_uris: list[str]
|
||||
grant_types: list[str]
|
||||
response_types: list[str]
|
||||
token_endpoint_auth_method: str
|
||||
client_id_issued_at: int
|
||||
client_secret_expires_at: int = 0
|
||||
client_secret_hash: str | None = None
|
||||
scope: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def _canonicalize_url(value: str) -> str:
|
||||
"""Normalize a URL for comparison."""
|
||||
parsed = urlparse(value.strip())
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return ""
|
||||
|
||||
normalized = ParseResult(
|
||||
scheme=parsed.scheme.lower(),
|
||||
netloc=parsed.netloc.lower(),
|
||||
path=parsed.path or "/",
|
||||
params=parsed.params,
|
||||
query=parsed.query,
|
||||
fragment="",
|
||||
)
|
||||
return urlunparse(normalized).rstrip("/")
|
||||
|
||||
|
||||
def is_loopback_redirect_uri(redirect_uri: str) -> bool:
|
||||
"""Return whether a redirect URI uses a loopback host."""
|
||||
parsed = urlparse(redirect_uri.strip())
|
||||
if parsed.scheme != "http":
|
||||
return False
|
||||
host = (parsed.hostname or "").lower()
|
||||
return host in _LOOPBACK_HOSTS
|
||||
|
||||
|
||||
def is_claude_redirect_uri(redirect_uri: str) -> bool:
|
||||
"""Return whether a redirect URI is a built-in Claude callback URL."""
|
||||
return _canonicalize_url(redirect_uri) in _CLAUDE_CALLBACK_URIS
|
||||
|
||||
|
||||
def is_redirect_uri_allowed(redirect_uri: str, allowlist: list[str]) -> bool:
|
||||
"""Return whether a redirect URI is allowed by policy."""
|
||||
normalized = _canonicalize_url(redirect_uri)
|
||||
if not normalized:
|
||||
return False
|
||||
|
||||
if is_loopback_redirect_uri(redirect_uri) or is_claude_redirect_uri(redirect_uri):
|
||||
return True
|
||||
|
||||
for pattern in allowlist:
|
||||
candidate = pattern.strip()
|
||||
if not candidate:
|
||||
continue
|
||||
if fnmatchcase(normalized, _canonicalize_url(candidate) or candidate):
|
||||
return True
|
||||
if fnmatchcase(redirect_uri.strip(), candidate):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_origin_allowed(origin: str, request_base: str, public_base: str | None) -> bool:
|
||||
"""Return whether a browser Origin is allowed for MCP transport requests."""
|
||||
normalized_origin = _canonicalize_url(origin)
|
||||
if not normalized_origin:
|
||||
return False
|
||||
|
||||
expected_bases = [request_base.rstrip("/")]
|
||||
if public_base:
|
||||
expected_bases.append(public_base.rstrip("/"))
|
||||
return normalized_origin in expected_bases
|
||||
|
||||
|
||||
def encode_proxy_state(
|
||||
secret: str,
|
||||
redirect_uri: str,
|
||||
original_state: str,
|
||||
*,
|
||||
ttl_seconds: int = 600,
|
||||
) -> str:
|
||||
"""Create a signed OAuth state wrapper for the proxy callback round-trip."""
|
||||
payload = {
|
||||
"redirect_uri": redirect_uri,
|
||||
"state": original_state,
|
||||
"issued_at": int(time.time()),
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
||||
signature = hmac.new(secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256)
|
||||
envelope = {
|
||||
"payload": payload,
|
||||
"signature": signature.hexdigest(),
|
||||
}
|
||||
return base64.urlsafe_b64encode(
|
||||
json.dumps(envelope, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
).decode("ascii")
|
||||
|
||||
|
||||
def decode_proxy_state(secret: str, encoded_state: str) -> dict[str, str]:
|
||||
"""Verify and unpack a signed OAuth state wrapper."""
|
||||
try:
|
||||
raw = base64.urlsafe_b64decode(encoded_state.encode("ascii"))
|
||||
envelope = json.loads(raw)
|
||||
except Exception as exc: # pragma: no cover - guarded by tests
|
||||
raise ValueError("Invalid or missing state parameter") from exc
|
||||
|
||||
if not isinstance(envelope, dict):
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
|
||||
payload = envelope.get("payload")
|
||||
signature = envelope.get("signature")
|
||||
if not isinstance(payload, dict) or not isinstance(signature, str):
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
|
||||
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
||||
expected_signature = hmac.new(
|
||||
secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
|
||||
issued_at = payload.get("issued_at")
|
||||
ttl_seconds = payload.get("ttl_seconds")
|
||||
now = int(time.time())
|
||||
if not isinstance(issued_at, int) or not isinstance(ttl_seconds, int):
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
if issued_at > now or now - issued_at > max(ttl_seconds, 1):
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
|
||||
redirect_uri = payload.get("redirect_uri")
|
||||
if not isinstance(redirect_uri, str) or not redirect_uri.strip():
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
|
||||
original_state = payload.get("state")
|
||||
if not isinstance(original_state, str):
|
||||
raise ValueError("Invalid or missing state parameter")
|
||||
|
||||
return {"redirect_uri": redirect_uri, "state": original_state}
|
||||
|
||||
|
||||
class OAuthClientRegistry:
|
||||
"""Persisted OAuth client registry for dynamic client registration."""
|
||||
|
||||
def __init__(self, storage_path: Path) -> None:
|
||||
"""Initialize registry storage."""
|
||||
self.storage_path = storage_path
|
||||
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._clients: dict[str, OAuthClientRecord] = {}
|
||||
self._loaded = False
|
||||
|
||||
@staticmethod
|
||||
def _hash_secret(secret: str) -> str:
|
||||
"""Hash client secrets before persistence."""
|
||||
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load persisted registrations from disk once."""
|
||||
if self._loaded:
|
||||
return
|
||||
self._loaded = True
|
||||
if not self.storage_path.exists():
|
||||
self._clients = {}
|
||||
return
|
||||
|
||||
raw = json.loads(self.storage_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError("Persisted DCR storage must be a JSON object")
|
||||
|
||||
clients: dict[str, OAuthClientRecord] = {}
|
||||
for client_id, payload in raw.items():
|
||||
if not isinstance(client_id, str):
|
||||
raise ValueError("Persisted client id must be a string")
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError(f"Persisted client record for {client_id} must be a mapping")
|
||||
record = OAuthClientRecord.model_validate({"client_id": client_id, **payload})
|
||||
clients[client_id] = record
|
||||
|
||||
self._clients = clients
|
||||
|
||||
def _persist(self) -> None:
|
||||
"""Write registrations atomically."""
|
||||
payload = {
|
||||
client_id: record.model_dump(mode="json", exclude={"client_id"})
|
||||
for client_id, record in self._clients.items()
|
||||
}
|
||||
tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp")
|
||||
tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2), encoding="utf-8")
|
||||
tmp_path.replace(self.storage_path)
|
||||
|
||||
def get(self, client_id: str) -> OAuthClientRecord | None:
|
||||
"""Look up a registered client by identifier."""
|
||||
self._load()
|
||||
return self._clients.get(client_id)
|
||||
|
||||
def is_known_client(
|
||||
self,
|
||||
client_id: str,
|
||||
*,
|
||||
fallback_client_id: str = "",
|
||||
fallback_client_secret: str = "",
|
||||
) -> bool:
|
||||
"""Return whether a client is recognized by the registry or environment."""
|
||||
if not client_id.strip():
|
||||
return False
|
||||
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
|
||||
return True
|
||||
return self.get(client_id) is not None
|
||||
|
||||
def validate_client_secret(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str | None,
|
||||
*,
|
||||
fallback_client_id: str = "",
|
||||
fallback_client_secret: str = "",
|
||||
) -> bool:
|
||||
"""Validate a client identifier and optional secret."""
|
||||
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
|
||||
if not fallback_client_secret.strip():
|
||||
return True
|
||||
if not client_secret:
|
||||
return False
|
||||
return hmac.compare_digest(
|
||||
self._hash_secret(client_secret), self._hash_secret(fallback_client_secret.strip())
|
||||
)
|
||||
|
||||
record = self.get(client_id)
|
||||
if record is None:
|
||||
return False
|
||||
|
||||
if record.client_secret_hash is None:
|
||||
return True
|
||||
if not client_secret:
|
||||
return False
|
||||
return hmac.compare_digest(self._hash_secret(client_secret), record.client_secret_hash)
|
||||
|
||||
def register(self, request: OAuthRegistrationRequest) -> dict[str, Any]:
|
||||
"""Persist a new OAuth client registration and return its public metadata."""
|
||||
self._load()
|
||||
client_id = secrets.token_urlsafe(24)
|
||||
client_secret: str | None = None
|
||||
client_secret_hash: str | None = None
|
||||
|
||||
if request.token_endpoint_auth_method != "none":
|
||||
client_secret = secrets.token_urlsafe(32)
|
||||
client_secret_hash = self._hash_secret(client_secret)
|
||||
|
||||
record = OAuthClientRecord(
|
||||
client_id=client_id,
|
||||
client_name=request.client_name,
|
||||
redirect_uris=list(request.redirect_uris),
|
||||
grant_types=list(request.grant_types),
|
||||
response_types=list(request.response_types),
|
||||
token_endpoint_auth_method=request.token_endpoint_auth_method,
|
||||
client_id_issued_at=int(time.time()),
|
||||
client_secret_hash=client_secret_hash,
|
||||
scope=request.scope,
|
||||
)
|
||||
self._clients[client_id] = record
|
||||
self._persist()
|
||||
|
||||
response: dict[str, Any] = record.model_dump(exclude={"client_secret_hash"})
|
||||
if client_secret is not None:
|
||||
response["client_secret"] = client_secret
|
||||
response["client_secret_expires_at"] = 0
|
||||
return response
|
||||
|
||||
|
||||
_oauth_client_registry: OAuthClientRegistry | None = None
|
||||
|
||||
|
||||
def get_oauth_client_registry(storage_path: Path) -> OAuthClientRegistry:
|
||||
"""Get or create the global OAuth client registry."""
|
||||
global _oauth_client_registry
|
||||
if _oauth_client_registry is None or _oauth_client_registry.storage_path != storage_path:
|
||||
_oauth_client_registry = OAuthClientRegistry(storage_path)
|
||||
return _oauth_client_registry
|
||||
|
||||
|
||||
def reset_oauth_client_registry() -> None:
|
||||
"""Reset the global OAuth client registry (primarily for tests)."""
|
||||
global _oauth_client_registry
|
||||
_oauth_client_registry = None
|
||||
+385
-43
@@ -3,10 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
@@ -36,11 +36,20 @@ from aegis_gitea_mcp.mcp_protocol import (
|
||||
get_tool_by_name,
|
||||
)
|
||||
from aegis_gitea_mcp.oauth import get_oauth_validator
|
||||
from aegis_gitea_mcp.oauth_flow import (
|
||||
OAuthRegistrationRequest,
|
||||
decode_proxy_state,
|
||||
encode_proxy_state,
|
||||
get_oauth_client_registry,
|
||||
is_origin_allowed,
|
||||
is_redirect_uri_allowed,
|
||||
)
|
||||
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
|
||||
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
|
||||
from aegis_gitea_mcp.rate_limit import get_rate_limiter
|
||||
from aegis_gitea_mcp.request_context import (
|
||||
clear_gitea_auth_context,
|
||||
get_gitea_user_login,
|
||||
get_gitea_user_scopes,
|
||||
get_gitea_user_token,
|
||||
set_gitea_user_login,
|
||||
@@ -94,9 +103,40 @@ _api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache(
|
||||
_REAUTH_GUIDANCE = (
|
||||
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
|
||||
"Revoke the authorization in Gitea (Settings > Applications > Authorized OAuth2 Applications) "
|
||||
"and in ChatGPT (Settings > Connected apps), then re-authorize."
|
||||
"and in your client, then re-authorize."
|
||||
)
|
||||
|
||||
_repo_authz_cache: BoundedTTLCache[str, bool] | None = None
|
||||
|
||||
|
||||
def _get_repo_authz_cache() -> BoundedTTLCache[str, bool]:
|
||||
"""Get the bounded cache for per-user repository permission checks."""
|
||||
global _repo_authz_cache
|
||||
settings = get_settings()
|
||||
if _repo_authz_cache is None:
|
||||
_repo_authz_cache = BoundedTTLCache(
|
||||
ttl_seconds=settings.repo_authz_cache_ttl_seconds,
|
||||
max_size=2048,
|
||||
)
|
||||
return _repo_authz_cache
|
||||
|
||||
|
||||
def reset_repo_authz_cache() -> None:
|
||||
"""Reset the repository authorization cache (primarily for tests)."""
|
||||
global _repo_authz_cache
|
||||
_repo_authz_cache = None
|
||||
|
||||
|
||||
def _repo_authz_cache_key(login: str, repository: str, required_scope: str) -> str:
|
||||
"""Build a bounded cache key for a user/repository permission check."""
|
||||
normalized_login = login.strip().lower()
|
||||
return f"{normalized_login}:{repository.lower()}:{required_scope}"
|
||||
|
||||
|
||||
def _is_mcp_transport_path(path: str) -> bool:
|
||||
"""Return whether a request targets the MCP transport surface."""
|
||||
return path in {"/mcp", "/mcp/sse"} or path.startswith("/mcp/")
|
||||
|
||||
|
||||
def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
||||
"""Return whether granted scopes satisfy the required MCP tool scope."""
|
||||
@@ -116,6 +156,129 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
||||
return required_scope in expanded
|
||||
|
||||
|
||||
def _repo_permission_satisfied(permission: dict[str, Any], required_scope: str) -> bool:
|
||||
"""Return whether a repository permission payload satisfies the requested scope."""
|
||||
permission_name = str(permission.get("permission", "")).lower().strip()
|
||||
if permission_name in {"admin", "owner"}:
|
||||
return True
|
||||
if required_scope == WRITE_SCOPE and permission_name == "write":
|
||||
return True
|
||||
if required_scope == READ_SCOPE and permission_name in {"read", "write"}:
|
||||
return True
|
||||
|
||||
nested_permissions = permission.get("permissions")
|
||||
if isinstance(nested_permissions, dict):
|
||||
return _repo_permission_satisfied(nested_permissions, required_scope)
|
||||
|
||||
if required_scope == WRITE_SCOPE:
|
||||
return bool(permission.get("push") or permission.get("admin"))
|
||||
return bool(permission.get("pull") or permission.get("push") or permission.get("admin"))
|
||||
|
||||
|
||||
async def _verify_user_repository_access(
|
||||
*,
|
||||
repository: str,
|
||||
required_scope: str,
|
||||
user_login: str,
|
||||
correlation_id: str,
|
||||
tool_name: str,
|
||||
) -> None:
|
||||
"""Verify the authenticated user can access the target repository before PAT fallback."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
|
||||
service_token = settings.gitea_token.strip()
|
||||
if not service_token:
|
||||
raise HTTPException(status_code=500, detail="Repository authorization misconfigured")
|
||||
|
||||
if not user_login.strip() or user_login == "unknown":
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason="repository_permission_missing_user",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unable to verify repository permission for this user.",
|
||||
)
|
||||
|
||||
cache_key = _repo_authz_cache_key(user_login, repository, required_scope)
|
||||
cached = _get_repo_authz_cache().get(cache_key)
|
||||
if cached is True:
|
||||
return
|
||||
|
||||
owner, repo = repository.split("/", 1)
|
||||
encoded_owner = urllib.parse.quote(owner, safe="")
|
||||
encoded_repo = urllib.parse.quote(repo, safe="")
|
||||
encoded_user = urllib.parse.quote(user_login, safe="")
|
||||
permission_url = (
|
||||
f"{settings.gitea_base_url}/api/v1/repos/{encoded_owner}/{encoded_repo}"
|
||||
f"/collaborators/{encoded_user}/permission"
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=settings.request_timeout_seconds) as client:
|
||||
response = await client.get(
|
||||
permission_url,
|
||||
headers={"Authorization": f"token {service_token}", "Accept": "application/json"},
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason="repository_permission_probe_failed",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unable to verify repository permission for this user.",
|
||||
) from exc
|
||||
|
||||
if response.status_code != 200:
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason=f"repository_permission_probe:{response.status_code}",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User does not have permission for the requested repository.",
|
||||
)
|
||||
|
||||
try:
|
||||
permission_payload = response.json()
|
||||
except ValueError as exc:
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason="repository_permission_invalid_json",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unable to verify repository permission for this user.",
|
||||
) from exc
|
||||
|
||||
if isinstance(permission_payload, dict) and _repo_permission_satisfied(
|
||||
permission_payload, required_scope
|
||||
):
|
||||
_get_repo_authz_cache().set(cache_key, True)
|
||||
return
|
||||
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason="repository_permission_denied",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User does not have permission for the requested repository.",
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
|
||||
@@ -243,6 +406,67 @@ async def request_context_middleware(
|
||||
metrics.record_http_request(request.method, request.url.path, status_code)
|
||||
|
||||
|
||||
def _cors_headers(origin: str) -> dict[str, str]:
|
||||
"""Build strict CORS headers for a validated browser origin."""
|
||||
return {
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": "GET,POST,OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Authorization,Content-Type,MCP-Protocol-Version,X-Request-ID",
|
||||
"Access-Control-Expose-Headers": "X-Request-ID,WWW-Authenticate",
|
||||
"Vary": "Origin",
|
||||
}
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def strict_origin_and_cors_middleware(
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Enforce strict browser origins for MCP transport requests."""
|
||||
if request.url.path not in {"/mcp", "/mcp/sse"}:
|
||||
return await call_next(request)
|
||||
|
||||
settings = get_settings()
|
||||
origin = request.headers.get("origin")
|
||||
expected_base = settings.public_base or str(request.base_url).rstrip("/")
|
||||
|
||||
if origin and not is_origin_allowed(origin, expected_base, settings.public_base):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "Origin not allowed",
|
||||
"message": "The request origin is not allowed for this MCP transport.",
|
||||
"request_id": getattr(request.state, "request_id", "-"),
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
response = Response(status_code=204)
|
||||
else:
|
||||
response = await call_next(request)
|
||||
|
||||
if origin and is_origin_allowed(origin, expected_base, settings.public_base):
|
||||
for header, value in _cors_headers(origin).items():
|
||||
response.headers[header] = value
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _oauth_invalid_client_response() -> JSONResponse:
|
||||
"""Return an RFC 6749 invalid_client error for token endpoint failures."""
|
||||
response = JSONResponse(status_code=401, content={"error": "invalid_client"})
|
||||
response.headers["WWW-Authenticate"] = 'Basic realm="oauth"'
|
||||
return response
|
||||
|
||||
|
||||
def _jsonrpc_error(message_id: Any, code: int, message: str) -> JSONResponse:
|
||||
"""Build a JSON-RPC error response envelope."""
|
||||
return JSONResponse(
|
||||
content={"jsonrpc": "2.0", "id": message_id, "error": {"code": code, "message": message}}
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def authenticate_and_rate_limit(
|
||||
request: Request,
|
||||
@@ -255,11 +479,14 @@ async def authenticate_and_rate_limit(
|
||||
if request.url.path in {"/", "/health"}:
|
||||
return await call_next(request)
|
||||
|
||||
if request.method == "OPTIONS" and request.url.path in {"/mcp", "/mcp/sse"}:
|
||||
return await call_next(request)
|
||||
|
||||
if request.url.path == "/metrics" and settings.metrics_enabled:
|
||||
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
|
||||
return await call_next(request)
|
||||
|
||||
# OAuth discovery and token endpoints must be public so ChatGPT can complete the flow.
|
||||
# OAuth discovery and token endpoints must be public so MCP clients can complete the flow.
|
||||
if request.url.path in {
|
||||
"/oauth/token",
|
||||
"/.well-known/oauth-protected-resource",
|
||||
@@ -268,7 +495,11 @@ async def authenticate_and_rate_limit(
|
||||
}:
|
||||
return await call_next(request)
|
||||
|
||||
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
|
||||
if not (
|
||||
request.url.path in {"/mcp/tools"}
|
||||
or _is_mcp_transport_path(request.url.path)
|
||||
or request.url.path.startswith("/automation/")
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
oauth_validator = get_oauth_validator()
|
||||
@@ -296,7 +527,7 @@ async def authenticate_and_rate_limit(
|
||||
return await call_next(request)
|
||||
|
||||
if not access_token:
|
||||
if request.url.path.startswith("/mcp/"):
|
||||
if _is_mcp_transport_path(request.url.path):
|
||||
return _oauth_unauthorized_response(
|
||||
request,
|
||||
"Provide Authorization: Bearer <token>.",
|
||||
@@ -315,7 +546,7 @@ async def authenticate_and_rate_limit(
|
||||
access_token, client_ip, user_agent
|
||||
)
|
||||
if not is_valid:
|
||||
if request.url.path.startswith("/mcp/"):
|
||||
if _is_mcp_transport_path(request.url.path):
|
||||
return _oauth_unauthorized_response(
|
||||
request,
|
||||
error_message or "Invalid or expired OAuth token.",
|
||||
@@ -400,7 +631,7 @@ async def authenticate_and_rate_limit(
|
||||
"OAuth token is valid but lacks required Gitea API access. "
|
||||
"Re-authorize this OAuth app in Gitea and try again."
|
||||
)
|
||||
if request.url.path.startswith("/mcp/"):
|
||||
if _is_mcp_transport_path(request.url.path):
|
||||
return _oauth_unauthorized_response(
|
||||
request,
|
||||
message,
|
||||
@@ -508,9 +739,14 @@ async def health() -> dict[str, str]:
|
||||
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
||||
"""OAuth 2.0 Protected Resource Metadata (RFC 9728).
|
||||
|
||||
Required by the MCP Authorization spec so that OAuth clients (e.g. ChatGPT)
|
||||
can discover the authorization server that protects this resource.
|
||||
ChatGPT fetches this endpoint when it first connects to the MCP server via SSE.
|
||||
Required by the MCP Authorization spec so that OAuth clients (Claude's
|
||||
connector infrastructure) can discover the authorization server that
|
||||
protects this resource. Claude fetches this endpoint when it first connects.
|
||||
|
||||
The ``resource`` value MUST be THIS server's own canonical public URL: the
|
||||
MCP client verifies that the resource identifier matches the origin it
|
||||
derived the MCP server URL from (RFC 9728 / RFC 8707). Returning the upstream
|
||||
Gitea URL here would fail that check.
|
||||
"""
|
||||
settings = get_settings()
|
||||
gitea_base = settings.gitea_base_url
|
||||
@@ -521,7 +757,7 @@ async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"resource": gitea_base,
|
||||
"resource": base_url,
|
||||
"authorization_servers": authorization_servers,
|
||||
"bearer_methods_supported": ["header"],
|
||||
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
||||
@@ -534,24 +770,52 @@ async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
||||
async def oauth_authorize_proxy(request: Request) -> RedirectResponse:
|
||||
"""Proxy OAuth authorization to Gitea, replacing redirect_uri with our own callback.
|
||||
|
||||
Clients (ChatGPT, Claude, etc.) send their own redirect_uri which Gitea doesn't know
|
||||
Clients (Claude, Claude Code, Cowork, etc.) send their own redirect_uri which Gitea doesn't know
|
||||
about. This endpoint intercepts the request, encodes the original redirect_uri and
|
||||
state into a new state parameter, and forwards the request to Gitea using the MCP
|
||||
server's own callback URI — the only URI that needs to be registered in Gitea.
|
||||
"""
|
||||
settings = get_settings()
|
||||
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
||||
registry = get_oauth_client_registry(settings.dcr_storage_path)
|
||||
|
||||
params = dict(request.query_params)
|
||||
client_redirect_uri = params.pop("redirect_uri", "")
|
||||
client_redirect_uri = params.pop("redirect_uri", "").strip()
|
||||
client_id = params.get("client_id", "").strip() or settings.gitea_oauth_client_id.strip()
|
||||
original_state = params.get("state", "")
|
||||
params.pop("client_secret", None)
|
||||
|
||||
# Encode the client's redirect_uri + original state into a tamper-evident wrapper.
|
||||
# We simply base64-encode a JSON blob; Gitea will echo it back on the callback.
|
||||
proxy_state_data = {"redirect_uri": client_redirect_uri, "state": original_state}
|
||||
proxy_state = base64.urlsafe_b64encode(json.dumps(proxy_state_data).encode()).decode()
|
||||
if not client_id:
|
||||
raise HTTPException(status_code=400, detail="Missing client_id")
|
||||
if not registry.is_known_client(
|
||||
client_id,
|
||||
fallback_client_id=settings.gitea_oauth_client_id,
|
||||
):
|
||||
raise HTTPException(status_code=401, detail="invalid_client")
|
||||
|
||||
if not client_redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="Missing redirect_uri")
|
||||
if not is_redirect_uri_allowed(client_redirect_uri, settings.oauth_redirect_allowlist):
|
||||
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
|
||||
|
||||
code_challenge = params.get("code_challenge", "").strip()
|
||||
code_challenge_method = params.get("code_challenge_method", "S256").strip().upper()
|
||||
if not code_challenge:
|
||||
raise HTTPException(status_code=400, detail="PKCE code_challenge is required")
|
||||
if code_challenge_method != "S256":
|
||||
raise HTTPException(status_code=400, detail="PKCE code_challenge_method must be S256")
|
||||
|
||||
proxy_state = encode_proxy_state(
|
||||
settings.oauth_state_secret,
|
||||
client_redirect_uri,
|
||||
original_state,
|
||||
ttl_seconds=600,
|
||||
)
|
||||
|
||||
params["client_id"] = settings.gitea_oauth_client_id
|
||||
params["state"] = proxy_state
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
params["redirect_uri"] = f"{base_url}/oauth/callback"
|
||||
|
||||
gitea_authorize_url = f"{settings.gitea_base_url}/login/oauth/authorize"
|
||||
@@ -568,14 +832,17 @@ async def oauth_callback_proxy(request: Request) -> RedirectResponse:
|
||||
error_description = request.query_params.get("error_description", "")
|
||||
|
||||
try:
|
||||
state_data = json.loads(base64.urlsafe_b64decode(proxy_state.encode()))
|
||||
state_data = decode_proxy_state(get_settings().oauth_state_secret, proxy_state)
|
||||
client_redirect_uri = state_data["redirect_uri"]
|
||||
original_state = state_data["state"]
|
||||
except Exception as exc:
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid or missing state parameter") from exc
|
||||
|
||||
settings = get_settings()
|
||||
if not client_redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="No client redirect_uri in state")
|
||||
if not is_redirect_uri_allowed(client_redirect_uri, settings.oauth_redirect_allowlist):
|
||||
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
|
||||
|
||||
result_params: dict[str, str] = {}
|
||||
if error:
|
||||
@@ -595,26 +862,31 @@ async def oauth_callback_proxy(request: Request) -> RedirectResponse:
|
||||
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
|
||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Proxies Gitea's OAuth authorization server metadata so that ChatGPT can
|
||||
discover the authorize URL, token URL, and supported features directly
|
||||
from this server without needing to know the Gitea URL upfront.
|
||||
Advertises this server's OAuth proxy endpoints so that Claude's connector
|
||||
infrastructure can discover the authorize URL, token URL, and dynamic client
|
||||
registration endpoint directly from this server without knowing the Gitea URL
|
||||
upfront. The authorize/token endpoints are this server's proxy routes because
|
||||
Gitea does not know Claude's redirect_uri.
|
||||
"""
|
||||
settings = get_settings()
|
||||
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
||||
gitea_base = settings.gitea_base_url
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"issuer": gitea_base,
|
||||
"authorization_endpoint": f"{base_url}/oauth/authorize",
|
||||
"token_endpoint": f"{base_url}/oauth/token",
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||
}
|
||||
)
|
||||
metadata: dict[str, Any] = {
|
||||
"issuer": gitea_base,
|
||||
"authorization_endpoint": f"{base_url}/oauth/authorize",
|
||||
"token_endpoint": f"{base_url}/oauth/token",
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||
}
|
||||
if settings.dcr_enabled:
|
||||
# RFC 7591 dynamic client registration endpoint (Claude registers here).
|
||||
metadata["registration_endpoint"] = f"{base_url}/register"
|
||||
|
||||
return JSONResponse(content=metadata)
|
||||
|
||||
|
||||
@app.get("/.well-known/openid-configuration")
|
||||
@@ -649,15 +921,47 @@ async def openid_configuration(request: Request) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def oauth_dynamic_client_registration(request: Request) -> JSONResponse:
|
||||
"""Persist a new OAuth client registration for Claude and similar MCP clients."""
|
||||
settings = get_settings()
|
||||
if not settings.dcr_enabled:
|
||||
raise HTTPException(status_code=404, detail="Dynamic client registration is disabled")
|
||||
|
||||
content_type = request.headers.get("content-type", "").split(";", 1)[0].strip().lower()
|
||||
if content_type != "application/json":
|
||||
raise HTTPException(status_code=415, detail="Content-Type must be application/json")
|
||||
|
||||
registry = get_oauth_client_registry(settings.dcr_storage_path)
|
||||
|
||||
try:
|
||||
payload = await request.json()
|
||||
registration_request = OAuthRegistrationRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid registration payload") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid registration payload") from exc
|
||||
|
||||
for redirect_uri in registration_request.redirect_uris:
|
||||
if not is_redirect_uri_allowed(redirect_uri, settings.oauth_redirect_allowlist):
|
||||
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
|
||||
|
||||
response = registry.register(registration_request)
|
||||
response["client_id_issued_at"] = int(time.time())
|
||||
response["client_secret_expires_at"] = 0
|
||||
return JSONResponse(content=response)
|
||||
|
||||
|
||||
@app.post("/oauth/token")
|
||||
async def oauth_token_proxy(request: Request) -> JSONResponse:
|
||||
"""Proxy OAuth2 token exchange to Gitea.
|
||||
|
||||
ChatGPT sends the authorization code here after the user logs in to Gitea.
|
||||
The client sends the authorization code here after the user logs in to Gitea.
|
||||
This endpoint forwards the code to Gitea's token endpoint and returns the
|
||||
access_token to ChatGPT, completing the OAuth2 Authorization Code flow.
|
||||
access_token to the client, completing the OAuth2 Authorization Code flow.
|
||||
"""
|
||||
settings = get_settings()
|
||||
registry = get_oauth_client_registry(settings.dcr_storage_path)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
@@ -683,32 +987,45 @@ async def oauth_token_proxy(request: Request) -> JSONResponse:
|
||||
# URI to Gitea, we must use the same URI here — not the client's original redirect_uri.
|
||||
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
||||
|
||||
if not client_id:
|
||||
return _oauth_invalid_client_response()
|
||||
if not registry.validate_client_secret(
|
||||
client_id,
|
||||
client_secret or None,
|
||||
fallback_client_id=settings.gitea_oauth_client_id,
|
||||
fallback_client_secret=settings.gitea_oauth_client_secret,
|
||||
):
|
||||
return _oauth_invalid_client_response()
|
||||
|
||||
gitea_token_url = f"{settings.gitea_base_url}/login/oauth/access_token"
|
||||
upstream_client_id = settings.gitea_oauth_client_id
|
||||
upstream_client_secret = settings.gitea_oauth_client_secret
|
||||
|
||||
if grant_type == "refresh_token":
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=400, detail="Missing refresh_token")
|
||||
payload: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"client_id": upstream_client_id,
|
||||
"client_secret": upstream_client_secret,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
else:
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||
if not code_verifier:
|
||||
raise HTTPException(status_code=400, detail="Missing code_verifier")
|
||||
payload = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"client_id": upstream_client_id,
|
||||
"client_secret": upstream_client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{base_url}/oauth/callback",
|
||||
}
|
||||
if code_verifier:
|
||||
payload["code_verifier"] = code_verifier
|
||||
payload["code_verifier"] = code_verifier
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
async with httpx.AsyncClient(timeout=settings.request_timeout_seconds) as client:
|
||||
response = await client.post(
|
||||
gitea_token_url,
|
||||
data=payload,
|
||||
@@ -846,6 +1163,29 @@ async def _execute_tool_call(
|
||||
if not user_token:
|
||||
raise HTTPException(status_code=401, detail="Missing authenticated user token context")
|
||||
|
||||
if settings.gitea_token.strip():
|
||||
if not repository:
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
reason="service_pat_requires_repository_target",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
"Service PAT mode requires a repository target so per-user "
|
||||
"permission can be verified."
|
||||
),
|
||||
)
|
||||
user_login = get_gitea_user_login()
|
||||
await _verify_user_repository_access(
|
||||
repository=repository,
|
||||
required_scope=required_scope,
|
||||
user_login=user_login or "",
|
||||
correlation_id=correlation_id,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
|
||||
# In OAuth mode, Gitea OIDC access_tokens can't call the Gitea REST API
|
||||
# (they only carry OIDC scopes). If a service PAT is configured via
|
||||
# GITEA_TOKEN, use that for API calls while OIDC handles identity/authz.
|
||||
@@ -970,6 +1310,7 @@ async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.get("/mcp")
|
||||
@app.get("/mcp/sse")
|
||||
async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
"""Server-Sent Events endpoint for MCP transport."""
|
||||
@@ -1004,6 +1345,7 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.post("/mcp")
|
||||
@app.post("/mcp/sse")
|
||||
async def sse_message_handler(request: Request) -> JSONResponse:
|
||||
"""Handle POST messages for MCP SSE transport."""
|
||||
|
||||
@@ -18,7 +18,7 @@ from aegis_gitea_mcp.tools.arguments import (
|
||||
|
||||
|
||||
async def list_repositories_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""List repositories visible to the bot user.
|
||||
"""List repositories visible to the active Gitea API token.
|
||||
|
||||
Args:
|
||||
gitea: Initialized Gitea client.
|
||||
|
||||
@@ -9,9 +9,11 @@ from aegis_gitea_mcp.audit import reset_audit_logger
|
||||
from aegis_gitea_mcp.auth import reset_validator
|
||||
from aegis_gitea_mcp.config import reset_settings
|
||||
from aegis_gitea_mcp.oauth import reset_oauth_validator
|
||||
from aegis_gitea_mcp.oauth_flow import reset_oauth_client_registry
|
||||
from aegis_gitea_mcp.observability import reset_metrics_registry
|
||||
from aegis_gitea_mcp.policy import reset_policy_engine
|
||||
from aegis_gitea_mcp.rate_limit import reset_rate_limiter
|
||||
from aegis_gitea_mcp.server import reset_repo_authz_cache
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -22,6 +24,8 @@ def reset_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[
|
||||
reset_audit_logger()
|
||||
reset_validator()
|
||||
reset_oauth_validator()
|
||||
reset_oauth_client_registry()
|
||||
reset_repo_authz_cache()
|
||||
reset_policy_engine()
|
||||
reset_rate_limiter()
|
||||
reset_metrics_registry()
|
||||
@@ -37,6 +41,8 @@ def reset_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[
|
||||
reset_audit_logger()
|
||||
reset_validator()
|
||||
reset_oauth_validator()
|
||||
reset_oauth_client_registry()
|
||||
reset_repo_authz_cache()
|
||||
reset_policy_engine()
|
||||
reset_rate_limiter()
|
||||
reset_metrics_registry()
|
||||
@@ -66,4 +72,5 @@ def mock_env_oauth(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||
|
||||
@@ -28,6 +28,7 @@ def full_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("ENVIRONMENT", "test")
|
||||
monkeypatch.setenv("MCP_HOST", "127.0.0.1")
|
||||
monkeypatch.setenv("MCP_PORT", "8080")
|
||||
|
||||
+240
-4
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from aegis_gitea_mcp.config import reset_settings
|
||||
from aegis_gitea_mcp.oauth import GiteaOAuthValidator, get_oauth_validator, reset_oauth_validator
|
||||
from aegis_gitea_mcp.oauth_flow import OAuthClientRegistry, OAuthRegistrationRequest
|
||||
from aegis_gitea_mcp.request_context import (
|
||||
get_gitea_user_login,
|
||||
get_gitea_user_token,
|
||||
@@ -40,6 +41,7 @@ def mock_env_oauth(monkeypatch):
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||
|
||||
|
||||
@@ -57,6 +59,24 @@ def oauth_client(mock_env_oauth):
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _register_public_client(oauth_client: TestClient, redirect_uri: str) -> dict[str, str]:
|
||||
"""Register a public OAuth client for test flows."""
|
||||
response = oauth_client.post(
|
||||
"/register",
|
||||
json={
|
||||
"client_name": "pytest-client",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert "client_id" in payload
|
||||
return payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GiteaOAuthValidator unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -248,19 +268,39 @@ def test_oauth_token_endpoint_available_when_oauth_mode_false(monkeypatch):
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.post("/oauth/token", data={"code": "abc123"})
|
||||
registration = client.post(
|
||||
"/register",
|
||||
json={
|
||||
"client_name": "pytest-client",
|
||||
"redirect_uris": ["http://127.0.0.1:8080/callback"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
},
|
||||
)
|
||||
assert registration.status_code == 200
|
||||
client_id = registration.json()["client_id"]
|
||||
response = client.post(
|
||||
"/oauth/token",
|
||||
data={"client_id": client_id, "code": "abc123", "code_verifier": "pkce"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_oauth_token_endpoint_missing_code(oauth_client):
|
||||
"""POST /oauth/token without a code returns 400."""
|
||||
response = oauth_client.post("/oauth/token", data={})
|
||||
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||
response = oauth_client.post(
|
||||
"/oauth/token",
|
||||
data={"client_id": client_data["client_id"], "code_verifier": "pkce"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_oauth_token_endpoint_proxy_success(oauth_client):
|
||||
"""POST /oauth/token proxies successfully to Gitea and returns access_token."""
|
||||
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
@@ -276,7 +316,11 @@ def test_oauth_token_endpoint_proxy_success(oauth_client):
|
||||
|
||||
response = oauth_client.post(
|
||||
"/oauth/token",
|
||||
data={"code": "auth-code-123", "redirect_uri": "https://chat.openai.com/callback"},
|
||||
data={
|
||||
"client_id": client_data["client_id"],
|
||||
"code": "auth-code-123",
|
||||
"code_verifier": "pkce-verifier",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -286,6 +330,7 @@ def test_oauth_token_endpoint_proxy_success(oauth_client):
|
||||
|
||||
def test_oauth_token_endpoint_gitea_error(oauth_client):
|
||||
"""POST /oauth/token propagates Gitea error status."""
|
||||
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {"error": "invalid_grant"}
|
||||
@@ -296,11 +341,202 @@ def test_oauth_token_endpoint_gitea_error(oauth_client):
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
response = oauth_client.post("/oauth/token", data={"code": "bad-code"})
|
||||
response = oauth_client.post(
|
||||
"/oauth/token",
|
||||
data={
|
||||
"client_id": client_data["client_id"],
|
||||
"code": "bad-code",
|
||||
"code_verifier": "pkce-verifier",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_oauth_authorize_and_callback_round_trip(oauth_client):
|
||||
"""OAuth authorize/callback round-trip preserves the original redirect URI and state."""
|
||||
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||
|
||||
authorize_response = oauth_client.get(
|
||||
"/oauth/authorize",
|
||||
params={
|
||||
"client_id": client_data["client_id"],
|
||||
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||
"state": "original-state",
|
||||
"code_challenge": "pkce-challenge",
|
||||
"code_challenge_method": "S256",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert authorize_response.status_code == 302
|
||||
location = authorize_response.headers["location"]
|
||||
assert "state=" in location
|
||||
assert "redirect_uri=http%3A%2F%2F127.0.0.1%3A8080%2Fcallback" not in location
|
||||
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
parsed = urlparse(location)
|
||||
query = parse_qs(parsed.query)
|
||||
proxy_state = query["state"][0]
|
||||
|
||||
callback_response = oauth_client.get(
|
||||
"/oauth/callback",
|
||||
params={"state": proxy_state, "code": "auth-code-123"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert callback_response.status_code == 302
|
||||
callback_location = callback_response.headers["location"]
|
||||
assert callback_location.startswith("http://127.0.0.1:8080/callback?")
|
||||
assert "code=auth-code-123" in callback_location
|
||||
assert "state=original-state" in callback_location
|
||||
|
||||
|
||||
def test_oauth_callback_rejects_tampered_state(oauth_client):
|
||||
"""OAuth callback rejects modified signed proxy state."""
|
||||
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||
authorize_response = oauth_client.get(
|
||||
"/oauth/authorize",
|
||||
params={
|
||||
"client_id": client_data["client_id"],
|
||||
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||
"state": "original-state",
|
||||
"code_challenge": "pkce-challenge",
|
||||
"code_challenge_method": "S256",
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
proxy_state = parse_qs(urlparse(authorize_response.headers["location"]).query)["state"][0]
|
||||
tampered_state = proxy_state[:-1] + ("A" if proxy_state[-1] != "A" else "B")
|
||||
|
||||
callback_response = oauth_client.get(
|
||||
"/oauth/callback",
|
||||
params={"state": tampered_state, "code": "auth-code-123"},
|
||||
)
|
||||
|
||||
assert callback_response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"redirect_uri",
|
||||
[
|
||||
"https://claude.ai/api/mcp/auth_callback",
|
||||
"https://claude.com/api/mcp/auth_callback",
|
||||
],
|
||||
)
|
||||
def test_dcr_accepts_default_claude_callbacks(oauth_client, redirect_uri):
|
||||
"""Claude's hosted connector callback URLs are allowed by default."""
|
||||
response = oauth_client.post(
|
||||
"/register",
|
||||
json={
|
||||
"client_name": "claude-client",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_oauth_authorize_rejects_unknown_client(oauth_client):
|
||||
"""OAuth authorize returns invalid_client for unregistered client IDs."""
|
||||
response = oauth_client.get(
|
||||
"/oauth/authorize",
|
||||
params={
|
||||
"client_id": "unknown-client",
|
||||
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||
"state": "x",
|
||||
"code_challenge": "pkce-challenge",
|
||||
"code_challenge_method": "S256",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "invalid_client"
|
||||
|
||||
|
||||
def test_oauth_token_rejects_unknown_dcr_client(oauth_client):
|
||||
"""Unknown dynamic clients receive RFC 6749 invalid_client from token endpoint."""
|
||||
response = oauth_client.post(
|
||||
"/oauth/token",
|
||||
data={
|
||||
"client_id": "deleted-or-unknown-client",
|
||||
"code": "auth-code-123",
|
||||
"code_verifier": "pkce-verifier",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"error": "invalid_client"}
|
||||
|
||||
|
||||
def test_oauth_authorize_requires_pkce_s256(oauth_client):
|
||||
"""Authorization endpoint enforces PKCE S256 for public clients."""
|
||||
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||
missing_challenge = oauth_client.get(
|
||||
"/oauth/authorize",
|
||||
params={
|
||||
"client_id": client_data["client_id"],
|
||||
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||
"state": "x",
|
||||
},
|
||||
)
|
||||
plain_method = oauth_client.get(
|
||||
"/oauth/authorize",
|
||||
params={
|
||||
"client_id": client_data["client_id"],
|
||||
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||
"state": "x",
|
||||
"code_challenge": "pkce-challenge",
|
||||
"code_challenge_method": "plain",
|
||||
},
|
||||
)
|
||||
|
||||
assert missing_challenge.status_code == 400
|
||||
assert plain_method.status_code == 400
|
||||
|
||||
|
||||
def test_register_rejects_foreign_redirect_uri(oauth_client):
|
||||
"""DCR rejects redirect URIs outside the allowlist and loopback/Claude patterns."""
|
||||
response = oauth_client.post(
|
||||
"/register",
|
||||
json={
|
||||
"client_name": "pytest-client",
|
||||
"redirect_uris": ["https://evil.example.com/callback"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_dcr_registry_persists_registered_clients(tmp_path):
|
||||
"""Registered OAuth clients survive registry reloads."""
|
||||
storage_path = tmp_path / "dcr_clients.json"
|
||||
registry = OAuthClientRegistry(storage_path)
|
||||
request = OAuthRegistrationRequest.model_validate(
|
||||
{
|
||||
"client_name": "persisted-client",
|
||||
"redirect_uris": ["http://127.0.0.1:8080/callback"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
}
|
||||
)
|
||||
|
||||
response = registry.register(request)
|
||||
reloaded = OAuthClientRegistry(storage_path)
|
||||
|
||||
assert reloaded.get(response["client_id"]) is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -25,13 +25,14 @@ def reset_state(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("OAUTH_CACHE_TTL_SECONDS", "600")
|
||||
yield
|
||||
reset_settings()
|
||||
reset_oauth_validator()
|
||||
|
||||
|
||||
def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
|
||||
def _build_jwt_fixture(aud: str = "test-client-id") -> tuple[str, dict[str, object]]:
|
||||
"""Generate RS256 access token and matching JWKS payload."""
|
||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
public_key = private_key.public_key()
|
||||
@@ -44,7 +45,7 @@ def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
|
||||
"sub": "user-1",
|
||||
"preferred_username": "alice",
|
||||
"scope": "read:repository write:repository",
|
||||
"aud": "test-client-id",
|
||||
"aud": aud,
|
||||
"iss": "https://gitea.example.com",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
@@ -56,6 +57,70 @@ def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
|
||||
return token, {"keys": [jwk]}
|
||||
|
||||
|
||||
async def _validate_with_jwks(
|
||||
validator: GiteaOAuthValidator, token: str, jwks: dict[str, object]
|
||||
) -> tuple[bool, str | None, dict[str, object] | None]:
|
||||
"""Drive a JWT validation with mocked discovery + JWKS responses."""
|
||||
discovery_response = MagicMock()
|
||||
discovery_response.status_code = 200
|
||||
discovery_response.json.return_value = {
|
||||
"issuer": "https://gitea.example.com",
|
||||
"jwks_uri": "https://gitea.example.com/login/oauth/keys",
|
||||
}
|
||||
jwks_response = MagicMock()
|
||||
jwks_response.status_code = 200
|
||||
jwks_response.json.return_value = jwks
|
||||
|
||||
with patch("aegis_gitea_mcp.oauth.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=[discovery_response, jwks_response])
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
return await validator.validate_oauth_token(token, "127.0.0.1", "TestAgent")
|
||||
|
||||
|
||||
def test_acceptable_audiences_includes_resource_and_client_id(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""The canonical MCP resource and the Gitea client id are accepted audiences."""
|
||||
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
||||
reset_settings()
|
||||
reset_oauth_validator()
|
||||
audiences = GiteaOAuthValidator()._acceptable_audiences()
|
||||
assert "https://mcp.example.com" in audiences
|
||||
assert "test-client-id" in audiences
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_with_canonical_resource_audience_is_accepted(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A token whose aud is the canonical MCP resource URL validates (P4)."""
|
||||
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
||||
reset_settings()
|
||||
reset_oauth_validator()
|
||||
token, jwks = _build_jwt_fixture(aud="https://mcp.example.com")
|
||||
valid, error, principal = await _validate_with_jwks(GiteaOAuthValidator(), token, jwks)
|
||||
assert valid is True
|
||||
assert error is None
|
||||
assert principal is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_with_foreign_audience_is_rejected() -> None:
|
||||
"""A token minted for a different audience is rejected (audience binding)."""
|
||||
token, jwks = _build_jwt_fixture(aud="some-other-service")
|
||||
# Foreign-audience JWT fails JWT validation, then falls back to userinfo, which
|
||||
# is not mocked here and raises a network error -> overall failure.
|
||||
with patch("aegis_gitea_mcp.oauth.GiteaOAuthValidator._validate_userinfo") as mock_userinfo:
|
||||
from aegis_gitea_mcp.oauth import OAuthTokenValidationError
|
||||
|
||||
mock_userinfo.side_effect = OAuthTokenValidationError("Invalid", "userinfo_denied")
|
||||
valid, _error, principal = await _validate_with_jwks(GiteaOAuthValidator(), token, jwks)
|
||||
assert valid is False
|
||||
assert principal is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_oauth_token_with_oidc_jwt_and_cache() -> None:
|
||||
"""JWT token validation uses discovery + JWKS and caches both documents."""
|
||||
|
||||
+212
-7
@@ -29,6 +29,7 @@ def oauth_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("ENVIRONMENT", "test")
|
||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||
monkeypatch.setenv("WRITE_MODE", "false")
|
||||
@@ -84,12 +85,13 @@ def test_health_endpoint(client: TestClient) -> None:
|
||||
|
||||
|
||||
def test_oauth_protected_resource_metadata(client: TestClient) -> None:
|
||||
"""OAuth protected-resource metadata contains required OpenAI-compatible fields."""
|
||||
"""PRM advertises THIS server's canonical URL as the protected resource."""
|
||||
response = client.get("/.well-known/oauth-protected-resource")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["resource"] == "https://gitea.example.com"
|
||||
# RFC 9728/8707: the resource identifier is the MCP server's own URL, not Gitea's.
|
||||
assert data["resource"] == "http://testserver"
|
||||
assert data["authorization_servers"] == [
|
||||
"http://testserver",
|
||||
"https://gitea.example.com",
|
||||
@@ -100,12 +102,15 @@ def test_oauth_protected_resource_metadata(client: TestClient) -> None:
|
||||
|
||||
|
||||
def test_oauth_authorization_server_metadata(client: TestClient) -> None:
|
||||
"""Auth server metadata includes expected OAuth endpoints and scopes."""
|
||||
"""Auth server metadata advertises this server's proxy OAuth endpoints."""
|
||||
response = client.get("/.well-known/oauth-authorization-server")
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["authorization_endpoint"].endswith("/login/oauth/authorize")
|
||||
assert payload["token_endpoint"].endswith("/oauth/token")
|
||||
# Claude must be sent to our proxy authorize endpoint (Gitea does not know
|
||||
# Claude's redirect_uri), so the endpoint lives on this server.
|
||||
assert payload["authorization_endpoint"] == "http://testserver/oauth/authorize"
|
||||
assert payload["token_endpoint"] == "http://testserver/oauth/token"
|
||||
assert payload["registration_endpoint"] == "http://testserver/register"
|
||||
assert payload["scopes_supported"] == ["read:repository", "write:repository"]
|
||||
|
||||
|
||||
@@ -115,8 +120,8 @@ def test_openid_configuration_metadata(client: TestClient) -> None:
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["issuer"] == "https://gitea.example.com"
|
||||
assert payload["authorization_endpoint"].endswith("/login/oauth/authorize")
|
||||
assert payload["token_endpoint"].endswith("/oauth/token")
|
||||
assert payload["authorization_endpoint"] == "http://testserver/oauth/authorize"
|
||||
assert payload["token_endpoint"] == "http://testserver/oauth/token"
|
||||
assert payload["userinfo_endpoint"].endswith("/login/oauth/userinfo")
|
||||
assert payload["jwks_uri"].endswith("/login/oauth/keys")
|
||||
assert "read:repository" in payload["scopes_supported"]
|
||||
@@ -129,6 +134,7 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
||||
monkeypatch.setenv("ENVIRONMENT", "test")
|
||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||
@@ -149,6 +155,8 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
|
||||
protected_response = client.get("/.well-known/oauth-protected-resource")
|
||||
assert protected_response.status_code == 200
|
||||
protected_payload = protected_response.json()
|
||||
# P4: the protected resource identifier must equal this server's public base.
|
||||
assert protected_payload["resource"] == "https://mcp.example.com"
|
||||
assert protected_payload["authorization_servers"] == [
|
||||
"https://mcp.example.com",
|
||||
"https://gitea.example.com",
|
||||
@@ -166,6 +174,201 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_streamable_http_path_works(client: TestClient) -> None:
|
||||
"""The spec path /mcp exposes the same transport behavior as the SSE alias."""
|
||||
response = client.post(
|
||||
"/mcp",
|
||||
headers={"Authorization": "Bearer valid-read"},
|
||||
json={"jsonrpc": "2.0", "id": "init-1", "method": "initialize"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["result"]["protocolVersion"] == "2024-11-05"
|
||||
|
||||
|
||||
def test_mcp_preflight_allows_same_origin(client: TestClient) -> None:
|
||||
"""Same-origin preflight requests to /mcp return strict CORS headers."""
|
||||
response = client.options(
|
||||
"/mcp",
|
||||
headers={
|
||||
"Origin": "http://testserver",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "authorization,content-type",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert response.headers["Access-Control-Allow-Origin"] == "http://testserver"
|
||||
|
||||
|
||||
def test_mcp_preflight_rejects_cross_origin(client: TestClient) -> None:
|
||||
"""Cross-origin browser requests to /mcp are denied."""
|
||||
response = client.options(
|
||||
"/mcp",
|
||||
headers={
|
||||
"Origin": "https://evil.example.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_service_pat_requests_verify_user_repo_access_before_execution(
|
||||
oauth_env: None, mock_oauth_validation: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Service PAT fallback checks the user's repository permission before executing tools."""
|
||||
from aegis_gitea_mcp import server
|
||||
|
||||
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||
server._api_scope_cache.clear()
|
||||
server.reset_repo_authz_cache()
|
||||
|
||||
probe_response = MagicMock()
|
||||
probe_response.status_code = 200
|
||||
|
||||
repo_response = MagicMock()
|
||||
repo_response.status_code = 403
|
||||
|
||||
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=[probe_response, repo_response])
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
from aegis_gitea_mcp.server import app
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.post(
|
||||
"/mcp/tool/call",
|
||||
headers={"Authorization": "Bearer valid-read"},
|
||||
json={
|
||||
"tool": "get_repository_info",
|
||||
"arguments": {"owner": "acme", "repo": "demo"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "permission" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_pat_repo_authz_allows_user_with_read_permission(
|
||||
oauth_env: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Read-level collaborator permission allows service PAT execution to proceed."""
|
||||
from aegis_gitea_mcp import server
|
||||
|
||||
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||
server.reset_repo_authz_cache()
|
||||
|
||||
permission_response = MagicMock()
|
||||
permission_response.status_code = 200
|
||||
permission_response.json.return_value = {"permission": "read"}
|
||||
|
||||
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=permission_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
await server._verify_user_repository_access(
|
||||
repository="acme/demo",
|
||||
required_scope=server.READ_SCOPE,
|
||||
user_login="alice",
|
||||
correlation_id="corr-1",
|
||||
tool_name="get_repository_info",
|
||||
)
|
||||
|
||||
mock_client.get.assert_awaited_once()
|
||||
requested_url = mock_client.get.await_args.args[0]
|
||||
requested_headers = mock_client.get.await_args.kwargs["headers"]
|
||||
assert requested_url.endswith("/api/v1/repos/acme/demo/collaborators/alice/permission")
|
||||
assert requested_headers["Authorization"] == "token service-pat-token"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_pat_repo_authz_denies_read_user_for_write_tool(
|
||||
oauth_env: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Read permission is insufficient for write tools in service PAT mode."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from aegis_gitea_mcp import server
|
||||
|
||||
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||
server.reset_repo_authz_cache()
|
||||
|
||||
permission_response = MagicMock()
|
||||
permission_response.status_code = 200
|
||||
permission_response.json.return_value = {"permission": "read"}
|
||||
|
||||
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=permission_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await server._verify_user_repository_access(
|
||||
repository="acme/demo",
|
||||
required_scope=server.WRITE_SCOPE,
|
||||
user_login="alice",
|
||||
correlation_id="corr-1",
|
||||
tool_name="create_issue",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_pat_repo_authz_cache_hit_and_expiry(
|
||||
oauth_env: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Repository permission decisions are cached briefly and rechecked after expiry."""
|
||||
from aegis_gitea_mcp import cache as cache_module
|
||||
from aegis_gitea_mcp import server
|
||||
|
||||
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||
monkeypatch.setenv("REPO_AUTHZ_CACHE_TTL_SECONDS", "1")
|
||||
server.reset_repo_authz_cache()
|
||||
|
||||
now = 1000.0
|
||||
monkeypatch.setattr(cache_module.time, "monotonic", lambda: now)
|
||||
|
||||
permission_response = MagicMock()
|
||||
permission_response.status_code = 200
|
||||
permission_response.json.return_value = {"permission": "read"}
|
||||
|
||||
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=permission_response)
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
for _ in range(2):
|
||||
await server._verify_user_repository_access(
|
||||
repository="acme/demo",
|
||||
required_scope=server.READ_SCOPE,
|
||||
user_login="alice",
|
||||
correlation_id="corr-1",
|
||||
tool_name="get_repository_info",
|
||||
)
|
||||
assert mock_client.get.await_count == 1
|
||||
|
||||
now = 1002.0
|
||||
await server._verify_user_repository_access(
|
||||
repository="acme/demo",
|
||||
required_scope=server.READ_SCOPE,
|
||||
user_login="alice",
|
||||
correlation_id="corr-1",
|
||||
tool_name="get_repository_info",
|
||||
)
|
||||
|
||||
assert mock_client.get.await_count == 2
|
||||
|
||||
|
||||
def test_scope_compatibility_write_implies_read() -> None:
|
||||
"""write:repository grants read-level access for read tools."""
|
||||
from aegis_gitea_mcp.server import READ_SCOPE, _has_required_scope
|
||||
@@ -348,6 +551,7 @@ async def test_startup_event_fails_when_discovery_unreachable(
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
|
||||
|
||||
from aegis_gitea_mcp import server
|
||||
@@ -377,6 +581,7 @@ async def test_startup_event_succeeds_when_discovery_ready(
|
||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
|
||||
|
||||
from aegis_gitea_mcp import server
|
||||
|
||||
Reference in New Issue
Block a user