"""OAuth proxy helpers for signed state, redirect validation, and DCR storage.""" from __future__ import annotations import base64 import hashlib import hmac import json import secrets import time from fnmatch import fnmatchcase from pathlib import Path from typing import Any from urllib.parse import ParseResult, urlparse, urlunparse from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator _CLAUDE_CALLBACK_URIS = { "https://claude.ai/api/mcp/auth_callback", "https://claude.com/api/mcp/auth_callback", } _LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"} _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {"none", "client_secret_post"} _SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"} _SUPPORTED_RESPONSE_TYPES = {"code"} class OAuthRegistrationRequest(BaseModel): """Incoming RFC 7591 client registration request.""" client_name: str | None = Field(default=None, max_length=200) redirect_uris: list[str] = Field(..., min_length=1) grant_types: list[str] = Field(default_factory=lambda: ["authorization_code", "refresh_token"]) response_types: list[str] = Field(default_factory=lambda: ["code"]) token_endpoint_auth_method: str = Field(default="none", max_length=64) scope: str | None = Field(default=None, max_length=512) model_config = ConfigDict(extra="forbid") @field_validator("redirect_uris") @classmethod def validate_redirect_uris(cls, value: list[str]) -> list[str]: """Normalize and validate redirect URIs.""" uris = [uri.strip() for uri in value if isinstance(uri, str) and uri.strip()] if not uris: raise ValueError("redirect_uris must contain at least one non-empty URI") return uris @field_validator("grant_types") @classmethod def validate_grant_types(cls, value: list[str]) -> list[str]: """Restrict supported grant types to authorization code and refresh token.""" normalized = [item.strip() for item in value if item.strip()] if not normalized: raise ValueError("grant_types must not be empty") if any(item not in _SUPPORTED_GRANT_TYPES for item in normalized): raise ValueError("Unsupported grant_types requested") return normalized @field_validator("response_types") @classmethod def validate_response_types(cls, value: list[str]) -> list[str]: """Restrict supported response types to authorization code.""" normalized = [item.strip() for item in value if item.strip()] if not normalized: raise ValueError("response_types must not be empty") if any(item not in _SUPPORTED_RESPONSE_TYPES for item in normalized): raise ValueError("Unsupported response_types requested") return normalized @field_validator("token_endpoint_auth_method") @classmethod def validate_token_endpoint_auth_method(cls, value: str) -> str: """Restrict token endpoint auth methods to the supported subset.""" normalized = value.strip().lower() if normalized not in _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS: raise ValueError("Unsupported token_endpoint_auth_method requested") return normalized @model_validator(mode="after") def validate_pkce_ready(self) -> OAuthRegistrationRequest: """Ensure the request is usable for PKCE-based authorization code flow.""" if "authorization_code" not in self.grant_types: raise ValueError("authorization_code grant is required") if "code" not in self.response_types: raise ValueError("code response type is required") return self class OAuthClientRecord(BaseModel): """Persisted OAuth client registration record.""" client_id: str client_name: str | None = None redirect_uris: list[str] grant_types: list[str] response_types: list[str] token_endpoint_auth_method: str client_id_issued_at: int client_secret_expires_at: int = 0 client_secret_hash: str | None = None scope: str | None = None model_config = ConfigDict(extra="forbid") def _canonicalize_url(value: str) -> str: """Normalize a URL for comparison.""" parsed = urlparse(value.strip()) if not parsed.scheme or not parsed.netloc: return "" normalized = ParseResult( scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), path=parsed.path or "/", params=parsed.params, query=parsed.query, fragment="", ) return urlunparse(normalized).rstrip("/") def is_loopback_redirect_uri(redirect_uri: str) -> bool: """Return whether a redirect URI uses a loopback host.""" parsed = urlparse(redirect_uri.strip()) if parsed.scheme != "http": return False host = (parsed.hostname or "").lower() return host in _LOOPBACK_HOSTS def is_claude_redirect_uri(redirect_uri: str) -> bool: """Return whether a redirect URI is a built-in Claude callback URL.""" return _canonicalize_url(redirect_uri) in _CLAUDE_CALLBACK_URIS def is_redirect_uri_allowed(redirect_uri: str, allowlist: list[str]) -> bool: """Return whether a redirect URI is allowed by policy.""" normalized = _canonicalize_url(redirect_uri) if not normalized: return False if is_loopback_redirect_uri(redirect_uri) or is_claude_redirect_uri(redirect_uri): return True for pattern in allowlist: candidate = pattern.strip() if not candidate: continue if fnmatchcase(normalized, _canonicalize_url(candidate) or candidate): return True if fnmatchcase(redirect_uri.strip(), candidate): return True return False def is_origin_allowed(origin: str, request_base: str, public_base: str | None) -> bool: """Return whether a browser Origin is allowed for MCP transport requests.""" normalized_origin = _canonicalize_url(origin) if not normalized_origin: return False expected_bases = [request_base.rstrip("/")] if public_base: expected_bases.append(public_base.rstrip("/")) return normalized_origin in expected_bases def encode_proxy_state( secret: str, redirect_uri: str, original_state: str, *, ttl_seconds: int = 600, ) -> str: """Create a signed OAuth state wrapper for the proxy callback round-trip.""" payload = { "redirect_uri": redirect_uri, "state": original_state, "issued_at": int(time.time()), "nonce": secrets.token_urlsafe(16), "ttl_seconds": ttl_seconds, } canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":")) signature = hmac.new(secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256) envelope = { "payload": payload, "signature": signature.hexdigest(), } return base64.urlsafe_b64encode( json.dumps(envelope, sort_keys=True, separators=(",", ":")).encode("utf-8") ).decode("ascii") def decode_proxy_state(secret: str, encoded_state: str) -> dict[str, str]: """Verify and unpack a signed OAuth state wrapper.""" try: raw = base64.urlsafe_b64decode(encoded_state.encode("ascii")) envelope = json.loads(raw) except Exception as exc: # pragma: no cover - guarded by tests raise ValueError("Invalid or missing state parameter") from exc if not isinstance(envelope, dict): raise ValueError("Invalid or missing state parameter") payload = envelope.get("payload") signature = envelope.get("signature") if not isinstance(payload, dict) or not isinstance(signature, str): raise ValueError("Invalid or missing state parameter") canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":")) expected_signature = hmac.new( secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): raise ValueError("Invalid or missing state parameter") issued_at = payload.get("issued_at") ttl_seconds = payload.get("ttl_seconds") now = int(time.time()) if not isinstance(issued_at, int) or not isinstance(ttl_seconds, int): raise ValueError("Invalid or missing state parameter") if issued_at > now or now - issued_at > max(ttl_seconds, 1): raise ValueError("Invalid or missing state parameter") redirect_uri = payload.get("redirect_uri") if not isinstance(redirect_uri, str) or not redirect_uri.strip(): raise ValueError("Invalid or missing state parameter") original_state = payload.get("state") if not isinstance(original_state, str): raise ValueError("Invalid or missing state parameter") return {"redirect_uri": redirect_uri, "state": original_state} class OAuthClientRegistry: """Persisted OAuth client registry for dynamic client registration.""" def __init__(self, storage_path: Path) -> None: """Initialize registry storage.""" self.storage_path = storage_path self.storage_path.parent.mkdir(parents=True, exist_ok=True) self._clients: dict[str, OAuthClientRecord] = {} self._loaded = False @staticmethod def _hash_secret(secret: str) -> str: """Hash client secrets before persistence.""" return hashlib.sha256(secret.encode("utf-8")).hexdigest() def _load(self) -> None: """Load persisted registrations from disk once.""" if self._loaded: return self._loaded = True if not self.storage_path.exists(): self._clients = {} return raw = json.loads(self.storage_path.read_text(encoding="utf-8")) if not isinstance(raw, dict): raise ValueError("Persisted DCR storage must be a JSON object") clients: dict[str, OAuthClientRecord] = {} for client_id, payload in raw.items(): if not isinstance(client_id, str): raise ValueError("Persisted client id must be a string") if not isinstance(payload, dict): raise ValueError(f"Persisted client record for {client_id} must be a mapping") record = OAuthClientRecord.model_validate({"client_id": client_id, **payload}) clients[client_id] = record self._clients = clients def _persist(self) -> None: """Write registrations atomically.""" payload = { client_id: record.model_dump(mode="json", exclude={"client_id"}) for client_id, record in self._clients.items() } tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp") tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2), encoding="utf-8") tmp_path.replace(self.storage_path) def get(self, client_id: str) -> OAuthClientRecord | None: """Look up a registered client by identifier.""" self._load() return self._clients.get(client_id) def is_known_client( self, client_id: str, *, fallback_client_id: str = "", fallback_client_secret: str = "", ) -> bool: """Return whether a client is recognized by the registry or environment.""" if not client_id.strip(): return False if fallback_client_id.strip() and client_id == fallback_client_id.strip(): return True return self.get(client_id) is not None def validate_client_secret( self, client_id: str, client_secret: str | None, *, fallback_client_id: str = "", fallback_client_secret: str = "", ) -> bool: """Validate a client identifier and optional secret.""" if fallback_client_id.strip() and client_id == fallback_client_id.strip(): if not fallback_client_secret.strip(): return True if not client_secret: return False return hmac.compare_digest( self._hash_secret(client_secret), self._hash_secret(fallback_client_secret.strip()) ) record = self.get(client_id) if record is None: return False if record.client_secret_hash is None: return True if not client_secret: return False return hmac.compare_digest(self._hash_secret(client_secret), record.client_secret_hash) def register(self, request: OAuthRegistrationRequest) -> dict[str, Any]: """Persist a new OAuth client registration and return its public metadata.""" self._load() client_id = secrets.token_urlsafe(24) client_secret: str | None = None client_secret_hash: str | None = None if request.token_endpoint_auth_method != "none": client_secret = secrets.token_urlsafe(32) client_secret_hash = self._hash_secret(client_secret) record = OAuthClientRecord( client_id=client_id, client_name=request.client_name, redirect_uris=list(request.redirect_uris), grant_types=list(request.grant_types), response_types=list(request.response_types), token_endpoint_auth_method=request.token_endpoint_auth_method, client_id_issued_at=int(time.time()), client_secret_hash=client_secret_hash, scope=request.scope, ) self._clients[client_id] = record self._persist() response: dict[str, Any] = record.model_dump(exclude={"client_secret_hash"}) if client_secret is not None: response["client_secret"] = client_secret response["client_secret_expires_at"] = 0 return response _oauth_client_registry: OAuthClientRegistry | None = None def get_oauth_client_registry(storage_path: Path) -> OAuthClientRegistry: """Get or create the global OAuth client registry.""" global _oauth_client_registry if _oauth_client_registry is None or _oauth_client_registry.storage_path != storage_path: _oauth_client_registry = OAuthClientRegistry(storage_path) return _oauth_client_registry def reset_oauth_client_registry() -> None: """Reset the global OAuth client registry (primarily for tests).""" global _oauth_client_registry _oauth_client_registry = None