Files
AegisGitea-MCP/src/aegis_gitea_mcp/oauth_flow.py
T

381 lines
14 KiB
Python

"""OAuth proxy helpers for signed state, redirect validation, and DCR storage."""
from __future__ import annotations
import base64
import hashlib
import hmac
import json
import secrets
import time
from fnmatch import fnmatchcase
from pathlib import Path
from typing import Any
from urllib.parse import ParseResult, urlparse, urlunparse
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
_CLAUDE_CALLBACK_URIS = {
"https://claude.ai/api/mcp/auth_callback",
"https://claude.com/api/mcp/auth_callback",
}
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
_SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS = {"none", "client_secret_post"}
_SUPPORTED_GRANT_TYPES = {"authorization_code", "refresh_token"}
_SUPPORTED_RESPONSE_TYPES = {"code"}
class OAuthRegistrationRequest(BaseModel):
"""Incoming RFC 7591 client registration request."""
client_name: str | None = Field(default=None, max_length=200)
redirect_uris: list[str] = Field(..., min_length=1)
grant_types: list[str] = Field(default_factory=lambda: ["authorization_code", "refresh_token"])
response_types: list[str] = Field(default_factory=lambda: ["code"])
token_endpoint_auth_method: str = Field(default="none", max_length=64)
scope: str | None = Field(default=None, max_length=512)
model_config = ConfigDict(extra="forbid")
@field_validator("redirect_uris")
@classmethod
def validate_redirect_uris(cls, value: list[str]) -> list[str]:
"""Normalize and validate redirect URIs."""
uris = [uri.strip() for uri in value if isinstance(uri, str) and uri.strip()]
if not uris:
raise ValueError("redirect_uris must contain at least one non-empty URI")
return uris
@field_validator("grant_types")
@classmethod
def validate_grant_types(cls, value: list[str]) -> list[str]:
"""Restrict supported grant types to authorization code and refresh token."""
normalized = [item.strip() for item in value if item.strip()]
if not normalized:
raise ValueError("grant_types must not be empty")
if any(item not in _SUPPORTED_GRANT_TYPES for item in normalized):
raise ValueError("Unsupported grant_types requested")
return normalized
@field_validator("response_types")
@classmethod
def validate_response_types(cls, value: list[str]) -> list[str]:
"""Restrict supported response types to authorization code."""
normalized = [item.strip() for item in value if item.strip()]
if not normalized:
raise ValueError("response_types must not be empty")
if any(item not in _SUPPORTED_RESPONSE_TYPES for item in normalized):
raise ValueError("Unsupported response_types requested")
return normalized
@field_validator("token_endpoint_auth_method")
@classmethod
def validate_token_endpoint_auth_method(cls, value: str) -> str:
"""Restrict token endpoint auth methods to the supported subset."""
normalized = value.strip().lower()
if normalized not in _SUPPORTED_TOKEN_ENDPOINT_AUTH_METHODS:
raise ValueError("Unsupported token_endpoint_auth_method requested")
return normalized
@model_validator(mode="after")
def validate_pkce_ready(self) -> OAuthRegistrationRequest:
"""Ensure the request is usable for PKCE-based authorization code flow."""
if "authorization_code" not in self.grant_types:
raise ValueError("authorization_code grant is required")
if "code" not in self.response_types:
raise ValueError("code response type is required")
return self
class OAuthClientRecord(BaseModel):
"""Persisted OAuth client registration record."""
client_id: str
client_name: str | None = None
redirect_uris: list[str]
grant_types: list[str]
response_types: list[str]
token_endpoint_auth_method: str
client_id_issued_at: int
client_secret_expires_at: int = 0
client_secret_hash: str | None = None
scope: str | None = None
model_config = ConfigDict(extra="forbid")
def _canonicalize_url(value: str) -> str:
"""Normalize a URL for comparison."""
parsed = urlparse(value.strip())
if not parsed.scheme or not parsed.netloc:
return ""
normalized = ParseResult(
scheme=parsed.scheme.lower(),
netloc=parsed.netloc.lower(),
path=parsed.path or "/",
params=parsed.params,
query=parsed.query,
fragment="",
)
return urlunparse(normalized).rstrip("/")
def is_loopback_redirect_uri(redirect_uri: str) -> bool:
"""Return whether a redirect URI uses a loopback host."""
parsed = urlparse(redirect_uri.strip())
if parsed.scheme != "http":
return False
host = (parsed.hostname or "").lower()
return host in _LOOPBACK_HOSTS
def is_claude_redirect_uri(redirect_uri: str) -> bool:
"""Return whether a redirect URI is a built-in Claude callback URL."""
return _canonicalize_url(redirect_uri) in _CLAUDE_CALLBACK_URIS
def is_redirect_uri_allowed(redirect_uri: str, allowlist: list[str]) -> bool:
"""Return whether a redirect URI is allowed by policy."""
normalized = _canonicalize_url(redirect_uri)
if not normalized:
return False
if is_loopback_redirect_uri(redirect_uri) or is_claude_redirect_uri(redirect_uri):
return True
for pattern in allowlist:
candidate = pattern.strip()
if not candidate:
continue
if fnmatchcase(normalized, _canonicalize_url(candidate) or candidate):
return True
if fnmatchcase(redirect_uri.strip(), candidate):
return True
return False
def is_origin_allowed(origin: str, request_base: str, public_base: str | None) -> bool:
"""Return whether a browser Origin is allowed for MCP transport requests."""
normalized_origin = _canonicalize_url(origin)
if not normalized_origin:
return False
expected_bases = [request_base.rstrip("/")]
if public_base:
expected_bases.append(public_base.rstrip("/"))
return normalized_origin in expected_bases
def encode_proxy_state(
secret: str,
redirect_uri: str,
original_state: str,
*,
ttl_seconds: int = 600,
) -> str:
"""Create a signed OAuth state wrapper for the proxy callback round-trip."""
payload = {
"redirect_uri": redirect_uri,
"state": original_state,
"issued_at": int(time.time()),
"nonce": secrets.token_urlsafe(16),
"ttl_seconds": ttl_seconds,
}
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
signature = hmac.new(secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256)
envelope = {
"payload": payload,
"signature": signature.hexdigest(),
}
return base64.urlsafe_b64encode(
json.dumps(envelope, sort_keys=True, separators=(",", ":")).encode("utf-8")
).decode("ascii")
def decode_proxy_state(secret: str, encoded_state: str) -> dict[str, str]:
"""Verify and unpack a signed OAuth state wrapper."""
try:
raw = base64.urlsafe_b64decode(encoded_state.encode("ascii"))
envelope = json.loads(raw)
except Exception as exc: # pragma: no cover - guarded by tests
raise ValueError("Invalid or missing state parameter") from exc
if not isinstance(envelope, dict):
raise ValueError("Invalid or missing state parameter")
payload = envelope.get("payload")
signature = envelope.get("signature")
if not isinstance(payload, dict) or not isinstance(signature, str):
raise ValueError("Invalid or missing state parameter")
canonical_payload = json.dumps(payload, sort_keys=True, separators=(",", ":"))
expected_signature = hmac.new(
secret.encode("utf-8"), canonical_payload.encode("utf-8"), hashlib.sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_signature):
raise ValueError("Invalid or missing state parameter")
issued_at = payload.get("issued_at")
ttl_seconds = payload.get("ttl_seconds")
now = int(time.time())
if not isinstance(issued_at, int) or not isinstance(ttl_seconds, int):
raise ValueError("Invalid or missing state parameter")
if issued_at > now or now - issued_at > max(ttl_seconds, 1):
raise ValueError("Invalid or missing state parameter")
redirect_uri = payload.get("redirect_uri")
if not isinstance(redirect_uri, str) or not redirect_uri.strip():
raise ValueError("Invalid or missing state parameter")
original_state = payload.get("state")
if not isinstance(original_state, str):
raise ValueError("Invalid or missing state parameter")
return {"redirect_uri": redirect_uri, "state": original_state}
class OAuthClientRegistry:
"""Persisted OAuth client registry for dynamic client registration."""
def __init__(self, storage_path: Path) -> None:
"""Initialize registry storage."""
self.storage_path = storage_path
self.storage_path.parent.mkdir(parents=True, exist_ok=True)
self._clients: dict[str, OAuthClientRecord] = {}
self._loaded = False
@staticmethod
def _hash_secret(secret: str) -> str:
"""Hash client secrets before persistence."""
return hashlib.sha256(secret.encode("utf-8")).hexdigest()
def _load(self) -> None:
"""Load persisted registrations from disk once."""
if self._loaded:
return
self._loaded = True
if not self.storage_path.exists():
self._clients = {}
return
raw = json.loads(self.storage_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("Persisted DCR storage must be a JSON object")
clients: dict[str, OAuthClientRecord] = {}
for client_id, payload in raw.items():
if not isinstance(client_id, str):
raise ValueError("Persisted client id must be a string")
if not isinstance(payload, dict):
raise ValueError(f"Persisted client record for {client_id} must be a mapping")
record = OAuthClientRecord.model_validate({"client_id": client_id, **payload})
clients[client_id] = record
self._clients = clients
def _persist(self) -> None:
"""Write registrations atomically."""
payload = {
client_id: record.model_dump(mode="json", exclude={"client_id"})
for client_id, record in self._clients.items()
}
tmp_path = self.storage_path.with_suffix(self.storage_path.suffix + ".tmp")
tmp_path.write_text(json.dumps(payload, sort_keys=True, indent=2), encoding="utf-8")
tmp_path.replace(self.storage_path)
def get(self, client_id: str) -> OAuthClientRecord | None:
"""Look up a registered client by identifier."""
self._load()
return self._clients.get(client_id)
def is_known_client(
self,
client_id: str,
*,
fallback_client_id: str = "",
fallback_client_secret: str = "",
) -> bool:
"""Return whether a client is recognized by the registry or environment."""
if not client_id.strip():
return False
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
return True
return self.get(client_id) is not None
def validate_client_secret(
self,
client_id: str,
client_secret: str | None,
*,
fallback_client_id: str = "",
fallback_client_secret: str = "",
) -> bool:
"""Validate a client identifier and optional secret."""
if fallback_client_id.strip() and client_id == fallback_client_id.strip():
if not fallback_client_secret.strip():
return True
if not client_secret:
return False
return hmac.compare_digest(
self._hash_secret(client_secret), self._hash_secret(fallback_client_secret.strip())
)
record = self.get(client_id)
if record is None:
return False
if record.client_secret_hash is None:
return True
if not client_secret:
return False
return hmac.compare_digest(self._hash_secret(client_secret), record.client_secret_hash)
def register(self, request: OAuthRegistrationRequest) -> dict[str, Any]:
"""Persist a new OAuth client registration and return its public metadata."""
self._load()
client_id = secrets.token_urlsafe(24)
client_secret: str | None = None
client_secret_hash: str | None = None
if request.token_endpoint_auth_method != "none":
client_secret = secrets.token_urlsafe(32)
client_secret_hash = self._hash_secret(client_secret)
record = OAuthClientRecord(
client_id=client_id,
client_name=request.client_name,
redirect_uris=list(request.redirect_uris),
grant_types=list(request.grant_types),
response_types=list(request.response_types),
token_endpoint_auth_method=request.token_endpoint_auth_method,
client_id_issued_at=int(time.time()),
client_secret_hash=client_secret_hash,
scope=request.scope,
)
self._clients[client_id] = record
self._persist()
response: dict[str, Any] = record.model_dump(exclude={"client_secret_hash"})
if client_secret is not None:
response["client_secret"] = client_secret
response["client_secret_expires_at"] = 0
return response
_oauth_client_registry: OAuthClientRegistry | None = None
def get_oauth_client_registry(storage_path: Path) -> OAuthClientRegistry:
"""Get or create the global OAuth client registry."""
global _oauth_client_registry
if _oauth_client_registry is None or _oauth_client_registry.storage_path != storage_path:
_oauth_client_registry = OAuthClientRegistry(storage_path)
return _oauth_client_registry
def reset_oauth_client_registry() -> None:
"""Reset the global OAuth client registry (primarily for tests)."""
global _oauth_client_registry
_oauth_client_registry = None