381 lines
14 KiB
Python
381 lines
14 KiB
Python
"""OAuth proxy helpers for signed state, redirect validation, and DCR storage."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import secrets
|
|
import time
|
|
from fnmatch import fnmatchcase
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.parse import ParseResult, urlparse, urlunparse
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
|
|
_CLAUDE_CALLBACK_URIS = {
|
|
"https://claude.ai/api/mcp/auth_callback",
|
|
"https://claude.com/api/mcp/auth_callback",
|
|
}
|
|
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
|
|
_SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {"none", "client_secret_post"}
|
|
_SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"}
|
|
_SUPPORTED_RESPONSE_TYPES = {"code"}
|
|
|
|
|
|
class OAuthRegistrationRequest(BaseModel):
|
|
"""Incoming RFC 7591 client registration request."""
|
|
|
|
client_name: str | None = Field(default=None, max_length=200)
|
|
redirect_uris: list[str] = Field(..., min_length=1)
|
|
grant_types: list[str] = Field(default_factory=lambda: ["authorization_code", "refresh_token"])
|
|
response_types: list[str] = Field(default_factory=lambda: ["code"])
|
|
token_endpoint_auth_method: str = Field(default="none", max_length=64)
|
|
scope: str | None = Field(default=None, max_length=512)
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
@field_validator("redirect_uris")
|
|
@classmethod
|
|
def validate_redirect_uris(cls, value: list[str]) -> list[str]:
|
|
"""Normalize and validate redirect URIs."""
|
|
uris = [uri.strip() for uri in value if isinstance(uri, str) and uri.strip()]
|
|
if not uris:
|
|
raise ValueError("redirect_uris must contain at least one non-empty URI")
|
|
return uris
|
|
|
|
@field_validator("grant_types")
|
|
@classmethod
|
|
def validate_grant_types(cls, value: list[str]) -> list[str]:
|
|
"""Restrict supported grant types to authorization code and refresh token."""
|
|
normalized = [item.strip() for item in value if item.strip()]
|
|
if not normalized:
|
|
raise ValueError("grant_types must not be empty")
|
|
if any(item not in _SUPPORTED_GRANT_TYPES for item in normalized):
|
|
raise ValueError("Unsupported grant_types requested")
|
|
return normalized
|
|
|
|
@field_validator("response_types")
|
|
@classmethod
|
|
def validate_response_types(cls, value: list[str]) -> list[str]:
|
|
"""Restrict supported response types to authorization code."""
|
|
normalized = [item.strip() for item in value if item.strip()]
|
|
if not normalized:
|
|
raise ValueError("response_types must not be empty")
|
|
if any(item not in _SUPPORTED_RESPONSE_TYPES for item in normalized):
|
|
raise ValueError("Unsupported response_types requested")
|
|
return normalized
|
|
|
|
@field_validator("token_endpoint_auth_method")
|
|
@classmethod
|
|
def validate_token_endpoint_auth_method(cls, value: str) -> str:
|
|
"""Restrict token endpoint auth methods to the supported subset."""
|
|
normalized = value.strip().lower()
|
|
if normalized not in _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS:
|
|
raise ValueError("Unsupported token_endpoint_auth_method requested")
|
|
return normalized
|
|
|
|
@model_validator(mode="after")
|
|
def validate_pkce_ready(self) -> OAuthRegistrationRequest:
|
|
"""Ensure the request is usable for PKCE-based authorization code flow."""
|
|
if "authorization_code" not in self.grant_types:
|
|
raise ValueError("authorization_code grant is required")
|
|
if "code" not in self.response_types:
|
|
raise ValueError("code response type is required")
|
|
return self
|
|
|
|
|
|
class OAuthClientRecord(BaseModel):
|
|
"""Persisted OAuth client registration record."""
|
|
|
|
client_id: str
|
|
client_name: str | None = None
|
|
redirect_uris: list[str]
|
|
grant_types: list[str]
|
|
response_types: list[str]
|
|
token_endpoint_auth_method: str
|
|
client_id_issued_at: int
|
|
client_secret_expires_at: int = 0
|
|
client_secret_hash: str | None = None
|
|
scope: str | None = None
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
|
|
def _canonicalize_url(value: str) -> str:
|
|
"""Normalize a URL for comparison."""
|
|
parsed = urlparse(value.strip())
|
|
if not parsed.scheme or not parsed.netloc:
|
|
return ""
|
|
|
|
normalized = ParseResult(
|
|
scheme=parsed.scheme.lower(),
|
|
netloc=parsed.netloc.lower(),
|
|
path=parsed.path or "/",
|
|
params=parsed.params,
|
|
query=parsed.query,
|
|
fragment="",
|
|
)
|
|
return urlunparse(normalized).rstrip("/")
|
|
|
|
|
|
def is_loopback_redirect_uri(redirect_uri: str) -> bool:
|
|
"""Return whether a redirect URI uses a loopback host."""
|
|
parsed = urlparse(redirect_uri.strip())
|
|
if parsed.scheme != "http":
|
|
return False
|
|
host = (parsed.hostname or "").lower()
|
|
return host in _LOOPBACK_HOSTS
|
|
|
|
|
|
def is_claude_redirect_uri(redirect_uri: str) -> bool:
|
|
"""Return whether a redirect URI is a built-in Claude callback URL."""
|
|
return _canonicalize_url(redirect_uri) in _CLAUDE_CALLBACK_URIS
|
|
|
|
|
|
def is_redirect_uri_allowed(redirect_uri: str, allowlist: list[str]) -> bool:
|
|
"""Return whether a redirect URI is allowed by policy."""
|
|
normalized = _canonicalize_url(redirect_uri)
|
|
if not normalized:
|
|
return False
|
|
|
|
if is_loopback_redirect_uri(redirect_uri) or is_claude_redirect_uri(redirect_uri):
|
|
return True
|
|
|
|
for pattern in allowlist:
|
|
candidate = pattern.strip()
|
|
if not candidate:
|
|
continue
|
|
if fnmatchcase(normalized, _canonicalize_url(candidate) or candidate):
|
|
return True
|
|
if fnmatchcase(redirect_uri.strip(), candidate):
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_origin_allowed(origin: str, request_base: str, public_base: str | None) -> bool:
|
|
"""Return whether a browser Origin is allowed for MCP transport requests."""
|
|
normalized_origin = _canonicalize_url(origin)
|
|
if not normalized_origin:
|
|
return False
|
|
|
|
expected_bases = [request_base.rstrip("/")]
|
|
if public_base:
|
|
expected_bases.append(public_base.rstrip("/"))
|
|
return normalized_origin in expected_bases
|
|
|
|
|
|
def encode_proxy_state(
|
|
secret: str,
|
|
redirect_uri: str,
|
|
original_state: str,
|
|
*,
|
|
ttl_seconds: int = 600,
|
|
) -> str:
|
|
"""Create a signed OAuth state wrapper for the proxy callback round-trip."""
|
|
payload = {
|
|
"redirect_uri": redirect_uri,
|
|
"state": original_state,
|
|
"issued_at": int(time.time()),
|
|
"nonce": secrets.token_urlsafe(16),
|
|
"ttl_seconds": ttl_seconds,
|
|
}
|
|
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
|
signature = hmac.new(secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256)
|
|
envelope = {
|
|
"payload": payload,
|
|
"signature": signature.hexdigest(),
|
|
}
|
|
return base64.urlsafe_b64encode(
|
|
json.dumps(envelope, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
|
).decode("ascii")
|
|
|
|
|
|
def decode_proxy_state(secret: str, encoded_state: str) -> dict[str, str]:
|
|
"""Verify and unpack a signed OAuth state wrapper."""
|
|
try:
|
|
raw = base64.urlsafe_b64decode(encoded_state.encode("ascii"))
|
|
envelope = json.loads(raw)
|
|
except Exception as exc: # pragma: no cover - guarded by tests
|
|
raise ValueError("Invalid or missing state parameter") from exc
|
|
|
|
if not isinstance(envelope, dict):
|
|
raise ValueError("Invalid or missing state parameter")
|
|
|
|
payload = envelope.get("payload")
|
|
signature = envelope.get("signature")
|
|
if not isinstance(payload, dict) or not isinstance(signature, str):
|
|
raise ValueError("Invalid or missing state parameter")
|
|
|
|
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
|
|
expected_signature = hmac.new(
|
|
secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256
|
|
).hexdigest()
|
|
if not hmac.compare_digest(signature, expected_signature):
|
|
raise ValueError("Invalid or missing state parameter")
|
|
|
|
issued_at = payload.get("issued_at")
|
|
ttl_seconds = payload.get("ttl_seconds")
|
|
now = int(time.time())
|
|
if not isinstance(issued_at, int) or not isinstance(ttl_seconds, int):
|
|
raise ValueError("Invalid or missing state parameter")
|
|
if issued_at > now or now - issued_at > max(ttl_seconds, 1):
|
|
raise ValueError("Invalid or missing state parameter")
|
|
|
|
redirect_uri = payload.get("redirect_uri")
|
|
if not isinstance(redirect_uri, str) or not redirect_uri.strip():
|
|
raise ValueError("Invalid or missing state parameter")
|
|
|
|
original_state = payload.get("state")
|
|
if not isinstance(original_state, str):
|
|
raise ValueError("Invalid or missing state parameter")
|
|
|
|
return {"redirect_uri": redirect_uri, "state": original_state}
|
|
|
|
|
|
class OAuthClientRegistry:
|
|
"""Persisted OAuth client registry for dynamic client registration."""
|
|
|
|
def __init__(self, storage_path: Path) -> None:
|
|
"""Initialize registry storage."""
|
|
self.storage_path = storage_path
|
|
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._clients: dict[str, OAuthClientRecord] = {}
|
|
self._loaded = False
|
|
|
|
@staticmethod
|
|
def _hash_secret(secret: str) -> str:
|
|
"""Hash client secrets before persistence."""
|
|
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
|
|
|
|
def _load(self) -> None:
|
|
"""Load persisted registrations from disk once."""
|
|
if self._loaded:
|
|
return
|
|
self._loaded = True
|
|
if not self.storage_path.exists():
|
|
self._clients = {}
|
|
return
|
|
|
|
raw = json.loads(self.storage_path.read_text(encoding="utf-8"))
|
|
if not isinstance(raw, dict):
|
|
raise ValueError("Persisted DCR storage must be a JSON object")
|
|
|
|
clients: dict[str, OAuthClientRecord] = {}
|
|
for client_id, payload in raw.items():
|
|
if not isinstance(client_id, str):
|
|
raise ValueError("Persisted client id must be a string")
|
|
if not isinstance(payload, dict):
|
|
raise ValueError(f"Persisted client record for {client_id} must be a mapping")
|
|
record = OAuthClientRecord.model_validate({"client_id": client_id, **payload})
|
|
clients[client_id] = record
|
|
|
|
self._clients = clients
|
|
|
|
def _persist(self) -> None:
|
|
"""Write registrations atomically."""
|
|
payload = {
|
|
client_id: record.model_dump(mode="json", exclude={"client_id"})
|
|
for client_id, record in self._clients.items()
|
|
}
|
|
tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp")
|
|
tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2), encoding="utf-8")
|
|
tmp_path.replace(self.storage_path)
|
|
|
|
def get(self, client_id: str) -> OAuthClientRecord | None:
|
|
"""Look up a registered client by identifier."""
|
|
self._load()
|
|
return self._clients.get(client_id)
|
|
|
|
def is_known_client(
|
|
self,
|
|
client_id: str,
|
|
*,
|
|
fallback_client_id: str = "",
|
|
fallback_client_secret: str = "",
|
|
) -> bool:
|
|
"""Return whether a client is recognized by the registry or environment."""
|
|
if not client_id.strip():
|
|
return False
|
|
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
|
|
return True
|
|
return self.get(client_id) is not None
|
|
|
|
def validate_client_secret(
|
|
self,
|
|
client_id: str,
|
|
client_secret: str | None,
|
|
*,
|
|
fallback_client_id: str = "",
|
|
fallback_client_secret: str = "",
|
|
) -> bool:
|
|
"""Validate a client identifier and optional secret."""
|
|
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
|
|
if not fallback_client_secret.strip():
|
|
return True
|
|
if not client_secret:
|
|
return False
|
|
return hmac.compare_digest(
|
|
self._hash_secret(client_secret), self._hash_secret(fallback_client_secret.strip())
|
|
)
|
|
|
|
record = self.get(client_id)
|
|
if record is None:
|
|
return False
|
|
|
|
if record.client_secret_hash is None:
|
|
return True
|
|
if not client_secret:
|
|
return False
|
|
return hmac.compare_digest(self._hash_secret(client_secret), record.client_secret_hash)
|
|
|
|
def register(self, request: OAuthRegistrationRequest) -> dict[str, Any]:
|
|
"""Persist a new OAuth client registration and return its public metadata."""
|
|
self._load()
|
|
client_id = secrets.token_urlsafe(24)
|
|
client_secret: str | None = None
|
|
client_secret_hash: str | None = None
|
|
|
|
if request.token_endpoint_auth_method != "none":
|
|
client_secret = secrets.token_urlsafe(32)
|
|
client_secret_hash = self._hash_secret(client_secret)
|
|
|
|
record = OAuthClientRecord(
|
|
client_id=client_id,
|
|
client_name=request.client_name,
|
|
redirect_uris=list(request.redirect_uris),
|
|
grant_types=list(request.grant_types),
|
|
response_types=list(request.response_types),
|
|
token_endpoint_auth_method=request.token_endpoint_auth_method,
|
|
client_id_issued_at=int(time.time()),
|
|
client_secret_hash=client_secret_hash,
|
|
scope=request.scope,
|
|
)
|
|
self._clients[client_id] = record
|
|
self._persist()
|
|
|
|
response: dict[str, Any] = record.model_dump(exclude={"client_secret_hash"})
|
|
if client_secret is not None:
|
|
response["client_secret"] = client_secret
|
|
response["client_secret_expires_at"] = 0
|
|
return response
|
|
|
|
|
|
_oauth_client_registry: OAuthClientRegistry | None = None
|
|
|
|
|
|
def get_oauth_client_registry(storage_path: Path) -> OAuthClientRegistry:
|
|
"""Get or create the global OAuth client registry."""
|
|
global _oauth_client_registry
|
|
if _oauth_client_registry is None or _oauth_client_registry.storage_path != storage_path:
|
|
_oauth_client_registry = OAuthClientRegistry(storage_path)
|
|
return _oauth_client_registry
|
|
|
|
|
|
def reset_oauth_client_registry() -> None:
|
|
"""Reset the global OAuth client registry (primarily for tests)."""
|
|
global _oauth_client_registry
|
|
_oauth_client_registry = None
|