"""OAuth2/OIDC token validation for per-user Gitea authentication.""" from __future__ import annotations import json import time from datetime import datetime, timezone from typing import Any, cast import httpx import jwt from jwt import InvalidTokenError from jwt.algorithms import RSAAlgorithm from aegis_gitea_mcp.audit import get_audit_logger from aegis_gitea_mcp.config import get_settings class OAuthTokenValidationError(RuntimeError): """Raised when a provided OAuth token cannot be trusted.""" def __init__(self, public_message: str, reason: str) -> None: """Initialize validation error details.""" super().__init__(public_message) self.public_message = public_message self.reason = reason class GiteaOAuthValidator: """Validate per-user OAuth access tokens issued by Gitea.""" def __init__(self) -> None: """Initialize OAuth validator state and caches.""" self.settings = get_settings() self.audit = get_audit_logger() self._failed_attempts: dict[str, list[datetime]] = {} self._discovery_cache: tuple[dict[str, Any], float] | None = None self._jwks_cache: dict[str, tuple[dict[str, Any], float]] = {} @staticmethod def extract_bearer_token(authorization_header: str | None) -> str | None: """Extract token from `Authorization: Bearer ` header.""" if not authorization_header: return None scheme, separator, token = authorization_header.partition(" ") if separator != " " or scheme != "Bearer": return None stripped = token.strip() if not stripped or " " in stripped: return None return stripped def _check_rate_limit(self, identifier: str) -> bool: """Check whether authentication failures exceed configured threshold.""" now = datetime.now(timezone.utc) boundary = now.timestamp() - self.settings.auth_failure_window if identifier in self._failed_attempts: self._failed_attempts[identifier] = [ attempt for attempt in self._failed_attempts[identifier] if attempt.timestamp() > boundary ] return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures def _record_failed_attempt(self, identifier: str) -> None: """Record a failed authentication attempt for rate limiting.""" attempt_time = datetime.now(timezone.utc) self._failed_attempts.setdefault(identifier, []).append(attempt_time) if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures: self.audit.log_security_event( event_type="oauth_rate_limit_exceeded", description="OAuth authentication failure threshold exceeded", severity="high", metadata={ "identifier": identifier, "failure_count": len(self._failed_attempts[identifier]), "window_seconds": self.settings.auth_failure_window, }, ) @staticmethod def _looks_like_jwt(token: str) -> bool: """Return True when token has JWT segment structure.""" return token.count(".") == 2 @staticmethod def _normalize_scopes(raw: Any) -> set[str]: """Normalize scope claim variations to a set.""" normalized: set[str] = set() if isinstance(raw, str): normalized.update(scope for scope in raw.split(" ") if scope) elif isinstance(raw, list): normalized.update(str(scope).strip() for scope in raw if str(scope).strip()) return normalized def _extract_scopes(self, payload: dict[str, Any]) -> set[str]: """Extract scopes from JWT or userinfo payload.""" scopes = set() scopes.update(self._normalize_scopes(payload.get("scope"))) scopes.update(self._normalize_scopes(payload.get("scopes"))) scopes.update(self._normalize_scopes(payload.get("scp"))) return scopes async def _fetch_json_document(self, url: str) -> dict[str, Any]: """Fetch a JSON document from a trusted OAuth endpoint.""" try: async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client: response = await client.get(url, headers={"Accept": "application/json"}) except httpx.RequestError as exc: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_network_error", ) from exc if response.status_code != 200: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_metadata_unavailable", ) try: data = response.json() except ValueError as exc: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_metadata_invalid_json", ) from exc if not isinstance(data, dict): raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_metadata_invalid_type", ) return data async def _get_discovery_document(self) -> dict[str, Any]: """Get cached OIDC discovery metadata.""" now = time.monotonic() if self._discovery_cache and now < self._discovery_cache[1]: return self._discovery_cache[0] discovery_url = f"{self.settings.gitea_base_url}/.well-known/openid-configuration" discovery = await self._fetch_json_document(discovery_url) issuer = discovery.get("issuer") jwks_uri = discovery.get("jwks_uri") if not isinstance(issuer, str) or not issuer.strip(): raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_discovery_missing_issuer", ) if not isinstance(jwks_uri, str) or not jwks_uri.strip(): raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_discovery_missing_jwks_uri", ) self._discovery_cache = (discovery, now + self.settings.oauth_cache_ttl_seconds) return discovery async def _get_jwks(self, jwks_uri: str) -> dict[str, Any]: """Get cached JWKS document.""" now = time.monotonic() cached = self._jwks_cache.get(jwks_uri) if cached and now < cached[1]: return cached[0] jwks = await self._fetch_json_document(jwks_uri) keys = jwks.get("keys") if not isinstance(keys, list) or not keys: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_jwks_missing_keys", ) self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds) return jwks async def _validate_jwt(self, token: str) -> dict[str, Any]: """Validate JWT access token using OIDC discovery and JWKS.""" discovery = await self._get_discovery_document() issuer = str(discovery["issuer"]).rstrip("/") jwks_uri = str(discovery["jwks_uri"]) jwks = await self._get_jwks(jwks_uri) try: header = jwt.get_unverified_header(token) except InvalidTokenError as exc: raise OAuthTokenValidationError( "Invalid or expired OAuth token.", "oauth_jwt_header" ) from exc algorithm = header.get("alg") key_id = header.get("kid") if algorithm != "RS256": raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_alg") if not isinstance(key_id, str) or not key_id.strip(): raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_kid") matching_key = None for key in jwks.get("keys", []): if isinstance(key, dict) and key.get("kid") == key_id: matching_key = key break if matching_key is None: raise OAuthTokenValidationError( "Invalid or expired OAuth token.", "oauth_jwt_key_not_found" ) try: public_key = RSAAlgorithm.from_jwk(json.dumps(matching_key)) except Exception as exc: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_jwt_invalid_jwk", ) from exc expected_audience = ( self.settings.oauth_expected_audience.strip() or self.settings.gitea_oauth_client_id.strip() ) decode_options = cast(Any, {"verify_aud": bool(expected_audience)}) try: claims = jwt.decode( token, key=cast(Any, public_key), algorithms=["RS256"], issuer=issuer, audience=expected_audience or None, options=decode_options, ) except InvalidTokenError as exc: raise OAuthTokenValidationError( "Invalid or expired OAuth token.", "oauth_jwt_invalid" ) from exc if not isinstance(claims, dict): raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_claims") scopes = self._extract_scopes(claims) login = ( str(claims.get("preferred_username", "")).strip() or str(claims.get("name", "")).strip() or str(claims.get("sub", "unknown")).strip() ) subject = str(claims.get("sub", login)).strip() or "unknown" return { "login": login, "subject": subject, "scopes": sorted(scopes), } async def _validate_userinfo(self, token: str) -> dict[str, Any]: """Validate token via Gitea userinfo endpoint (opaque token fallback).""" userinfo_url = f"{self.settings.gitea_base_url}/login/oauth/userinfo" try: async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client: response = await client.get( userinfo_url, headers={ "Authorization": f"Bearer {token}", "Accept": "application/json", }, ) except httpx.RequestError as exc: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_userinfo_network", ) from exc if response.status_code in {401, 403}: raise OAuthTokenValidationError( "Invalid or expired OAuth token.", "oauth_userinfo_denied" ) if response.status_code != 200: raise OAuthTokenValidationError( "Unable to validate OAuth token at this time.", "oauth_userinfo_unavailable", ) try: payload = response.json() except ValueError as exc: raise OAuthTokenValidationError( "Invalid or expired OAuth token.", "oauth_userinfo_json" ) from exc if not isinstance(payload, dict): raise OAuthTokenValidationError( "Invalid or expired OAuth token.", "oauth_userinfo_type" ) scopes = self._extract_scopes(payload) login = ( str(payload.get("preferred_username", "")).strip() or str(payload.get("login", "")).strip() or str(payload.get("name", "")).strip() or str(payload.get("sub", "unknown")).strip() ) subject = str(payload.get("sub", login)).strip() or "unknown" return { "login": login, "subject": subject, "scopes": sorted(scopes), } async def validate_oauth_token( self, token: str | None, client_ip: str, user_agent: str, ) -> tuple[bool, str | None, dict[str, Any] | None]: """Validate an incoming OAuth token and return principal context.""" if not self._check_rate_limit(client_ip): return False, "Too many failed authentication attempts. Try again later.", None if not token: self._record_failed_attempt(client_ip) return False, "Authorization header missing or empty.", None try: if self._looks_like_jwt(token): try: principal = await self._validate_jwt(token) except OAuthTokenValidationError: # Some providers issue opaque access tokens; verify those via userinfo. principal = await self._validate_userinfo(token) else: principal = await self._validate_userinfo(token) except OAuthTokenValidationError as exc: self._record_failed_attempt(client_ip) self.audit.log_access_denied( tool_name="oauth_authentication", reason=exc.reason, ) return False, exc.public_message, None self.audit.log_tool_invocation( tool_name="oauth_authentication", result_status="success", params={ "client_ip": client_ip, "user_agent": user_agent, "gitea_user": principal.get("login", "unknown"), }, ) return True, None, principal _oauth_validator: GiteaOAuthValidator | None = None def get_oauth_validator() -> GiteaOAuthValidator: """Get or create the global OAuth validator instance.""" global _oauth_validator if _oauth_validator is None: _oauth_validator = GiteaOAuthValidator() return _oauth_validator def reset_oauth_validator() -> None: """Reset the global OAuth validator instance (primarily for testing).""" global _oauth_validator _oauth_validator = None