Introduce a GiteaOAuthValidator for JWT and userinfo validation and fallbacks, add /oauth/token proxy, and thread per-user tokens through the request context and automation paths. Update config and .env.example for OAuth-first mode, add OpenAPI, extensive unit/integration tests, GitHub/Gitea CI workflows, docs, and lint/test enforcement (>=80% cov).
367 lines
14 KiB
Python
367 lines
14 KiB
Python
"""OAuth2/OIDC token validation for per-user Gitea authentication."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Any, cast
|
|
|
|
import httpx
|
|
import jwt
|
|
from jwt import InvalidTokenError
|
|
from jwt.algorithms import RSAAlgorithm
|
|
|
|
from aegis_gitea_mcp.audit import get_audit_logger
|
|
from aegis_gitea_mcp.config import get_settings
|
|
|
|
|
|
class OAuthTokenValidationError(RuntimeError):
|
|
"""Raised when a provided OAuth token cannot be trusted."""
|
|
|
|
def __init__(self, public_message: str, reason: str) -> None:
|
|
"""Initialize validation error details."""
|
|
super().__init__(public_message)
|
|
self.public_message = public_message
|
|
self.reason = reason
|
|
|
|
|
|
class GiteaOAuthValidator:
|
|
"""Validate per-user OAuth access tokens issued by Gitea."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize OAuth validator state and caches."""
|
|
self.settings = get_settings()
|
|
self.audit = get_audit_logger()
|
|
self._failed_attempts: dict[str, list[datetime]] = {}
|
|
self._discovery_cache: tuple[dict[str, Any], float] | None = None
|
|
self._jwks_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
|
|
|
@staticmethod
|
|
def extract_bearer_token(authorization_header: str | None) -> str | None:
|
|
"""Extract token from `Authorization: Bearer <token>` header."""
|
|
if not authorization_header:
|
|
return None
|
|
scheme, separator, token = authorization_header.partition(" ")
|
|
if separator != " " or scheme != "Bearer":
|
|
return None
|
|
stripped = token.strip()
|
|
if not stripped or " " in stripped:
|
|
return None
|
|
return stripped
|
|
|
|
def _check_rate_limit(self, identifier: str) -> bool:
|
|
"""Check whether authentication failures exceed configured threshold."""
|
|
now = datetime.now(timezone.utc)
|
|
boundary = now.timestamp() - self.settings.auth_failure_window
|
|
|
|
if identifier in self._failed_attempts:
|
|
self._failed_attempts[identifier] = [
|
|
attempt
|
|
for attempt in self._failed_attempts[identifier]
|
|
if attempt.timestamp() > boundary
|
|
]
|
|
|
|
return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures
|
|
|
|
def _record_failed_attempt(self, identifier: str) -> None:
|
|
"""Record a failed authentication attempt for rate limiting."""
|
|
attempt_time = datetime.now(timezone.utc)
|
|
self._failed_attempts.setdefault(identifier, []).append(attempt_time)
|
|
|
|
if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures:
|
|
self.audit.log_security_event(
|
|
event_type="oauth_rate_limit_exceeded",
|
|
description="OAuth authentication failure threshold exceeded",
|
|
severity="high",
|
|
metadata={
|
|
"identifier": identifier,
|
|
"failure_count": len(self._failed_attempts[identifier]),
|
|
"window_seconds": self.settings.auth_failure_window,
|
|
},
|
|
)
|
|
|
|
@staticmethod
|
|
def _looks_like_jwt(token: str) -> bool:
|
|
"""Return True when token has JWT segment structure."""
|
|
return token.count(".") == 2
|
|
|
|
@staticmethod
|
|
def _normalize_scopes(raw: Any) -> set[str]:
|
|
"""Normalize scope claim variations to a set."""
|
|
normalized: set[str] = set()
|
|
if isinstance(raw, str):
|
|
normalized.update(scope for scope in raw.split(" ") if scope)
|
|
elif isinstance(raw, list):
|
|
normalized.update(str(scope).strip() for scope in raw if str(scope).strip())
|
|
return normalized
|
|
|
|
def _extract_scopes(self, payload: dict[str, Any]) -> set[str]:
|
|
"""Extract scopes from JWT or userinfo payload."""
|
|
scopes = set()
|
|
scopes.update(self._normalize_scopes(payload.get("scope")))
|
|
scopes.update(self._normalize_scopes(payload.get("scopes")))
|
|
scopes.update(self._normalize_scopes(payload.get("scp")))
|
|
return scopes
|
|
|
|
async def _fetch_json_document(self, url: str) -> dict[str, Any]:
|
|
"""Fetch a JSON document from a trusted OAuth endpoint."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client:
|
|
response = await client.get(url, headers={"Accept": "application/json"})
|
|
except httpx.RequestError as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_network_error",
|
|
) from exc
|
|
|
|
if response.status_code != 200:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_metadata_unavailable",
|
|
)
|
|
|
|
try:
|
|
data = response.json()
|
|
except ValueError as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_metadata_invalid_json",
|
|
) from exc
|
|
|
|
if not isinstance(data, dict):
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_metadata_invalid_type",
|
|
)
|
|
return data
|
|
|
|
async def _get_discovery_document(self) -> dict[str, Any]:
|
|
"""Get cached OIDC discovery metadata."""
|
|
now = time.monotonic()
|
|
if self._discovery_cache and now < self._discovery_cache[1]:
|
|
return self._discovery_cache[0]
|
|
|
|
discovery_url = f"{self.settings.gitea_base_url}/.well-known/openid-configuration"
|
|
discovery = await self._fetch_json_document(discovery_url)
|
|
issuer = discovery.get("issuer")
|
|
jwks_uri = discovery.get("jwks_uri")
|
|
if not isinstance(issuer, str) or not issuer.strip():
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_discovery_missing_issuer",
|
|
)
|
|
if not isinstance(jwks_uri, str) or not jwks_uri.strip():
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_discovery_missing_jwks_uri",
|
|
)
|
|
|
|
self._discovery_cache = (discovery, now + self.settings.oauth_cache_ttl_seconds)
|
|
return discovery
|
|
|
|
async def _get_jwks(self, jwks_uri: str) -> dict[str, Any]:
|
|
"""Get cached JWKS document."""
|
|
now = time.monotonic()
|
|
cached = self._jwks_cache.get(jwks_uri)
|
|
if cached and now < cached[1]:
|
|
return cached[0]
|
|
|
|
jwks = await self._fetch_json_document(jwks_uri)
|
|
keys = jwks.get("keys")
|
|
if not isinstance(keys, list) or not keys:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_jwks_missing_keys",
|
|
)
|
|
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
|
|
return jwks
|
|
|
|
async def _validate_jwt(self, token: str) -> dict[str, Any]:
|
|
"""Validate JWT access token using OIDC discovery and JWKS."""
|
|
discovery = await self._get_discovery_document()
|
|
issuer = str(discovery["issuer"]).rstrip("/")
|
|
jwks_uri = str(discovery["jwks_uri"])
|
|
jwks = await self._get_jwks(jwks_uri)
|
|
|
|
try:
|
|
header = jwt.get_unverified_header(token)
|
|
except InvalidTokenError as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Invalid or expired OAuth token.", "oauth_jwt_header"
|
|
) from exc
|
|
|
|
algorithm = header.get("alg")
|
|
key_id = header.get("kid")
|
|
if algorithm != "RS256":
|
|
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_alg")
|
|
if not isinstance(key_id, str) or not key_id.strip():
|
|
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_kid")
|
|
|
|
matching_key = None
|
|
for key in jwks.get("keys", []):
|
|
if isinstance(key, dict) and key.get("kid") == key_id:
|
|
matching_key = key
|
|
break
|
|
if matching_key is None:
|
|
raise OAuthTokenValidationError(
|
|
"Invalid or expired OAuth token.", "oauth_jwt_key_not_found"
|
|
)
|
|
|
|
try:
|
|
public_key = RSAAlgorithm.from_jwk(json.dumps(matching_key))
|
|
except Exception as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_jwt_invalid_jwk",
|
|
) from exc
|
|
|
|
expected_audience = (
|
|
self.settings.oauth_expected_audience.strip()
|
|
or self.settings.gitea_oauth_client_id.strip()
|
|
)
|
|
|
|
decode_options = cast(Any, {"verify_aud": bool(expected_audience)})
|
|
try:
|
|
claims = jwt.decode(
|
|
token,
|
|
key=cast(Any, public_key),
|
|
algorithms=["RS256"],
|
|
issuer=issuer,
|
|
audience=expected_audience or None,
|
|
options=decode_options,
|
|
)
|
|
except InvalidTokenError as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Invalid or expired OAuth token.", "oauth_jwt_invalid"
|
|
) from exc
|
|
|
|
if not isinstance(claims, dict):
|
|
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_claims")
|
|
|
|
scopes = self._extract_scopes(claims)
|
|
login = (
|
|
str(claims.get("preferred_username", "")).strip()
|
|
or str(claims.get("name", "")).strip()
|
|
or str(claims.get("sub", "unknown")).strip()
|
|
)
|
|
subject = str(claims.get("sub", login)).strip() or "unknown"
|
|
return {
|
|
"login": login,
|
|
"subject": subject,
|
|
"scopes": sorted(scopes),
|
|
}
|
|
|
|
async def _validate_userinfo(self, token: str) -> dict[str, Any]:
|
|
"""Validate token via Gitea userinfo endpoint (opaque token fallback)."""
|
|
userinfo_url = f"{self.settings.gitea_base_url}/login/oauth/userinfo"
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client:
|
|
response = await client.get(
|
|
userinfo_url,
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Accept": "application/json",
|
|
},
|
|
)
|
|
except httpx.RequestError as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_userinfo_network",
|
|
) from exc
|
|
|
|
if response.status_code in {401, 403}:
|
|
raise OAuthTokenValidationError(
|
|
"Invalid or expired OAuth token.", "oauth_userinfo_denied"
|
|
)
|
|
if response.status_code != 200:
|
|
raise OAuthTokenValidationError(
|
|
"Unable to validate OAuth token at this time.",
|
|
"oauth_userinfo_unavailable",
|
|
)
|
|
|
|
try:
|
|
payload = response.json()
|
|
except ValueError as exc:
|
|
raise OAuthTokenValidationError(
|
|
"Invalid or expired OAuth token.", "oauth_userinfo_json"
|
|
) from exc
|
|
|
|
if not isinstance(payload, dict):
|
|
raise OAuthTokenValidationError(
|
|
"Invalid or expired OAuth token.", "oauth_userinfo_type"
|
|
)
|
|
|
|
scopes = self._extract_scopes(payload)
|
|
login = (
|
|
str(payload.get("preferred_username", "")).strip()
|
|
or str(payload.get("login", "")).strip()
|
|
or str(payload.get("name", "")).strip()
|
|
or str(payload.get("sub", "unknown")).strip()
|
|
)
|
|
subject = str(payload.get("sub", login)).strip() or "unknown"
|
|
return {
|
|
"login": login,
|
|
"subject": subject,
|
|
"scopes": sorted(scopes),
|
|
}
|
|
|
|
async def validate_oauth_token(
|
|
self,
|
|
token: str | None,
|
|
client_ip: str,
|
|
user_agent: str,
|
|
) -> tuple[bool, str | None, dict[str, Any] | None]:
|
|
"""Validate an incoming OAuth token and return principal context."""
|
|
if not self._check_rate_limit(client_ip):
|
|
return False, "Too many failed authentication attempts. Try again later.", None
|
|
|
|
if not token:
|
|
self._record_failed_attempt(client_ip)
|
|
return False, "Authorization header missing or empty.", None
|
|
|
|
try:
|
|
if self._looks_like_jwt(token):
|
|
try:
|
|
principal = await self._validate_jwt(token)
|
|
except OAuthTokenValidationError:
|
|
# Some providers issue opaque access tokens; verify those via userinfo.
|
|
principal = await self._validate_userinfo(token)
|
|
else:
|
|
principal = await self._validate_userinfo(token)
|
|
except OAuthTokenValidationError as exc:
|
|
self._record_failed_attempt(client_ip)
|
|
self.audit.log_access_denied(
|
|
tool_name="oauth_authentication",
|
|
reason=exc.reason,
|
|
)
|
|
return False, exc.public_message, None
|
|
|
|
self.audit.log_tool_invocation(
|
|
tool_name="oauth_authentication",
|
|
result_status="success",
|
|
params={
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent,
|
|
"gitea_user": principal.get("login", "unknown"),
|
|
},
|
|
)
|
|
return True, None, principal
|
|
|
|
|
|
_oauth_validator: GiteaOAuthValidator | None = None
|
|
|
|
|
|
def get_oauth_validator() -> GiteaOAuthValidator:
|
|
"""Get or create the global OAuth validator instance."""
|
|
global _oauth_validator
|
|
if _oauth_validator is None:
|
|
_oauth_validator = GiteaOAuthValidator()
|
|
return _oauth_validator
|
|
|
|
|
|
def reset_oauth_validator() -> None:
|
|
"""Reset the global OAuth validator instance (primarily for testing)."""
|
|
global _oauth_validator
|
|
_oauth_validator = None
|