feat: harden Claude MCP OAuth transport
This commit is contained in:
@@ -106,12 +106,12 @@ class Settings(BaseSettings):
|
|||||||
description="Secret detection mode: off, mask, or block",
|
description="Secret detection mode: off, mask, or block",
|
||||||
)
|
)
|
||||||
|
|
||||||
# OAuth2 configuration (for ChatGPT per-user Gitea authentication)
|
# OAuth2 configuration (for per-client Gitea authentication)
|
||||||
oauth_mode: bool = Field(
|
oauth_mode: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description=(
|
description=(
|
||||||
"Enable per-user OAuth2 authentication mode. "
|
"Enable per-user OAuth2 authentication mode. "
|
||||||
"When true, each ChatGPT user authenticates with their own Gitea account. "
|
"When true, each client user authenticates with their own Gitea account. "
|
||||||
"GITEA_TOKEN and MCP_API_KEYS are not required in this mode."
|
"GITEA_TOKEN and MCP_API_KEYS are not required in this mode."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -126,8 +126,9 @@ class Settings(BaseSettings):
|
|||||||
oauth_expected_audience: str = Field(
|
oauth_expected_audience: str = Field(
|
||||||
default="",
|
default="",
|
||||||
description=(
|
description=(
|
||||||
"Expected OIDC audience for access tokens. "
|
"Additional expected OIDC audience for access tokens. The canonical MCP "
|
||||||
"Defaults to GITEA_OAUTH_CLIENT_ID when unset."
|
"resource URL and the Gitea OAuth client id are always accepted; set this "
|
||||||
|
"to require an extra audience value."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
oauth_cache_ttl_seconds: int = Field(
|
oauth_cache_ttl_seconds: int = Field(
|
||||||
@@ -139,6 +140,37 @@ class Settings(BaseSettings):
|
|||||||
default="https://hiddenden.cafe/docs/mcp-gitea",
|
default="https://hiddenden.cafe/docs/mcp-gitea",
|
||||||
description="Public documentation URL for OAuth-protected MCP resource behavior",
|
description="Public documentation URL for OAuth-protected MCP resource behavior",
|
||||||
)
|
)
|
||||||
|
oauth_state_secret: str = Field(
|
||||||
|
default="",
|
||||||
|
description=(
|
||||||
|
"Server secret used to HMAC-sign the OAuth proxy state parameter. "
|
||||||
|
"Required when OAUTH_MODE=true so callback state is tamper-evident."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
oauth_redirect_allowlist_raw: str = Field(
|
||||||
|
default="",
|
||||||
|
description=(
|
||||||
|
"Comma-separated additional allowed client redirect URIs for the OAuth "
|
||||||
|
"callback proxy. Claude's callback URLs and loopback URIs are always allowed."
|
||||||
|
),
|
||||||
|
alias="OAUTH_REDIRECT_ALLOWLIST",
|
||||||
|
)
|
||||||
|
dcr_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"Enable RFC 7591 Dynamic Client Registration at /register. Claude's "
|
||||||
|
"connectors register dynamically; disable to require manual client_id/secret."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dcr_storage_path: Path = Field(
|
||||||
|
default=Path("/var/lib/aegis-mcp/dcr_clients.json"),
|
||||||
|
description="Path to the JSON file that persists dynamically registered clients",
|
||||||
|
)
|
||||||
|
repo_authz_cache_ttl_seconds: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="TTL (seconds) for cached per-user repository permission decisions",
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Authentication configuration
|
# Authentication configuration
|
||||||
auth_enabled: bool = Field(
|
auth_enabled: bool = Field(
|
||||||
@@ -269,12 +301,28 @@ class Settings(BaseSettings):
|
|||||||
"Set ALLOW_INSECURE_BIND=true to explicitly permit this."
|
"Set ALLOW_INSECURE_BIND=true to explicitly permit this."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_redirect_uris: list[str] = []
|
||||||
|
if self.oauth_redirect_allowlist_raw.strip():
|
||||||
|
extra_redirect_uris = [
|
||||||
|
value.strip()
|
||||||
|
for value in self.oauth_redirect_allowlist_raw.split(",")
|
||||||
|
if value.strip()
|
||||||
|
]
|
||||||
|
object.__setattr__(self, "_oauth_redirect_allowlist", extra_redirect_uris)
|
||||||
|
|
||||||
if self.oauth_mode:
|
if self.oauth_mode:
|
||||||
# In OAuth mode, per-user Gitea tokens are used; no shared bot token or API keys needed.
|
# In OAuth mode, per-user Gitea tokens are used; no shared bot token or API keys needed.
|
||||||
if not self.gitea_oauth_client_id.strip():
|
if not self.gitea_oauth_client_id.strip():
|
||||||
raise ValueError("GITEA_OAUTH_CLIENT_ID is required when OAUTH_MODE=true.")
|
raise ValueError("GITEA_OAUTH_CLIENT_ID is required when OAUTH_MODE=true.")
|
||||||
if not self.gitea_oauth_client_secret.strip():
|
if not self.gitea_oauth_client_secret.strip():
|
||||||
raise ValueError("GITEA_OAUTH_CLIENT_SECRET is required when OAUTH_MODE=true.")
|
raise ValueError("GITEA_OAUTH_CLIENT_SECRET is required when OAUTH_MODE=true.")
|
||||||
|
# The proxy state parameter carries the client's redirect_uri across the Gitea
|
||||||
|
# round-trip; it must be HMAC-signed, which requires a server-held secret.
|
||||||
|
if not self.oauth_state_secret.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"OAUTH_STATE_SECRET is required when OAUTH_MODE=true so the OAuth "
|
||||||
|
"proxy state parameter can be HMAC-signed and verified."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Standard API key mode: require bot token and at least one API key.
|
# Standard API key mode: require bot token and at least one API key.
|
||||||
if not self.gitea_token.strip():
|
if not self.gitea_token.strip():
|
||||||
@@ -308,6 +356,11 @@ class Settings(BaseSettings):
|
|||||||
"""Get parsed list of repositories allowed for write-mode operations."""
|
"""Get parsed list of repositories allowed for write-mode operations."""
|
||||||
return list(getattr(self, "_write_repository_whitelist", []))
|
return list(getattr(self, "_write_repository_whitelist", []))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def oauth_redirect_allowlist(self) -> list[str]:
|
||||||
|
"""Get parsed list of additional allowed client redirect URIs."""
|
||||||
|
return list(getattr(self, "_oauth_redirect_allowlist", []))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gitea_base_url(self) -> str:
|
def gitea_base_url(self) -> str:
|
||||||
"""Get Gitea base URL as normalized string."""
|
"""Get Gitea base URL as normalized string."""
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ def _tool(
|
|||||||
AVAILABLE_TOOLS: list[MCPTool] = [
|
AVAILABLE_TOOLS: list[MCPTool] = [
|
||||||
_tool(
|
_tool(
|
||||||
"list_repositories",
|
"list_repositories",
|
||||||
"List repositories visible to the configured bot account.",
|
"List repositories visible to the authenticated Gitea API token.",
|
||||||
{"type": "object", "properties": {}, "required": []},
|
{"type": "object", "properties": {}, "required": []},
|
||||||
),
|
),
|
||||||
_tool(
|
_tool(
|
||||||
|
|||||||
@@ -177,6 +177,29 @@ class GiteaOAuthValidator:
|
|||||||
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
|
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
|
||||||
return jwks
|
return jwks
|
||||||
|
|
||||||
|
def _acceptable_audiences(self) -> list[str]:
|
||||||
|
"""Return the set of OIDC audiences this MCP server will accept.
|
||||||
|
|
||||||
|
Per the MCP authorization spec (RFC 8707 / RFC 9728) tokens are bound to
|
||||||
|
the MCP server's canonical resource URL, so the configured public base is
|
||||||
|
the primary accepted audience. The upstream Gitea OAuth client id is also
|
||||||
|
accepted because Gitea — the actual token issuer behind this proxy —
|
||||||
|
stamps ``aud`` with the client id rather than the MCP resource URL. An
|
||||||
|
operator may add a further required audience via OAUTH_EXPECTED_AUDIENCE.
|
||||||
|
"""
|
||||||
|
audiences: list[str] = []
|
||||||
|
canonical_resource = self.settings.public_base
|
||||||
|
if canonical_resource:
|
||||||
|
audiences.append(canonical_resource)
|
||||||
|
gitea_client_id = self.settings.gitea_oauth_client_id.strip()
|
||||||
|
if gitea_client_id:
|
||||||
|
audiences.append(gitea_client_id)
|
||||||
|
configured = self.settings.oauth_expected_audience.strip()
|
||||||
|
if configured:
|
||||||
|
audiences.append(configured)
|
||||||
|
# Preserve order while removing duplicates.
|
||||||
|
return list(dict.fromkeys(audiences))
|
||||||
|
|
||||||
async def _validate_jwt(self, token: str) -> dict[str, Any]:
|
async def _validate_jwt(self, token: str) -> dict[str, Any]:
|
||||||
"""Validate JWT access token using OIDC discovery and JWKS."""
|
"""Validate JWT access token using OIDC discovery and JWKS."""
|
||||||
discovery = await self._get_discovery_document()
|
discovery = await self._get_discovery_document()
|
||||||
@@ -216,19 +239,16 @@ class GiteaOAuthValidator:
|
|||||||
"oauth_jwt_invalid_jwk",
|
"oauth_jwt_invalid_jwk",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
expected_audience = (
|
accepted_audiences = self._acceptable_audiences()
|
||||||
self.settings.oauth_expected_audience.strip()
|
|
||||||
or self.settings.gitea_oauth_client_id.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
decode_options = cast(Any, {"verify_aud": bool(expected_audience)})
|
decode_options = cast(Any, {"verify_aud": bool(accepted_audiences)})
|
||||||
try:
|
try:
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
token,
|
token,
|
||||||
key=cast(Any, public_key),
|
key=cast(Any, public_key),
|
||||||
algorithms=["RS256"],
|
algorithms=["RS256"],
|
||||||
issuer=issuer,
|
issuer=issuer,
|
||||||
audience=expected_audience or None,
|
audience=accepted_audiences or None,
|
||||||
options=decode_options,
|
options=decode_options,
|
||||||
)
|
)
|
||||||
except InvalidTokenError as exc:
|
except InvalidTokenError as exc:
|
||||||
|
|||||||
@@ -0,0 +1,380 @@
|
|||||||
|
"""OAuth proxy helpers for signed state, redirect validation, and DCR storage."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from fnmatch import fnmatchcase
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import ParseResult, urlparse, urlunparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
_CLAUDE_CALLBACK_URIS = {
|
||||||
|
"https://claude.ai/api/mcp/auth_callback",
|
||||||
|
"https://claude.com/api/mcp/auth_callback",
|
||||||
|
}
|
||||||
|
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
|
||||||
|
_SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {"none", "client_secret_post"}
|
||||||
|
_SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"}
|
||||||
|
_SUPPORTED_RESPONSE_TYPES = {"code"}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthRegistrationRequest(BaseModel):
|
||||||
|
"""Incoming RFC 7591 client registration request."""
|
||||||
|
|
||||||
|
client_name: str | None = Field(default=None, max_length=200)
|
||||||
|
redirect_uris: list[str] = Field(..., min_length=1)
|
||||||
|
grant_types: list[str] = Field(default_factory=lambda: ["authorization_code", "refresh_token"])
|
||||||
|
response_types: list[str] = Field(default_factory=lambda: ["code"])
|
||||||
|
token_endpoint_auth_method: str = Field(default="none", max_length=64)
|
||||||
|
scope: str | None = Field(default=None, max_length=512)
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
@field_validator("redirect_uris")
|
||||||
|
@classmethod
|
||||||
|
def validate_redirect_uris(cls, value: list[str]) -> list[str]:
|
||||||
|
"""Normalize and validate redirect URIs."""
|
||||||
|
uris = [uri.strip() for uri in value if isinstance(uri, str) and uri.strip()]
|
||||||
|
if not uris:
|
||||||
|
raise ValueError("redirect_uris must contain at least one non-empty URI")
|
||||||
|
return uris
|
||||||
|
|
||||||
|
@field_validator("grant_types")
|
||||||
|
@classmethod
|
||||||
|
def validate_grant_types(cls, value: list[str]) -> list[str]:
|
||||||
|
"""Restrict supported grant types to authorization code and refresh token."""
|
||||||
|
normalized = [item.strip() for item in value if item.strip()]
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("grant_types must not be empty")
|
||||||
|
if any(item not in _SUPPORTED_GRANT_TYPES for item in normalized):
|
||||||
|
raise ValueError("Unsupported grant_types requested")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@field_validator("response_types")
|
||||||
|
@classmethod
|
||||||
|
def validate_response_types(cls, value: list[str]) -> list[str]:
|
||||||
|
"""Restrict supported response types to authorization code."""
|
||||||
|
normalized = [item.strip() for item in value if item.strip()]
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("response_types must not be empty")
|
||||||
|
if any(item not in _SUPPORTED_RESPONSE_TYPES for item in normalized):
|
||||||
|
raise ValueError("Unsupported response_types requested")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@field_validator("token_endpoint_auth_method")
|
||||||
|
@classmethod
|
||||||
|
def validate_token_endpoint_auth_method(cls, value: str) -> str:
|
||||||
|
"""Restrict token endpoint auth methods to the supported subset."""
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized not in _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS:
|
||||||
|
raise ValueError("Unsupported token_endpoint_auth_method requested")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_pkce_ready(self) -> OAuthRegistrationRequest:
|
||||||
|
"""Ensure the request is usable for PKCE-based authorization code flow."""
|
||||||
|
if "authorization_code" not in self.grant_types:
|
||||||
|
raise ValueError("authorization_code grant is required")
|
||||||
|
if "code" not in self.response_types:
|
||||||
|
raise ValueError("code response type is required")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientRecord(BaseModel):
|
||||||
|
"""Persisted OAuth client registration record."""
|
||||||
|
|
||||||
|
client_id: str
|
||||||
|
client_name: str | None = None
|
||||||
|
redirect_uris: list[str]
|
||||||
|
grant_types: list[str]
|
||||||
|
response_types: list[str]
|
||||||
|
token_endpoint_auth_method: str
|
||||||
|
client_id_issued_at: int
|
||||||
|
client_secret_expires_at: int = 0
|
||||||
|
client_secret_hash: str | None = None
|
||||||
|
scope: str | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
def _canonicalize_url(value: str) -> str:
|
||||||
|
"""Normalize a URL for comparison."""
|
||||||
|
parsed = urlparse(value.strip())
|
||||||
|
if not parsed.scheme or not parsed.netloc:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
normalized = ParseResult(
|
||||||
|
scheme=parsed.scheme.lower(),
|
||||||
|
netloc=parsed.netloc.lower(),
|
||||||
|
path=parsed.path or "/",
|
||||||
|
params=parsed.params,
|
||||||
|
query=parsed.query,
|
||||||
|
fragment="",
|
||||||
|
)
|
||||||
|
return urlunparse(normalized).rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def is_loopback_redirect_uri(redirect_uri: str) -> bool:
|
||||||
|
"""Return whether a redirect URI uses a loopback host."""
|
||||||
|
parsed = urlparse(redirect_uri.strip())
|
||||||
|
if parsed.scheme != "http":
|
||||||
|
return False
|
||||||
|
host = (parsed.hostname or "").lower()
|
||||||
|
return host in _LOOPBACK_HOSTS
|
||||||
|
|
||||||
|
|
||||||
|
def is_claude_redirect_uri(redirect_uri: str) -> bool:
|
||||||
|
"""Return whether a redirect URI is a built-in Claude callback URL."""
|
||||||
|
return _canonicalize_url(redirect_uri) in _CLAUDE_CALLBACK_URIS
|
||||||
|
|
||||||
|
|
||||||
|
def is_redirect_uri_allowed(redirect_uri: str, allowlist: list[str]) -> bool:
|
||||||
|
"""Return whether a redirect URI is allowed by policy."""
|
||||||
|
normalized = _canonicalize_url(redirect_uri)
|
||||||
|
if not normalized:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_loopback_redirect_uri(redirect_uri) or is_claude_redirect_uri(redirect_uri):
|
||||||
|
return True
|
||||||
|
|
||||||
|
for pattern in allowlist:
|
||||||
|
candidate = pattern.strip()
|
||||||
|
if not candidate:
|
||||||
|
continue
|
||||||
|
if fnmatchcase(normalized, _canonicalize_url(candidate) or candidate):
|
||||||
|
return True
|
||||||
|
if fnmatchcase(redirect_uri.strip(), candidate):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_origin_allowed(origin: str, request_base: str, public_base: str | None) -> bool:
|
||||||
|
"""Return whether a browser Origin is allowed for MCP transport requests."""
|
||||||
|
normalized_origin = _canonicalize_url(origin)
|
||||||
|
if not normalized_origin:
|
||||||
|
return False
|
||||||
|
|
||||||
|
expected_bases = [request_base.rstrip("/")]
|
||||||
|
if public_base:
|
||||||
|
expected_bases.append(public_base.rstrip("/"))
|
||||||
|
return normalized_origin in expected_bases
|
||||||
|
|
||||||
|
|
||||||
|
def encode_proxy_state(
|
||||||
|
secret: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
original_state: str,
|
||||||
|
*,
|
||||||
|
ttl_seconds: int = 600,
|
||||||
|
) -> str:
|
||||||
|
"""Create a signed OAuth state wrapper for the proxy callback round-trip."""
|
||||||
|
payload = {
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"state": original_state,
|
||||||
|
"issued_at": int(time.time()),
|
||||||
|
"nonce": secrets.token_urlsafe(16),
|
||||||
|
"ttl_seconds": ttl_seconds,
|
||||||
|
}
|
||||||
|
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
||||||
|
signature = hmac.new(secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256)
|
||||||
|
envelope = {
|
||||||
|
"payload": payload,
|
||||||
|
"signature": signature.hexdigest(),
|
||||||
|
}
|
||||||
|
return base64.urlsafe_b64encode(
|
||||||
|
json.dumps(envelope, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||||
|
).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_proxy_state(secret: str, encoded_state: str) -> dict[str, str]:
|
||||||
|
"""Verify and unpack a signed OAuth state wrapper."""
|
||||||
|
try:
|
||||||
|
raw = base64.urlsafe_b64decode(encoded_state.encode("ascii"))
|
||||||
|
envelope = json.loads(raw)
|
||||||
|
except Exception as exc: # pragma: no cover - guarded by tests
|
||||||
|
raise ValueError("Invalid or missing state parameter") from exc
|
||||||
|
|
||||||
|
if not isinstance(envelope, dict):
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
|
||||||
|
payload = envelope.get("payload")
|
||||||
|
signature = envelope.get("signature")
|
||||||
|
if not isinstance(payload, dict) or not isinstance(signature, str):
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
|
||||||
|
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
||||||
|
expected_signature = hmac.new(
|
||||||
|
secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
if not hmac.compare_digest(signature, expected_signature):
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
|
||||||
|
issued_at = payload.get("issued_at")
|
||||||
|
ttl_seconds = payload.get("ttl_seconds")
|
||||||
|
now = int(time.time())
|
||||||
|
if not isinstance(issued_at, int) or not isinstance(ttl_seconds, int):
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
if issued_at > now or now - issued_at > max(ttl_seconds, 1):
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
|
||||||
|
redirect_uri = payload.get("redirect_uri")
|
||||||
|
if not isinstance(redirect_uri, str) or not redirect_uri.strip():
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
|
||||||
|
original_state = payload.get("state")
|
||||||
|
if not isinstance(original_state, str):
|
||||||
|
raise ValueError("Invalid or missing state parameter")
|
||||||
|
|
||||||
|
return {"redirect_uri": redirect_uri, "state": original_state}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientRegistry:
|
||||||
|
"""Persisted OAuth client registry for dynamic client registration."""
|
||||||
|
|
||||||
|
def __init__(self, storage_path: Path) -> None:
|
||||||
|
"""Initialize registry storage."""
|
||||||
|
self.storage_path = storage_path
|
||||||
|
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._clients: dict[str, OAuthClientRecord] = {}
|
||||||
|
self._loaded = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _hash_secret(secret: str) -> str:
|
||||||
|
"""Hash client secrets before persistence."""
|
||||||
|
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def _load(self) -> None:
|
||||||
|
"""Load persisted registrations from disk once."""
|
||||||
|
if self._loaded:
|
||||||
|
return
|
||||||
|
self._loaded = True
|
||||||
|
if not self.storage_path.exists():
|
||||||
|
self._clients = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
raw = json.loads(self.storage_path.read_text(encoding="utf-8"))
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise ValueError("Persisted DCR storage must be a JSON object")
|
||||||
|
|
||||||
|
clients: dict[str, OAuthClientRecord] = {}
|
||||||
|
for client_id, payload in raw.items():
|
||||||
|
if not isinstance(client_id, str):
|
||||||
|
raise ValueError("Persisted client id must be a string")
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise ValueError(f"Persisted client record for {client_id} must be a mapping")
|
||||||
|
record = OAuthClientRecord.model_validate({"client_id": client_id, **payload})
|
||||||
|
clients[client_id] = record
|
||||||
|
|
||||||
|
self._clients = clients
|
||||||
|
|
||||||
|
def _persist(self) -> None:
|
||||||
|
"""Write registrations atomically."""
|
||||||
|
payload = {
|
||||||
|
client_id: record.model_dump(mode="json", exclude={"client_id"})
|
||||||
|
for client_id, record in self._clients.items()
|
||||||
|
}
|
||||||
|
tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp")
|
||||||
|
tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2), encoding="utf-8")
|
||||||
|
tmp_path.replace(self.storage_path)
|
||||||
|
|
||||||
|
def get(self, client_id: str) -> OAuthClientRecord | None:
|
||||||
|
"""Look up a registered client by identifier."""
|
||||||
|
self._load()
|
||||||
|
return self._clients.get(client_id)
|
||||||
|
|
||||||
|
def is_known_client(
|
||||||
|
self,
|
||||||
|
client_id: str,
|
||||||
|
*,
|
||||||
|
fallback_client_id: str = "",
|
||||||
|
fallback_client_secret: str = "",
|
||||||
|
) -> bool:
|
||||||
|
"""Return whether a client is recognized by the registry or environment."""
|
||||||
|
if not client_id.strip():
|
||||||
|
return False
|
||||||
|
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
|
||||||
|
return True
|
||||||
|
return self.get(client_id) is not None
|
||||||
|
|
||||||
|
def validate_client_secret(
|
||||||
|
self,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str | None,
|
||||||
|
*,
|
||||||
|
fallback_client_id: str = "",
|
||||||
|
fallback_client_secret: str = "",
|
||||||
|
) -> bool:
|
||||||
|
"""Validate a client identifier and optional secret."""
|
||||||
|
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
|
||||||
|
if not fallback_client_secret.strip():
|
||||||
|
return True
|
||||||
|
if not client_secret:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(
|
||||||
|
self._hash_secret(client_secret), self._hash_secret(fallback_client_secret.strip())
|
||||||
|
)
|
||||||
|
|
||||||
|
record = self.get(client_id)
|
||||||
|
if record is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if record.client_secret_hash is None:
|
||||||
|
return True
|
||||||
|
if not client_secret:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(self._hash_secret(client_secret), record.client_secret_hash)
|
||||||
|
|
||||||
|
def register(self, request: OAuthRegistrationRequest) -> dict[str, Any]:
|
||||||
|
"""Persist a new OAuth client registration and return its public metadata."""
|
||||||
|
self._load()
|
||||||
|
client_id = secrets.token_urlsafe(24)
|
||||||
|
client_secret: str | None = None
|
||||||
|
client_secret_hash: str | None = None
|
||||||
|
|
||||||
|
if request.token_endpoint_auth_method != "none":
|
||||||
|
client_secret = secrets.token_urlsafe(32)
|
||||||
|
client_secret_hash = self._hash_secret(client_secret)
|
||||||
|
|
||||||
|
record = OAuthClientRecord(
|
||||||
|
client_id=client_id,
|
||||||
|
client_name=request.client_name,
|
||||||
|
redirect_uris=list(request.redirect_uris),
|
||||||
|
grant_types=list(request.grant_types),
|
||||||
|
response_types=list(request.response_types),
|
||||||
|
token_endpoint_auth_method=request.token_endpoint_auth_method,
|
||||||
|
client_id_issued_at=int(time.time()),
|
||||||
|
client_secret_hash=client_secret_hash,
|
||||||
|
scope=request.scope,
|
||||||
|
)
|
||||||
|
self._clients[client_id] = record
|
||||||
|
self._persist()
|
||||||
|
|
||||||
|
response: dict[str, Any] = record.model_dump(exclude={"client_secret_hash"})
|
||||||
|
if client_secret is not None:
|
||||||
|
response["client_secret"] = client_secret
|
||||||
|
response["client_secret_expires_at"] = 0
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
_oauth_client_registry: OAuthClientRegistry | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_client_registry(storage_path: Path) -> OAuthClientRegistry:
|
||||||
|
"""Get or create the global OAuth client registry."""
|
||||||
|
global _oauth_client_registry
|
||||||
|
if _oauth_client_registry is None or _oauth_client_registry.storage_path != storage_path:
|
||||||
|
_oauth_client_registry = OAuthClientRegistry(storage_path)
|
||||||
|
return _oauth_client_registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_oauth_client_registry() -> None:
|
||||||
|
"""Reset the global OAuth client registry (primarily for tests)."""
|
||||||
|
global _oauth_client_registry
|
||||||
|
_oauth_client_registry = None
|
||||||
+385
-43
@@ -3,10 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
@@ -36,11 +36,20 @@ from aegis_gitea_mcp.mcp_protocol import (
|
|||||||
get_tool_by_name,
|
get_tool_by_name,
|
||||||
)
|
)
|
||||||
from aegis_gitea_mcp.oauth import get_oauth_validator
|
from aegis_gitea_mcp.oauth import get_oauth_validator
|
||||||
|
from aegis_gitea_mcp.oauth_flow import (
|
||||||
|
OAuthRegistrationRequest,
|
||||||
|
decode_proxy_state,
|
||||||
|
encode_proxy_state,
|
||||||
|
get_oauth_client_registry,
|
||||||
|
is_origin_allowed,
|
||||||
|
is_redirect_uri_allowed,
|
||||||
|
)
|
||||||
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
|
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
|
||||||
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
|
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
|
||||||
from aegis_gitea_mcp.rate_limit import get_rate_limiter
|
from aegis_gitea_mcp.rate_limit import get_rate_limiter
|
||||||
from aegis_gitea_mcp.request_context import (
|
from aegis_gitea_mcp.request_context import (
|
||||||
clear_gitea_auth_context,
|
clear_gitea_auth_context,
|
||||||
|
get_gitea_user_login,
|
||||||
get_gitea_user_scopes,
|
get_gitea_user_scopes,
|
||||||
get_gitea_user_token,
|
get_gitea_user_token,
|
||||||
set_gitea_user_login,
|
set_gitea_user_login,
|
||||||
@@ -94,9 +103,40 @@ _api_scope_cache: BoundedTTLCache[str, bool] = BoundedTTLCache(
|
|||||||
_REAUTH_GUIDANCE = (
|
_REAUTH_GUIDANCE = (
|
||||||
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
|
"Your OAuth token lacks Gitea API scopes (e.g. read:repository). "
|
||||||
"Revoke the authorization in Gitea (Settings > Applications > Authorized OAuth2 Applications) "
|
"Revoke the authorization in Gitea (Settings > Applications > Authorized OAuth2 Applications) "
|
||||||
"and in ChatGPT (Settings > Connected apps), then re-authorize."
|
"and in your client, then re-authorize."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_repo_authz_cache: BoundedTTLCache[str, bool] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_repo_authz_cache() -> BoundedTTLCache[str, bool]:
|
||||||
|
"""Get the bounded cache for per-user repository permission checks."""
|
||||||
|
global _repo_authz_cache
|
||||||
|
settings = get_settings()
|
||||||
|
if _repo_authz_cache is None:
|
||||||
|
_repo_authz_cache = BoundedTTLCache(
|
||||||
|
ttl_seconds=settings.repo_authz_cache_ttl_seconds,
|
||||||
|
max_size=2048,
|
||||||
|
)
|
||||||
|
return _repo_authz_cache
|
||||||
|
|
||||||
|
|
||||||
|
def reset_repo_authz_cache() -> None:
|
||||||
|
"""Reset the repository authorization cache (primarily for tests)."""
|
||||||
|
global _repo_authz_cache
|
||||||
|
_repo_authz_cache = None
|
||||||
|
|
||||||
|
|
||||||
|
def _repo_authz_cache_key(login: str, repository: str, required_scope: str) -> str:
|
||||||
|
"""Build a bounded cache key for a user/repository permission check."""
|
||||||
|
normalized_login = login.strip().lower()
|
||||||
|
return f"{normalized_login}:{repository.lower()}:{required_scope}"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_mcp_transport_path(path: str) -> bool:
|
||||||
|
"""Return whether a request targets the MCP transport surface."""
|
||||||
|
return path in {"/mcp", "/mcp/sse"} or path.startswith("/mcp/")
|
||||||
|
|
||||||
|
|
||||||
def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
||||||
"""Return whether granted scopes satisfy the required MCP tool scope."""
|
"""Return whether granted scopes satisfy the required MCP tool scope."""
|
||||||
@@ -116,6 +156,129 @@ def _has_required_scope(required_scope: str, granted_scopes: set[str]) -> bool:
|
|||||||
return required_scope in expanded
|
return required_scope in expanded
|
||||||
|
|
||||||
|
|
||||||
|
def _repo_permission_satisfied(permission: dict[str, Any], required_scope: str) -> bool:
|
||||||
|
"""Return whether a repository permission payload satisfies the requested scope."""
|
||||||
|
permission_name = str(permission.get("permission", "")).lower().strip()
|
||||||
|
if permission_name in {"admin", "owner"}:
|
||||||
|
return True
|
||||||
|
if required_scope == WRITE_SCOPE and permission_name == "write":
|
||||||
|
return True
|
||||||
|
if required_scope == READ_SCOPE and permission_name in {"read", "write"}:
|
||||||
|
return True
|
||||||
|
|
||||||
|
nested_permissions = permission.get("permissions")
|
||||||
|
if isinstance(nested_permissions, dict):
|
||||||
|
return _repo_permission_satisfied(nested_permissions, required_scope)
|
||||||
|
|
||||||
|
if required_scope == WRITE_SCOPE:
|
||||||
|
return bool(permission.get("push") or permission.get("admin"))
|
||||||
|
return bool(permission.get("pull") or permission.get("push") or permission.get("admin"))
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_user_repository_access(
|
||||||
|
*,
|
||||||
|
repository: str,
|
||||||
|
required_scope: str,
|
||||||
|
user_login: str,
|
||||||
|
correlation_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Verify the authenticated user can access the target repository before PAT fallback."""
|
||||||
|
settings = get_settings()
|
||||||
|
audit = get_audit_logger()
|
||||||
|
|
||||||
|
service_token = settings.gitea_token.strip()
|
||||||
|
if not service_token:
|
||||||
|
raise HTTPException(status_code=500, detail="Repository authorization misconfigured")
|
||||||
|
|
||||||
|
if not user_login.strip() or user_login == "unknown":
|
||||||
|
audit.log_access_denied(
|
||||||
|
tool_name=tool_name,
|
||||||
|
repository=repository,
|
||||||
|
reason="repository_permission_missing_user",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Unable to verify repository permission for this user.",
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_key = _repo_authz_cache_key(user_login, repository, required_scope)
|
||||||
|
cached = _get_repo_authz_cache().get(cache_key)
|
||||||
|
if cached is True:
|
||||||
|
return
|
||||||
|
|
||||||
|
owner, repo = repository.split("/", 1)
|
||||||
|
encoded_owner = urllib.parse.quote(owner, safe="")
|
||||||
|
encoded_repo = urllib.parse.quote(repo, safe="")
|
||||||
|
encoded_user = urllib.parse.quote(user_login, safe="")
|
||||||
|
permission_url = (
|
||||||
|
f"{settings.gitea_base_url}/api/v1/repos/{encoded_owner}/{encoded_repo}"
|
||||||
|
f"/collaborators/{encoded_user}/permission"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=settings.request_timeout_seconds) as client:
|
||||||
|
response = await client.get(
|
||||||
|
permission_url,
|
||||||
|
headers={"Authorization": f"token {service_token}", "Accept": "application/json"},
|
||||||
|
)
|
||||||
|
except httpx.RequestError as exc:
|
||||||
|
audit.log_access_denied(
|
||||||
|
tool_name=tool_name,
|
||||||
|
repository=repository,
|
||||||
|
reason="repository_permission_probe_failed",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Unable to verify repository permission for this user.",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
audit.log_access_denied(
|
||||||
|
tool_name=tool_name,
|
||||||
|
repository=repository,
|
||||||
|
reason=f"repository_permission_probe:{response.status_code}",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="User does not have permission for the requested repository.",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
permission_payload = response.json()
|
||||||
|
except ValueError as exc:
|
||||||
|
audit.log_access_denied(
|
||||||
|
tool_name=tool_name,
|
||||||
|
repository=repository,
|
||||||
|
reason="repository_permission_invalid_json",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Unable to verify repository permission for this user.",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
if isinstance(permission_payload, dict) and _repo_permission_satisfied(
|
||||||
|
permission_payload, required_scope
|
||||||
|
):
|
||||||
|
_get_repo_authz_cache().set(cache_key, True)
|
||||||
|
return
|
||||||
|
|
||||||
|
audit.log_access_denied(
|
||||||
|
tool_name=tool_name,
|
||||||
|
repository=repository,
|
||||||
|
reason="repository_permission_denied",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="User does not have permission for the requested repository.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
|
"""Run startup and shutdown hooks via the FastAPI lifespan protocol."""
|
||||||
@@ -243,6 +406,67 @@ async def request_context_middleware(
|
|||||||
metrics.record_http_request(request.method, request.url.path, status_code)
|
metrics.record_http_request(request.method, request.url.path, status_code)
|
||||||
|
|
||||||
|
|
||||||
|
def _cors_headers(origin: str) -> dict[str, str]:
|
||||||
|
"""Build strict CORS headers for a validated browser origin."""
|
||||||
|
return {
|
||||||
|
"Access-Control-Allow-Origin": origin,
|
||||||
|
"Access-Control-Allow-Credentials": "true",
|
||||||
|
"Access-Control-Allow-Methods": "GET,POST,OPTIONS",
|
||||||
|
"Access-Control-Allow-Headers": "Authorization,Content-Type,MCP-Protocol-Version,X-Request-ID",
|
||||||
|
"Access-Control-Expose-Headers": "X-Request-ID,WWW-Authenticate",
|
||||||
|
"Vary": "Origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def strict_origin_and_cors_middleware(
|
||||||
|
request: Request,
|
||||||
|
call_next: Callable[[Request], Awaitable[Response]],
|
||||||
|
) -> Response:
|
||||||
|
"""Enforce strict browser origins for MCP transport requests."""
|
||||||
|
if request.url.path not in {"/mcp", "/mcp/sse"}:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
expected_base = settings.public_base or str(request.base_url).rstrip("/")
|
||||||
|
|
||||||
|
if origin and not is_origin_allowed(origin, expected_base, settings.public_base):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={
|
||||||
|
"error": "Origin not allowed",
|
||||||
|
"message": "The request origin is not allowed for this MCP transport.",
|
||||||
|
"request_id": getattr(request.state, "request_id", "-"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
response = Response(status_code=204)
|
||||||
|
else:
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
if origin and is_origin_allowed(origin, expected_base, settings.public_base):
|
||||||
|
for header, value in _cors_headers(origin).items():
|
||||||
|
response.headers[header] = value
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _oauth_invalid_client_response() -> JSONResponse:
|
||||||
|
"""Return an RFC 6749 invalid_client error for token endpoint failures."""
|
||||||
|
response = JSONResponse(status_code=401, content={"error": "invalid_client"})
|
||||||
|
response.headers["WWW-Authenticate"] = 'Basic realm="oauth"'
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _jsonrpc_error(message_id: Any, code: int, message: str) -> JSONResponse:
|
||||||
|
"""Build a JSON-RPC error response envelope."""
|
||||||
|
return JSONResponse(
|
||||||
|
content={"jsonrpc": "2.0", "id": message_id, "error": {"code": code, "message": message}}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def authenticate_and_rate_limit(
|
async def authenticate_and_rate_limit(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -255,11 +479,14 @@ async def authenticate_and_rate_limit(
|
|||||||
if request.url.path in {"/", "/health"}:
|
if request.url.path in {"/", "/health"}:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
if request.method == "OPTIONS" and request.url.path in {"/mcp", "/mcp/sse"}:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
if request.url.path == "/metrics" and settings.metrics_enabled:
|
if request.url.path == "/metrics" and settings.metrics_enabled:
|
||||||
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
|
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# OAuth discovery and token endpoints must be public so ChatGPT can complete the flow.
|
# OAuth discovery and token endpoints must be public so MCP clients can complete the flow.
|
||||||
if request.url.path in {
|
if request.url.path in {
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
"/.well-known/oauth-protected-resource",
|
"/.well-known/oauth-protected-resource",
|
||||||
@@ -268,7 +495,11 @@ async def authenticate_and_rate_limit(
|
|||||||
}:
|
}:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
|
if not (
|
||||||
|
request.url.path in {"/mcp/tools"}
|
||||||
|
or _is_mcp_transport_path(request.url.path)
|
||||||
|
or request.url.path.startswith("/automation/")
|
||||||
|
):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
oauth_validator = get_oauth_validator()
|
oauth_validator = get_oauth_validator()
|
||||||
@@ -296,7 +527,7 @@ async def authenticate_and_rate_limit(
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
if request.url.path.startswith("/mcp/"):
|
if _is_mcp_transport_path(request.url.path):
|
||||||
return _oauth_unauthorized_response(
|
return _oauth_unauthorized_response(
|
||||||
request,
|
request,
|
||||||
"Provide Authorization: Bearer <token>.",
|
"Provide Authorization: Bearer <token>.",
|
||||||
@@ -315,7 +546,7 @@ async def authenticate_and_rate_limit(
|
|||||||
access_token, client_ip, user_agent
|
access_token, client_ip, user_agent
|
||||||
)
|
)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
if request.url.path.startswith("/mcp/"):
|
if _is_mcp_transport_path(request.url.path):
|
||||||
return _oauth_unauthorized_response(
|
return _oauth_unauthorized_response(
|
||||||
request,
|
request,
|
||||||
error_message or "Invalid or expired OAuth token.",
|
error_message or "Invalid or expired OAuth token.",
|
||||||
@@ -400,7 +631,7 @@ async def authenticate_and_rate_limit(
|
|||||||
"OAuth token is valid but lacks required Gitea API access. "
|
"OAuth token is valid but lacks required Gitea API access. "
|
||||||
"Re-authorize this OAuth app in Gitea and try again."
|
"Re-authorize this OAuth app in Gitea and try again."
|
||||||
)
|
)
|
||||||
if request.url.path.startswith("/mcp/"):
|
if _is_mcp_transport_path(request.url.path):
|
||||||
return _oauth_unauthorized_response(
|
return _oauth_unauthorized_response(
|
||||||
request,
|
request,
|
||||||
message,
|
message,
|
||||||
@@ -508,9 +739,14 @@ async def health() -> dict[str, str]:
|
|||||||
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
||||||
"""OAuth 2.0 Protected Resource Metadata (RFC 9728).
|
"""OAuth 2.0 Protected Resource Metadata (RFC 9728).
|
||||||
|
|
||||||
Required by the MCP Authorization spec so that OAuth clients (e.g. ChatGPT)
|
Required by the MCP Authorization spec so that OAuth clients (Claude's
|
||||||
can discover the authorization server that protects this resource.
|
connector infrastructure) can discover the authorization server that
|
||||||
ChatGPT fetches this endpoint when it first connects to the MCP server via SSE.
|
protects this resource. Claude fetches this endpoint when it first connects.
|
||||||
|
|
||||||
|
The ``resource`` value MUST be THIS server's own canonical public URL: the
|
||||||
|
MCP client verifies that the resource identifier matches the origin it
|
||||||
|
derived the MCP server URL from (RFC 9728 / RFC 8707). Returning the upstream
|
||||||
|
Gitea URL here would fail that check.
|
||||||
"""
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
gitea_base = settings.gitea_base_url
|
gitea_base = settings.gitea_base_url
|
||||||
@@ -521,7 +757,7 @@ async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
|||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
"resource": gitea_base,
|
"resource": base_url,
|
||||||
"authorization_servers": authorization_servers,
|
"authorization_servers": authorization_servers,
|
||||||
"bearer_methods_supported": ["header"],
|
"bearer_methods_supported": ["header"],
|
||||||
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
||||||
@@ -534,24 +770,52 @@ async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
|||||||
async def oauth_authorize_proxy(request: Request) -> RedirectResponse:
|
async def oauth_authorize_proxy(request: Request) -> RedirectResponse:
|
||||||
"""Proxy OAuth authorization to Gitea, replacing redirect_uri with our own callback.
|
"""Proxy OAuth authorization to Gitea, replacing redirect_uri with our own callback.
|
||||||
|
|
||||||
Clients (ChatGPT, Claude, etc.) send their own redirect_uri which Gitea doesn't know
|
Clients (Claude, Claude Code, Cowork, etc.) send their own redirect_uri which Gitea doesn't know
|
||||||
about. This endpoint intercepts the request, encodes the original redirect_uri and
|
about. This endpoint intercepts the request, encodes the original redirect_uri and
|
||||||
state into a new state parameter, and forwards the request to Gitea using the MCP
|
state into a new state parameter, and forwards the request to Gitea using the MCP
|
||||||
server's own callback URI — the only URI that needs to be registered in Gitea.
|
server's own callback URI — the only URI that needs to be registered in Gitea.
|
||||||
"""
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
||||||
|
registry = get_oauth_client_registry(settings.dcr_storage_path)
|
||||||
|
|
||||||
params = dict(request.query_params)
|
params = dict(request.query_params)
|
||||||
client_redirect_uri = params.pop("redirect_uri", "")
|
client_redirect_uri = params.pop("redirect_uri", "").strip()
|
||||||
|
client_id = params.get("client_id", "").strip() or settings.gitea_oauth_client_id.strip()
|
||||||
original_state = params.get("state", "")
|
original_state = params.get("state", "")
|
||||||
|
params.pop("client_secret", None)
|
||||||
|
|
||||||
# Encode the client's redirect_uri + original state into a tamper-evident wrapper.
|
if not client_id:
|
||||||
# We simply base64-encode a JSON blob; Gitea will echo it back on the callback.
|
raise HTTPException(status_code=400, detail="Missing client_id")
|
||||||
proxy_state_data = {"redirect_uri": client_redirect_uri, "state": original_state}
|
if not registry.is_known_client(
|
||||||
proxy_state = base64.urlsafe_b64encode(json.dumps(proxy_state_data).encode()).decode()
|
client_id,
|
||||||
|
fallback_client_id=settings.gitea_oauth_client_id,
|
||||||
|
):
|
||||||
|
raise HTTPException(status_code=401, detail="invalid_client")
|
||||||
|
|
||||||
|
if not client_redirect_uri:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing redirect_uri")
|
||||||
|
if not is_redirect_uri_allowed(client_redirect_uri, settings.oauth_redirect_allowlist):
|
||||||
|
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
|
||||||
|
|
||||||
|
code_challenge = params.get("code_challenge", "").strip()
|
||||||
|
code_challenge_method = params.get("code_challenge_method", "S256").strip().upper()
|
||||||
|
if not code_challenge:
|
||||||
|
raise HTTPException(status_code=400, detail="PKCE code_challenge is required")
|
||||||
|
if code_challenge_method != "S256":
|
||||||
|
raise HTTPException(status_code=400, detail="PKCE code_challenge_method must be S256")
|
||||||
|
|
||||||
|
proxy_state = encode_proxy_state(
|
||||||
|
settings.oauth_state_secret,
|
||||||
|
client_redirect_uri,
|
||||||
|
original_state,
|
||||||
|
ttl_seconds=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
params["client_id"] = settings.gitea_oauth_client_id
|
||||||
params["state"] = proxy_state
|
params["state"] = proxy_state
|
||||||
|
params["code_challenge"] = code_challenge
|
||||||
|
params["code_challenge_method"] = "S256"
|
||||||
params["redirect_uri"] = f"{base_url}/oauth/callback"
|
params["redirect_uri"] = f"{base_url}/oauth/callback"
|
||||||
|
|
||||||
gitea_authorize_url = f"{settings.gitea_base_url}/login/oauth/authorize"
|
gitea_authorize_url = f"{settings.gitea_base_url}/login/oauth/authorize"
|
||||||
@@ -568,14 +832,17 @@ async def oauth_callback_proxy(request: Request) -> RedirectResponse:
|
|||||||
error_description = request.query_params.get("error_description", "")
|
error_description = request.query_params.get("error_description", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state_data = json.loads(base64.urlsafe_b64decode(proxy_state.encode()))
|
state_data = decode_proxy_state(get_settings().oauth_state_secret, proxy_state)
|
||||||
client_redirect_uri = state_data["redirect_uri"]
|
client_redirect_uri = state_data["redirect_uri"]
|
||||||
original_state = state_data["state"]
|
original_state = state_data["state"]
|
||||||
except Exception as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=400, detail="Invalid or missing state parameter") from exc
|
raise HTTPException(status_code=400, detail="Invalid or missing state parameter") from exc
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
if not client_redirect_uri:
|
if not client_redirect_uri:
|
||||||
raise HTTPException(status_code=400, detail="No client redirect_uri in state")
|
raise HTTPException(status_code=400, detail="No client redirect_uri in state")
|
||||||
|
if not is_redirect_uri_allowed(client_redirect_uri, settings.oauth_redirect_allowlist):
|
||||||
|
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
|
||||||
|
|
||||||
result_params: dict[str, str] = {}
|
result_params: dict[str, str] = {}
|
||||||
if error:
|
if error:
|
||||||
@@ -595,26 +862,31 @@ async def oauth_callback_proxy(request: Request) -> RedirectResponse:
|
|||||||
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
|
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
|
||||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
"""OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||||
|
|
||||||
Proxies Gitea's OAuth authorization server metadata so that ChatGPT can
|
Advertises this server's OAuth proxy endpoints so that Claude's connector
|
||||||
discover the authorize URL, token URL, and supported features directly
|
infrastructure can discover the authorize URL, token URL, and dynamic client
|
||||||
from this server without needing to know the Gitea URL upfront.
|
registration endpoint directly from this server without knowing the Gitea URL
|
||||||
|
upfront. The authorize/token endpoints are this server's proxy routes because
|
||||||
|
Gitea does not know Claude's redirect_uri.
|
||||||
"""
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
||||||
gitea_base = settings.gitea_base_url
|
gitea_base = settings.gitea_base_url
|
||||||
|
|
||||||
return JSONResponse(
|
metadata: dict[str, Any] = {
|
||||||
content={
|
"issuer": gitea_base,
|
||||||
"issuer": gitea_base,
|
"authorization_endpoint": f"{base_url}/oauth/authorize",
|
||||||
"authorization_endpoint": f"{base_url}/oauth/authorize",
|
"token_endpoint": f"{base_url}/oauth/token",
|
||||||
"token_endpoint": f"{base_url}/oauth/token",
|
"response_types_supported": ["code"],
|
||||||
"response_types_supported": ["code"],
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
"grant_types_supported": ["authorization_code"],
|
"code_challenge_methods_supported": ["S256"],
|
||||||
"code_challenge_methods_supported": ["S256"],
|
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
||||||
"scopes_supported": [READ_SCOPE, WRITE_SCOPE],
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
|
}
|
||||||
}
|
if settings.dcr_enabled:
|
||||||
)
|
# RFC 7591 dynamic client registration endpoint (Claude registers here).
|
||||||
|
metadata["registration_endpoint"] = f"{base_url}/register"
|
||||||
|
|
||||||
|
return JSONResponse(content=metadata)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/.well-known/openid-configuration")
|
@app.get("/.well-known/openid-configuration")
|
||||||
@@ -649,15 +921,47 @@ async def openid_configuration(request: Request) -> JSONResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/register")
|
||||||
|
async def oauth_dynamic_client_registration(request: Request) -> JSONResponse:
|
||||||
|
"""Persist a new OAuth client registration for Claude and similar MCP clients."""
|
||||||
|
settings = get_settings()
|
||||||
|
if not settings.dcr_enabled:
|
||||||
|
raise HTTPException(status_code=404, detail="Dynamic client registration is disabled")
|
||||||
|
|
||||||
|
content_type = request.headers.get("content-type", "").split(";", 1)[0].strip().lower()
|
||||||
|
if content_type != "application/json":
|
||||||
|
raise HTTPException(status_code=415, detail="Content-Type must be application/json")
|
||||||
|
|
||||||
|
registry = get_oauth_client_registry(settings.dcr_storage_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
registration_request = OAuthRegistrationRequest.model_validate(payload)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid registration payload") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid registration payload") from exc
|
||||||
|
|
||||||
|
for redirect_uri in registration_request.redirect_uris:
|
||||||
|
if not is_redirect_uri_allowed(redirect_uri, settings.oauth_redirect_allowlist):
|
||||||
|
raise HTTPException(status_code=400, detail="redirect_uri is not allowed")
|
||||||
|
|
||||||
|
response = registry.register(registration_request)
|
||||||
|
response["client_id_issued_at"] = int(time.time())
|
||||||
|
response["client_secret_expires_at"] = 0
|
||||||
|
return JSONResponse(content=response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/oauth/token")
|
@app.post("/oauth/token")
|
||||||
async def oauth_token_proxy(request: Request) -> JSONResponse:
|
async def oauth_token_proxy(request: Request) -> JSONResponse:
|
||||||
"""Proxy OAuth2 token exchange to Gitea.
|
"""Proxy OAuth2 token exchange to Gitea.
|
||||||
|
|
||||||
ChatGPT sends the authorization code here after the user logs in to Gitea.
|
The client sends the authorization code here after the user logs in to Gitea.
|
||||||
This endpoint forwards the code to Gitea's token endpoint and returns the
|
This endpoint forwards the code to Gitea's token endpoint and returns the
|
||||||
access_token to ChatGPT, completing the OAuth2 Authorization Code flow.
|
access_token to the client, completing the OAuth2 Authorization Code flow.
|
||||||
"""
|
"""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
registry = get_oauth_client_registry(settings.dcr_storage_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
form_data = await request.form()
|
form_data = await request.form()
|
||||||
@@ -683,32 +987,45 @@ async def oauth_token_proxy(request: Request) -> JSONResponse:
|
|||||||
# URI to Gitea, we must use the same URI here — not the client's original redirect_uri.
|
# URI to Gitea, we must use the same URI here — not the client's original redirect_uri.
|
||||||
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
base_url = settings.public_base or str(request.base_url).rstrip("/")
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
return _oauth_invalid_client_response()
|
||||||
|
if not registry.validate_client_secret(
|
||||||
|
client_id,
|
||||||
|
client_secret or None,
|
||||||
|
fallback_client_id=settings.gitea_oauth_client_id,
|
||||||
|
fallback_client_secret=settings.gitea_oauth_client_secret,
|
||||||
|
):
|
||||||
|
return _oauth_invalid_client_response()
|
||||||
|
|
||||||
gitea_token_url = f"{settings.gitea_base_url}/login/oauth/access_token"
|
gitea_token_url = f"{settings.gitea_base_url}/login/oauth/access_token"
|
||||||
|
upstream_client_id = settings.gitea_oauth_client_id
|
||||||
|
upstream_client_secret = settings.gitea_oauth_client_secret
|
||||||
|
|
||||||
if grant_type == "refresh_token":
|
if grant_type == "refresh_token":
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
raise HTTPException(status_code=400, detail="Missing refresh_token")
|
raise HTTPException(status_code=400, detail="Missing refresh_token")
|
||||||
payload: dict[str, str] = {
|
payload: dict[str, str] = {
|
||||||
"client_id": client_id,
|
"client_id": upstream_client_id,
|
||||||
"client_secret": client_secret,
|
"client_secret": upstream_client_secret,
|
||||||
"grant_type": "refresh_token",
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
if not code:
|
if not code:
|
||||||
raise HTTPException(status_code=400, detail="Missing authorization code")
|
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||||
|
if not code_verifier:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing code_verifier")
|
||||||
payload = {
|
payload = {
|
||||||
"client_id": client_id,
|
"client_id": upstream_client_id,
|
||||||
"client_secret": client_secret,
|
"client_secret": upstream_client_secret,
|
||||||
"code": code,
|
"code": code,
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"redirect_uri": f"{base_url}/oauth/callback",
|
"redirect_uri": f"{base_url}/oauth/callback",
|
||||||
}
|
}
|
||||||
if code_verifier:
|
payload["code_verifier"] = code_verifier
|
||||||
payload["code_verifier"] = code_verifier
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
async with httpx.AsyncClient(timeout=settings.request_timeout_seconds) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
gitea_token_url,
|
gitea_token_url,
|
||||||
data=payload,
|
data=payload,
|
||||||
@@ -846,6 +1163,29 @@ async def _execute_tool_call(
|
|||||||
if not user_token:
|
if not user_token:
|
||||||
raise HTTPException(status_code=401, detail="Missing authenticated user token context")
|
raise HTTPException(status_code=401, detail="Missing authenticated user token context")
|
||||||
|
|
||||||
|
if settings.gitea_token.strip():
|
||||||
|
if not repository:
|
||||||
|
audit.log_access_denied(
|
||||||
|
tool_name=tool_name,
|
||||||
|
reason="service_pat_requires_repository_target",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=(
|
||||||
|
"Service PAT mode requires a repository target so per-user "
|
||||||
|
"permission can be verified."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
user_login = get_gitea_user_login()
|
||||||
|
await _verify_user_repository_access(
|
||||||
|
repository=repository,
|
||||||
|
required_scope=required_scope,
|
||||||
|
user_login=user_login or "",
|
||||||
|
correlation_id=correlation_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
# In OAuth mode, Gitea OIDC access_tokens can't call the Gitea REST API
|
# In OAuth mode, Gitea OIDC access_tokens can't call the Gitea REST API
|
||||||
# (they only carry OIDC scopes). If a service PAT is configured via
|
# (they only carry OIDC scopes). If a service PAT is configured via
|
||||||
# GITEA_TOKEN, use that for API calls while OIDC handles identity/authz.
|
# GITEA_TOKEN, use that for API calls while OIDC handles identity/authz.
|
||||||
@@ -970,6 +1310,7 @@ async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/mcp")
|
||||||
@app.get("/mcp/sse")
|
@app.get("/mcp/sse")
|
||||||
async def sse_endpoint(request: Request) -> StreamingResponse:
|
async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||||
"""Server-Sent Events endpoint for MCP transport."""
|
"""Server-Sent Events endpoint for MCP transport."""
|
||||||
@@ -1004,6 +1345,7 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/mcp")
|
||||||
@app.post("/mcp/sse")
|
@app.post("/mcp/sse")
|
||||||
async def sse_message_handler(request: Request) -> JSONResponse:
|
async def sse_message_handler(request: Request) -> JSONResponse:
|
||||||
"""Handle POST messages for MCP SSE transport."""
|
"""Handle POST messages for MCP SSE transport."""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from aegis_gitea_mcp.tools.arguments import (
|
|||||||
|
|
||||||
|
|
||||||
async def list_repositories_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
async def list_repositories_tool(gitea: GiteaClient, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""List repositories visible to the bot user.
|
"""List repositories visible to the active Gitea API token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gitea: Initialized Gitea client.
|
gitea: Initialized Gitea client.
|
||||||
|
|||||||
@@ -9,9 +9,11 @@ from aegis_gitea_mcp.audit import reset_audit_logger
|
|||||||
from aegis_gitea_mcp.auth import reset_validator
|
from aegis_gitea_mcp.auth import reset_validator
|
||||||
from aegis_gitea_mcp.config import reset_settings
|
from aegis_gitea_mcp.config import reset_settings
|
||||||
from aegis_gitea_mcp.oauth import reset_oauth_validator
|
from aegis_gitea_mcp.oauth import reset_oauth_validator
|
||||||
|
from aegis_gitea_mcp.oauth_flow import reset_oauth_client_registry
|
||||||
from aegis_gitea_mcp.observability import reset_metrics_registry
|
from aegis_gitea_mcp.observability import reset_metrics_registry
|
||||||
from aegis_gitea_mcp.policy import reset_policy_engine
|
from aegis_gitea_mcp.policy import reset_policy_engine
|
||||||
from aegis_gitea_mcp.rate_limit import reset_rate_limiter
|
from aegis_gitea_mcp.rate_limit import reset_rate_limiter
|
||||||
|
from aegis_gitea_mcp.server import reset_repo_authz_cache
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -22,6 +24,8 @@ def reset_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[
|
|||||||
reset_audit_logger()
|
reset_audit_logger()
|
||||||
reset_validator()
|
reset_validator()
|
||||||
reset_oauth_validator()
|
reset_oauth_validator()
|
||||||
|
reset_oauth_client_registry()
|
||||||
|
reset_repo_authz_cache()
|
||||||
reset_policy_engine()
|
reset_policy_engine()
|
||||||
reset_rate_limiter()
|
reset_rate_limiter()
|
||||||
reset_metrics_registry()
|
reset_metrics_registry()
|
||||||
@@ -37,6 +41,8 @@ def reset_globals(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[
|
|||||||
reset_audit_logger()
|
reset_audit_logger()
|
||||||
reset_validator()
|
reset_validator()
|
||||||
reset_oauth_validator()
|
reset_oauth_validator()
|
||||||
|
reset_oauth_client_registry()
|
||||||
|
reset_repo_authz_cache()
|
||||||
reset_policy_engine()
|
reset_policy_engine()
|
||||||
reset_rate_limiter()
|
reset_rate_limiter()
|
||||||
reset_metrics_registry()
|
reset_metrics_registry()
|
||||||
@@ -66,4 +72,5 @@ def mock_env_oauth(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ def full_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("ENVIRONMENT", "test")
|
monkeypatch.setenv("ENVIRONMENT", "test")
|
||||||
monkeypatch.setenv("MCP_HOST", "127.0.0.1")
|
monkeypatch.setenv("MCP_HOST", "127.0.0.1")
|
||||||
monkeypatch.setenv("MCP_PORT", "8080")
|
monkeypatch.setenv("MCP_PORT", "8080")
|
||||||
|
|||||||
+240
-4
@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
|
|||||||
|
|
||||||
from aegis_gitea_mcp.config import reset_settings
|
from aegis_gitea_mcp.config import reset_settings
|
||||||
from aegis_gitea_mcp.oauth import GiteaOAuthValidator, get_oauth_validator, reset_oauth_validator
|
from aegis_gitea_mcp.oauth import GiteaOAuthValidator, get_oauth_validator, reset_oauth_validator
|
||||||
|
from aegis_gitea_mcp.oauth_flow import OAuthClientRegistry, OAuthRegistrationRequest
|
||||||
from aegis_gitea_mcp.request_context import (
|
from aegis_gitea_mcp.request_context import (
|
||||||
get_gitea_user_login,
|
get_gitea_user_login,
|
||||||
get_gitea_user_token,
|
get_gitea_user_token,
|
||||||
@@ -40,6 +41,7 @@ def mock_env_oauth(monkeypatch):
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||||
|
|
||||||
|
|
||||||
@@ -57,6 +59,24 @@ def oauth_client(mock_env_oauth):
|
|||||||
return TestClient(app, raise_server_exceptions=False)
|
return TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_public_client(oauth_client: TestClient, redirect_uri: str) -> dict[str, str]:
|
||||||
|
"""Register a public OAuth client for test flows."""
|
||||||
|
response = oauth_client.post(
|
||||||
|
"/register",
|
||||||
|
json={
|
||||||
|
"client_name": "pytest-client",
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert "client_id" in payload
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GiteaOAuthValidator unit tests
|
# GiteaOAuthValidator unit tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -248,19 +268,39 @@ def test_oauth_token_endpoint_available_when_oauth_mode_false(monkeypatch):
|
|||||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
with TestClient(app, raise_server_exceptions=False) as client:
|
with TestClient(app, raise_server_exceptions=False) as client:
|
||||||
response = client.post("/oauth/token", data={"code": "abc123"})
|
registration = client.post(
|
||||||
|
"/register",
|
||||||
|
json={
|
||||||
|
"client_name": "pytest-client",
|
||||||
|
"redirect_uris": ["http://127.0.0.1:8080/callback"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert registration.status_code == 200
|
||||||
|
client_id = registration.json()["client_id"]
|
||||||
|
response = client.post(
|
||||||
|
"/oauth/token",
|
||||||
|
data={"client_id": client_id, "code": "abc123", "code_verifier": "pkce"},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
def test_oauth_token_endpoint_missing_code(oauth_client):
|
def test_oauth_token_endpoint_missing_code(oauth_client):
|
||||||
"""POST /oauth/token without a code returns 400."""
|
"""POST /oauth/token without a code returns 400."""
|
||||||
response = oauth_client.post("/oauth/token", data={})
|
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||||
|
response = oauth_client.post(
|
||||||
|
"/oauth/token",
|
||||||
|
data={"client_id": client_data["client_id"], "code_verifier": "pkce"},
|
||||||
|
)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
def test_oauth_token_endpoint_proxy_success(oauth_client):
|
def test_oauth_token_endpoint_proxy_success(oauth_client):
|
||||||
"""POST /oauth/token proxies successfully to Gitea and returns access_token."""
|
"""POST /oauth/token proxies successfully to Gitea and returns access_token."""
|
||||||
|
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
@@ -276,7 +316,11 @@ def test_oauth_token_endpoint_proxy_success(oauth_client):
|
|||||||
|
|
||||||
response = oauth_client.post(
|
response = oauth_client.post(
|
||||||
"/oauth/token",
|
"/oauth/token",
|
||||||
data={"code": "auth-code-123", "redirect_uri": "https://chat.openai.com/callback"},
|
data={
|
||||||
|
"client_id": client_data["client_id"],
|
||||||
|
"code": "auth-code-123",
|
||||||
|
"code_verifier": "pkce-verifier",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -286,6 +330,7 @@ def test_oauth_token_endpoint_proxy_success(oauth_client):
|
|||||||
|
|
||||||
def test_oauth_token_endpoint_gitea_error(oauth_client):
|
def test_oauth_token_endpoint_gitea_error(oauth_client):
|
||||||
"""POST /oauth/token propagates Gitea error status."""
|
"""POST /oauth/token propagates Gitea error status."""
|
||||||
|
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 400
|
mock_response.status_code = 400
|
||||||
mock_response.json.return_value = {"error": "invalid_grant"}
|
mock_response.json.return_value = {"error": "invalid_grant"}
|
||||||
@@ -296,11 +341,202 @@ def test_oauth_token_endpoint_gitea_error(oauth_client):
|
|||||||
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
response = oauth_client.post("/oauth/token", data={"code": "bad-code"})
|
response = oauth_client.post(
|
||||||
|
"/oauth/token",
|
||||||
|
data={
|
||||||
|
"client_id": client_data["client_id"],
|
||||||
|
"code": "bad-code",
|
||||||
|
"code_verifier": "pkce-verifier",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_authorize_and_callback_round_trip(oauth_client):
|
||||||
|
"""OAuth authorize/callback round-trip preserves the original redirect URI and state."""
|
||||||
|
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||||
|
|
||||||
|
authorize_response = oauth_client.get(
|
||||||
|
"/oauth/authorize",
|
||||||
|
params={
|
||||||
|
"client_id": client_data["client_id"],
|
||||||
|
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||||
|
"state": "original-state",
|
||||||
|
"code_challenge": "pkce-challenge",
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert authorize_response.status_code == 302
|
||||||
|
location = authorize_response.headers["location"]
|
||||||
|
assert "state=" in location
|
||||||
|
assert "redirect_uri=http%3A%2F%2F127.0.0.1%3A8080%2Fcallback" not in location
|
||||||
|
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(location)
|
||||||
|
query = parse_qs(parsed.query)
|
||||||
|
proxy_state = query["state"][0]
|
||||||
|
|
||||||
|
callback_response = oauth_client.get(
|
||||||
|
"/oauth/callback",
|
||||||
|
params={"state": proxy_state, "code": "auth-code-123"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callback_response.status_code == 302
|
||||||
|
callback_location = callback_response.headers["location"]
|
||||||
|
assert callback_location.startswith("http://127.0.0.1:8080/callback?")
|
||||||
|
assert "code=auth-code-123" in callback_location
|
||||||
|
assert "state=original-state" in callback_location
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_callback_rejects_tampered_state(oauth_client):
|
||||||
|
"""OAuth callback rejects modified signed proxy state."""
|
||||||
|
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||||
|
authorize_response = oauth_client.get(
|
||||||
|
"/oauth/authorize",
|
||||||
|
params={
|
||||||
|
"client_id": client_data["client_id"],
|
||||||
|
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||||
|
"state": "original-state",
|
||||||
|
"code_challenge": "pkce-challenge",
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
proxy_state = parse_qs(urlparse(authorize_response.headers["location"]).query)["state"][0]
|
||||||
|
tampered_state = proxy_state[:-1] + ("A" if proxy_state[-1] != "A" else "B")
|
||||||
|
|
||||||
|
callback_response = oauth_client.get(
|
||||||
|
"/oauth/callback",
|
||||||
|
params={"state": tampered_state, "code": "auth-code-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callback_response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"redirect_uri",
|
||||||
|
[
|
||||||
|
"https://claude.ai/api/mcp/auth_callback",
|
||||||
|
"https://claude.com/api/mcp/auth_callback",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dcr_accepts_default_claude_callbacks(oauth_client, redirect_uri):
|
||||||
|
"""Claude's hosted connector callback URLs are allowed by default."""
|
||||||
|
response = oauth_client.post(
|
||||||
|
"/register",
|
||||||
|
json={
|
||||||
|
"client_name": "claude-client",
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_authorize_rejects_unknown_client(oauth_client):
|
||||||
|
"""OAuth authorize returns invalid_client for unregistered client IDs."""
|
||||||
|
response = oauth_client.get(
|
||||||
|
"/oauth/authorize",
|
||||||
|
params={
|
||||||
|
"client_id": "unknown-client",
|
||||||
|
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||||
|
"state": "x",
|
||||||
|
"code_challenge": "pkce-challenge",
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json()["detail"] == "invalid_client"
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_token_rejects_unknown_dcr_client(oauth_client):
|
||||||
|
"""Unknown dynamic clients receive RFC 6749 invalid_client from token endpoint."""
|
||||||
|
response = oauth_client.post(
|
||||||
|
"/oauth/token",
|
||||||
|
data={
|
||||||
|
"client_id": "deleted-or-unknown-client",
|
||||||
|
"code": "auth-code-123",
|
||||||
|
"code_verifier": "pkce-verifier",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {"error": "invalid_client"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_authorize_requires_pkce_s256(oauth_client):
|
||||||
|
"""Authorization endpoint enforces PKCE S256 for public clients."""
|
||||||
|
client_data = _register_public_client(oauth_client, "http://127.0.0.1:8080/callback")
|
||||||
|
missing_challenge = oauth_client.get(
|
||||||
|
"/oauth/authorize",
|
||||||
|
params={
|
||||||
|
"client_id": client_data["client_id"],
|
||||||
|
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||||
|
"state": "x",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
plain_method = oauth_client.get(
|
||||||
|
"/oauth/authorize",
|
||||||
|
params={
|
||||||
|
"client_id": client_data["client_id"],
|
||||||
|
"redirect_uri": "http://127.0.0.1:8080/callback",
|
||||||
|
"state": "x",
|
||||||
|
"code_challenge": "pkce-challenge",
|
||||||
|
"code_challenge_method": "plain",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert missing_challenge.status_code == 400
|
||||||
|
assert plain_method.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_rejects_foreign_redirect_uri(oauth_client):
|
||||||
|
"""DCR rejects redirect URIs outside the allowlist and loopback/Claude patterns."""
|
||||||
|
response = oauth_client.post(
|
||||||
|
"/register",
|
||||||
|
json={
|
||||||
|
"client_name": "pytest-client",
|
||||||
|
"redirect_uris": ["https://evil.example.com/callback"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_dcr_registry_persists_registered_clients(tmp_path):
|
||||||
|
"""Registered OAuth clients survive registry reloads."""
|
||||||
|
storage_path = tmp_path / "dcr_clients.json"
|
||||||
|
registry = OAuthClientRegistry(storage_path)
|
||||||
|
request = OAuthRegistrationRequest.model_validate(
|
||||||
|
{
|
||||||
|
"client_name": "persisted-client",
|
||||||
|
"redirect_uris": ["http://127.0.0.1:8080/callback"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
"grant_types": ["authorization_code"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = registry.register(request)
|
||||||
|
reloaded = OAuthClientRegistry(storage_path)
|
||||||
|
|
||||||
|
assert reloaded.get(response["client_id"]) is not None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Config validation tests
|
# Config validation tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -25,13 +25,14 @@ def reset_state(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("OAUTH_CACHE_TTL_SECONDS", "600")
|
monkeypatch.setenv("OAUTH_CACHE_TTL_SECONDS", "600")
|
||||||
yield
|
yield
|
||||||
reset_settings()
|
reset_settings()
|
||||||
reset_oauth_validator()
|
reset_oauth_validator()
|
||||||
|
|
||||||
|
|
||||||
def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
|
def _build_jwt_fixture(aud: str = "test-client-id") -> tuple[str, dict[str, object]]:
|
||||||
"""Generate RS256 access token and matching JWKS payload."""
|
"""Generate RS256 access token and matching JWKS payload."""
|
||||||
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||||
public_key = private_key.public_key()
|
public_key = private_key.public_key()
|
||||||
@@ -44,7 +45,7 @@ def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
|
|||||||
"sub": "user-1",
|
"sub": "user-1",
|
||||||
"preferred_username": "alice",
|
"preferred_username": "alice",
|
||||||
"scope": "read:repository write:repository",
|
"scope": "read:repository write:repository",
|
||||||
"aud": "test-client-id",
|
"aud": aud,
|
||||||
"iss": "https://gitea.example.com",
|
"iss": "https://gitea.example.com",
|
||||||
"iat": now,
|
"iat": now,
|
||||||
"exp": now + 3600,
|
"exp": now + 3600,
|
||||||
@@ -56,6 +57,70 @@ def _build_jwt_fixture() -> tuple[str, dict[str, object]]:
|
|||||||
return token, {"keys": [jwk]}
|
return token, {"keys": [jwk]}
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_with_jwks(
|
||||||
|
validator: GiteaOAuthValidator, token: str, jwks: dict[str, object]
|
||||||
|
) -> tuple[bool, str | None, dict[str, object] | None]:
|
||||||
|
"""Drive a JWT validation with mocked discovery + JWKS responses."""
|
||||||
|
discovery_response = MagicMock()
|
||||||
|
discovery_response.status_code = 200
|
||||||
|
discovery_response.json.return_value = {
|
||||||
|
"issuer": "https://gitea.example.com",
|
||||||
|
"jwks_uri": "https://gitea.example.com/login/oauth/keys",
|
||||||
|
}
|
||||||
|
jwks_response = MagicMock()
|
||||||
|
jwks_response.status_code = 200
|
||||||
|
jwks_response.json.return_value = jwks
|
||||||
|
|
||||||
|
with patch("aegis_gitea_mcp.oauth.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(side_effect=[discovery_response, jwks_response])
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
return await validator.validate_oauth_token(token, "127.0.0.1", "TestAgent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_acceptable_audiences_includes_resource_and_client_id(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""The canonical MCP resource and the Gitea client id are accepted audiences."""
|
||||||
|
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
||||||
|
reset_settings()
|
||||||
|
reset_oauth_validator()
|
||||||
|
audiences = GiteaOAuthValidator()._acceptable_audiences()
|
||||||
|
assert "https://mcp.example.com" in audiences
|
||||||
|
assert "test-client-id" in audiences
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jwt_with_canonical_resource_audience_is_accepted(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""A token whose aud is the canonical MCP resource URL validates (P4)."""
|
||||||
|
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
||||||
|
reset_settings()
|
||||||
|
reset_oauth_validator()
|
||||||
|
token, jwks = _build_jwt_fixture(aud="https://mcp.example.com")
|
||||||
|
valid, error, principal = await _validate_with_jwks(GiteaOAuthValidator(), token, jwks)
|
||||||
|
assert valid is True
|
||||||
|
assert error is None
|
||||||
|
assert principal is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_jwt_with_foreign_audience_is_rejected() -> None:
|
||||||
|
"""A token minted for a different audience is rejected (audience binding)."""
|
||||||
|
token, jwks = _build_jwt_fixture(aud="some-other-service")
|
||||||
|
# Foreign-audience JWT fails JWT validation, then falls back to userinfo, which
|
||||||
|
# is not mocked here and raises a network error -> overall failure.
|
||||||
|
with patch("aegis_gitea_mcp.oauth.GiteaOAuthValidator._validate_userinfo") as mock_userinfo:
|
||||||
|
from aegis_gitea_mcp.oauth import OAuthTokenValidationError
|
||||||
|
|
||||||
|
mock_userinfo.side_effect = OAuthTokenValidationError("Invalid", "userinfo_denied")
|
||||||
|
valid, _error, principal = await _validate_with_jwks(GiteaOAuthValidator(), token, jwks)
|
||||||
|
assert valid is False
|
||||||
|
assert principal is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_validate_oauth_token_with_oidc_jwt_and_cache() -> None:
|
async def test_validate_oauth_token_with_oidc_jwt_and_cache() -> None:
|
||||||
"""JWT token validation uses discovery + JWKS and caches both documents."""
|
"""JWT token validation uses discovery + JWKS and caches both documents."""
|
||||||
|
|||||||
+212
-7
@@ -29,6 +29,7 @@ def oauth_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("ENVIRONMENT", "test")
|
monkeypatch.setenv("ENVIRONMENT", "test")
|
||||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||||
monkeypatch.setenv("WRITE_MODE", "false")
|
monkeypatch.setenv("WRITE_MODE", "false")
|
||||||
@@ -84,12 +85,13 @@ def test_health_endpoint(client: TestClient) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_oauth_protected_resource_metadata(client: TestClient) -> None:
|
def test_oauth_protected_resource_metadata(client: TestClient) -> None:
|
||||||
"""OAuth protected-resource metadata contains required OpenAI-compatible fields."""
|
"""PRM advertises THIS server's canonical URL as the protected resource."""
|
||||||
response = client.get("/.well-known/oauth-protected-resource")
|
response = client.get("/.well-known/oauth-protected-resource")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["resource"] == "https://gitea.example.com"
|
# RFC 9728/8707: the resource identifier is the MCP server's own URL, not Gitea's.
|
||||||
|
assert data["resource"] == "http://testserver"
|
||||||
assert data["authorization_servers"] == [
|
assert data["authorization_servers"] == [
|
||||||
"http://testserver",
|
"http://testserver",
|
||||||
"https://gitea.example.com",
|
"https://gitea.example.com",
|
||||||
@@ -100,12 +102,15 @@ def test_oauth_protected_resource_metadata(client: TestClient) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_oauth_authorization_server_metadata(client: TestClient) -> None:
|
def test_oauth_authorization_server_metadata(client: TestClient) -> None:
|
||||||
"""Auth server metadata includes expected OAuth endpoints and scopes."""
|
"""Auth server metadata advertises this server's proxy OAuth endpoints."""
|
||||||
response = client.get("/.well-known/oauth-authorization-server")
|
response = client.get("/.well-known/oauth-authorization-server")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
payload = response.json()
|
payload = response.json()
|
||||||
assert payload["authorization_endpoint"].endswith("/login/oauth/authorize")
|
# Claude must be sent to our proxy authorize endpoint (Gitea does not know
|
||||||
assert payload["token_endpoint"].endswith("/oauth/token")
|
# Claude's redirect_uri), so the endpoint lives on this server.
|
||||||
|
assert payload["authorization_endpoint"] == "http://testserver/oauth/authorize"
|
||||||
|
assert payload["token_endpoint"] == "http://testserver/oauth/token"
|
||||||
|
assert payload["registration_endpoint"] == "http://testserver/register"
|
||||||
assert payload["scopes_supported"] == ["read:repository", "write:repository"]
|
assert payload["scopes_supported"] == ["read:repository", "write:repository"]
|
||||||
|
|
||||||
|
|
||||||
@@ -115,8 +120,8 @@ def test_openid_configuration_metadata(client: TestClient) -> None:
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
payload = response.json()
|
payload = response.json()
|
||||||
assert payload["issuer"] == "https://gitea.example.com"
|
assert payload["issuer"] == "https://gitea.example.com"
|
||||||
assert payload["authorization_endpoint"].endswith("/login/oauth/authorize")
|
assert payload["authorization_endpoint"] == "http://testserver/oauth/authorize"
|
||||||
assert payload["token_endpoint"].endswith("/oauth/token")
|
assert payload["token_endpoint"] == "http://testserver/oauth/token"
|
||||||
assert payload["userinfo_endpoint"].endswith("/login/oauth/userinfo")
|
assert payload["userinfo_endpoint"].endswith("/login/oauth/userinfo")
|
||||||
assert payload["jwks_uri"].endswith("/login/oauth/keys")
|
assert payload["jwks_uri"].endswith("/login/oauth/keys")
|
||||||
assert "read:repository" in payload["scopes_supported"]
|
assert "read:repository" in payload["scopes_supported"]
|
||||||
@@ -129,6 +134,7 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
monkeypatch.setenv("PUBLIC_BASE_URL", "https://mcp.example.com")
|
||||||
monkeypatch.setenv("ENVIRONMENT", "test")
|
monkeypatch.setenv("ENVIRONMENT", "test")
|
||||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "false")
|
||||||
@@ -149,6 +155,8 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
|
|||||||
protected_response = client.get("/.well-known/oauth-protected-resource")
|
protected_response = client.get("/.well-known/oauth-protected-resource")
|
||||||
assert protected_response.status_code == 200
|
assert protected_response.status_code == 200
|
||||||
protected_payload = protected_response.json()
|
protected_payload = protected_response.json()
|
||||||
|
# P4: the protected resource identifier must equal this server's public base.
|
||||||
|
assert protected_payload["resource"] == "https://mcp.example.com"
|
||||||
assert protected_payload["authorization_servers"] == [
|
assert protected_payload["authorization_servers"] == [
|
||||||
"https://mcp.example.com",
|
"https://mcp.example.com",
|
||||||
"https://gitea.example.com",
|
"https://gitea.example.com",
|
||||||
@@ -166,6 +174,201 @@ def test_oauth_metadata_uses_public_base_url(monkeypatch: pytest.MonkeyPatch) ->
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_streamable_http_path_works(client: TestClient) -> None:
|
||||||
|
"""The spec path /mcp exposes the same transport behavior as the SSE alias."""
|
||||||
|
response = client.post(
|
||||||
|
"/mcp",
|
||||||
|
headers={"Authorization": "Bearer valid-read"},
|
||||||
|
json={"jsonrpc": "2.0", "id": "init-1", "method": "initialize"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
payload = response.json()
|
||||||
|
assert payload["result"]["protocolVersion"] == "2024-11-05"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_preflight_allows_same_origin(client: TestClient) -> None:
|
||||||
|
"""Same-origin preflight requests to /mcp return strict CORS headers."""
|
||||||
|
response = client.options(
|
||||||
|
"/mcp",
|
||||||
|
headers={
|
||||||
|
"Origin": "http://testserver",
|
||||||
|
"Access-Control-Request-Method": "POST",
|
||||||
|
"Access-Control-Request-Headers": "authorization,content-type",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
assert response.headers["Access-Control-Allow-Origin"] == "http://testserver"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_preflight_rejects_cross_origin(client: TestClient) -> None:
|
||||||
|
"""Cross-origin browser requests to /mcp are denied."""
|
||||||
|
response = client.options(
|
||||||
|
"/mcp",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://evil.example.com",
|
||||||
|
"Access-Control-Request-Method": "POST",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_service_pat_requests_verify_user_repo_access_before_execution(
|
||||||
|
oauth_env: None, mock_oauth_validation: None, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""Service PAT fallback checks the user's repository permission before executing tools."""
|
||||||
|
from aegis_gitea_mcp import server
|
||||||
|
|
||||||
|
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||||
|
server._api_scope_cache.clear()
|
||||||
|
server.reset_repo_authz_cache()
|
||||||
|
|
||||||
|
probe_response = MagicMock()
|
||||||
|
probe_response.status_code = 200
|
||||||
|
|
||||||
|
repo_response = MagicMock()
|
||||||
|
repo_response.status_code = 403
|
||||||
|
|
||||||
|
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(side_effect=[probe_response, repo_response])
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
from aegis_gitea_mcp.server import app
|
||||||
|
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.post(
|
||||||
|
"/mcp/tool/call",
|
||||||
|
headers={"Authorization": "Bearer valid-read"},
|
||||||
|
json={
|
||||||
|
"tool": "get_repository_info",
|
||||||
|
"arguments": {"owner": "acme", "repo": "demo"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "permission" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_pat_repo_authz_allows_user_with_read_permission(
|
||||||
|
oauth_env: None, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""Read-level collaborator permission allows service PAT execution to proceed."""
|
||||||
|
from aegis_gitea_mcp import server
|
||||||
|
|
||||||
|
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||||
|
server.reset_repo_authz_cache()
|
||||||
|
|
||||||
|
permission_response = MagicMock()
|
||||||
|
permission_response.status_code = 200
|
||||||
|
permission_response.json.return_value = {"permission": "read"}
|
||||||
|
|
||||||
|
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=permission_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
await server._verify_user_repository_access(
|
||||||
|
repository="acme/demo",
|
||||||
|
required_scope=server.READ_SCOPE,
|
||||||
|
user_login="alice",
|
||||||
|
correlation_id="corr-1",
|
||||||
|
tool_name="get_repository_info",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.get.assert_awaited_once()
|
||||||
|
requested_url = mock_client.get.await_args.args[0]
|
||||||
|
requested_headers = mock_client.get.await_args.kwargs["headers"]
|
||||||
|
assert requested_url.endswith("/api/v1/repos/acme/demo/collaborators/alice/permission")
|
||||||
|
assert requested_headers["Authorization"] == "token service-pat-token"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_pat_repo_authz_denies_read_user_for_write_tool(
|
||||||
|
oauth_env: None, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""Read permission is insufficient for write tools in service PAT mode."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from aegis_gitea_mcp import server
|
||||||
|
|
||||||
|
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||||
|
server.reset_repo_authz_cache()
|
||||||
|
|
||||||
|
permission_response = MagicMock()
|
||||||
|
permission_response.status_code = 200
|
||||||
|
permission_response.json.return_value = {"permission": "read"}
|
||||||
|
|
||||||
|
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=permission_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await server._verify_user_repository_access(
|
||||||
|
repository="acme/demo",
|
||||||
|
required_scope=server.WRITE_SCOPE,
|
||||||
|
user_login="alice",
|
||||||
|
correlation_id="corr-1",
|
||||||
|
tool_name="create_issue",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_pat_repo_authz_cache_hit_and_expiry(
|
||||||
|
oauth_env: None, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""Repository permission decisions are cached briefly and rechecked after expiry."""
|
||||||
|
from aegis_gitea_mcp import cache as cache_module
|
||||||
|
from aegis_gitea_mcp import server
|
||||||
|
|
||||||
|
monkeypatch.setenv("GITEA_TOKEN", "service-pat-token")
|
||||||
|
monkeypatch.setenv("REPO_AUTHZ_CACHE_TTL_SECONDS", "1")
|
||||||
|
server.reset_repo_authz_cache()
|
||||||
|
|
||||||
|
now = 1000.0
|
||||||
|
monkeypatch.setattr(cache_module.time, "monotonic", lambda: now)
|
||||||
|
|
||||||
|
permission_response = MagicMock()
|
||||||
|
permission_response.status_code = 200
|
||||||
|
permission_response.json.return_value = {"permission": "read"}
|
||||||
|
|
||||||
|
with patch("aegis_gitea_mcp.server.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=permission_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
await server._verify_user_repository_access(
|
||||||
|
repository="acme/demo",
|
||||||
|
required_scope=server.READ_SCOPE,
|
||||||
|
user_login="alice",
|
||||||
|
correlation_id="corr-1",
|
||||||
|
tool_name="get_repository_info",
|
||||||
|
)
|
||||||
|
assert mock_client.get.await_count == 1
|
||||||
|
|
||||||
|
now = 1002.0
|
||||||
|
await server._verify_user_repository_access(
|
||||||
|
repository="acme/demo",
|
||||||
|
required_scope=server.READ_SCOPE,
|
||||||
|
user_login="alice",
|
||||||
|
correlation_id="corr-1",
|
||||||
|
tool_name="get_repository_info",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.get.await_count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_scope_compatibility_write_implies_read() -> None:
|
def test_scope_compatibility_write_implies_read() -> None:
|
||||||
"""write:repository grants read-level access for read tools."""
|
"""write:repository grants read-level access for read tools."""
|
||||||
from aegis_gitea_mcp.server import READ_SCOPE, _has_required_scope
|
from aegis_gitea_mcp.server import READ_SCOPE, _has_required_scope
|
||||||
@@ -348,6 +551,7 @@ async def test_startup_event_fails_when_discovery_unreachable(
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
|
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
|
||||||
|
|
||||||
from aegis_gitea_mcp import server
|
from aegis_gitea_mcp import server
|
||||||
@@ -377,6 +581,7 @@ async def test_startup_event_succeeds_when_discovery_ready(
|
|||||||
monkeypatch.setenv("OAUTH_MODE", "true")
|
monkeypatch.setenv("OAUTH_MODE", "true")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_ID", "test-client-id")
|
||||||
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
monkeypatch.setenv("GITEA_OAUTH_CLIENT_SECRET", "test-client-secret")
|
||||||
|
monkeypatch.setenv("OAUTH_STATE_SECRET", "test-state-secret-0123456789abcdef")
|
||||||
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
|
monkeypatch.setenv("STARTUP_VALIDATE_GITEA", "true")
|
||||||
|
|
||||||
from aegis_gitea_mcp import server
|
from aegis_gitea_mcp import server
|
||||||
|
|||||||
Reference in New Issue
Block a user