Files
AegisGitea-MCP/src/aegis_gitea_mcp/oauth.py
latte 59e1ea53a8
Some checks failed
docker / lint (push) Has been cancelled
docker / test (push) Has been cancelled
docker / docker-build (push) Has been cancelled
lint / lint (push) Has been cancelled
test / test (push) Has been cancelled
Add OAuth2/OIDC per-user Gitea authentication
Introduce a GiteaOAuthValidator for JWT and userinfo validation and
fallbacks, add /oauth/token proxy, and thread per-user tokens through
the
request context and automation paths. Update config and .env.example for
OAuth-first mode, add OpenAPI, extensive unit/integration tests,
GitHub/Gitea CI workflows, docs, and lint/test enforcement (>=80% cov).
2026-02-25 16:54:01 +01:00

367 lines
14 KiB
Python

"""OAuth2/OIDC token validation for per-user Gitea authentication."""
from __future__ import annotations
import json
import time
from datetime import datetime, timezone
from typing import Any, cast
import httpx
import jwt
from jwt import InvalidTokenError
from jwt.algorithms import RSAAlgorithm
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.config import get_settings
class OAuthTokenValidationError(RuntimeError):
"""Raised when a provided OAuth token cannot be trusted."""
def __init__(self, public_message: str, reason: str) -> None:
"""Initialize validation error details."""
super().__init__(public_message)
self.public_message = public_message
self.reason = reason
class GiteaOAuthValidator:
"""Validate per-user OAuth access tokens issued by Gitea."""
def __init__(self) -> None:
"""Initialize OAuth validator state and caches."""
self.settings = get_settings()
self.audit = get_audit_logger()
self._failed_attempts: dict[str, list[datetime]] = {}
self._discovery_cache: tuple[dict[str, Any], float] | None = None
self._jwks_cache: dict[str, tuple[dict[str, Any], float]] = {}
@staticmethod
def extract_bearer_token(authorization_header: str | None) -> str | None:
"""Extract token from `Authorization: Bearer <token>` header."""
if not authorization_header:
return None
scheme, separator, token = authorization_header.partition(" ")
if separator != " " or scheme != "Bearer":
return None
stripped = token.strip()
if not stripped or " " in stripped:
return None
return stripped
def _check_rate_limit(self, identifier: str) -> bool:
"""Check whether authentication failures exceed configured threshold."""
now = datetime.now(timezone.utc)
boundary = now.timestamp() - self.settings.auth_failure_window
if identifier in self._failed_attempts:
self._failed_attempts[identifier] = [
attempt
for attempt in self._failed_attempts[identifier]
if attempt.timestamp() > boundary
]
return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures
def _record_failed_attempt(self, identifier: str) -> None:
"""Record a failed authentication attempt for rate limiting."""
attempt_time = datetime.now(timezone.utc)
self._failed_attempts.setdefault(identifier, []).append(attempt_time)
if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures:
self.audit.log_security_event(
event_type="oauth_rate_limit_exceeded",
description="OAuth authentication failure threshold exceeded",
severity="high",
metadata={
"identifier": identifier,
"failure_count": len(self._failed_attempts[identifier]),
"window_seconds": self.settings.auth_failure_window,
},
)
@staticmethod
def _looks_like_jwt(token: str) -> bool:
"""Return True when token has JWT segment structure."""
return token.count(".") == 2
@staticmethod
def _normalize_scopes(raw: Any) -> set[str]:
"""Normalize scope claim variations to a set."""
normalized: set[str] = set()
if isinstance(raw, str):
normalized.update(scope for scope in raw.split(" ") if scope)
elif isinstance(raw, list):
normalized.update(str(scope).strip() for scope in raw if str(scope).strip())
return normalized
def _extract_scopes(self, payload: dict[str, Any]) -> set[str]:
"""Extract scopes from JWT or userinfo payload."""
scopes = set()
scopes.update(self._normalize_scopes(payload.get("scope")))
scopes.update(self._normalize_scopes(payload.get("scopes")))
scopes.update(self._normalize_scopes(payload.get("scp")))
return scopes
async def _fetch_json_document(self, url: str) -> dict[str, Any]:
"""Fetch a JSON document from a trusted OAuth endpoint."""
try:
async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client:
response = await client.get(url, headers={"Accept": "application/json"})
except httpx.RequestError as exc:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_network_error",
) from exc
if response.status_code != 200:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_metadata_unavailable",
)
try:
data = response.json()
except ValueError as exc:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_metadata_invalid_json",
) from exc
if not isinstance(data, dict):
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_metadata_invalid_type",
)
return data
async def _get_discovery_document(self) -> dict[str, Any]:
"""Get cached OIDC discovery metadata."""
now = time.monotonic()
if self._discovery_cache and now < self._discovery_cache[1]:
return self._discovery_cache[0]
discovery_url = f"{self.settings.gitea_base_url}/.well-known/openid-configuration"
discovery = await self._fetch_json_document(discovery_url)
issuer = discovery.get("issuer")
jwks_uri = discovery.get("jwks_uri")
if not isinstance(issuer, str) or not issuer.strip():
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_discovery_missing_issuer",
)
if not isinstance(jwks_uri, str) or not jwks_uri.strip():
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_discovery_missing_jwks_uri",
)
self._discovery_cache = (discovery, now + self.settings.oauth_cache_ttl_seconds)
return discovery
async def _get_jwks(self, jwks_uri: str) -> dict[str, Any]:
"""Get cached JWKS document."""
now = time.monotonic()
cached = self._jwks_cache.get(jwks_uri)
if cached and now < cached[1]:
return cached[0]
jwks = await self._fetch_json_document(jwks_uri)
keys = jwks.get("keys")
if not isinstance(keys, list) or not keys:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_jwks_missing_keys",
)
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
return jwks
async def _validate_jwt(self, token: str) -> dict[str, Any]:
"""Validate JWT access token using OIDC discovery and JWKS."""
discovery = await self._get_discovery_document()
issuer = str(discovery["issuer"]).rstrip("/")
jwks_uri = str(discovery["jwks_uri"])
jwks = await self._get_jwks(jwks_uri)
try:
header = jwt.get_unverified_header(token)
except InvalidTokenError as exc:
raise OAuthTokenValidationError(
"Invalid or expired OAuth token.", "oauth_jwt_header"
) from exc
algorithm = header.get("alg")
key_id = header.get("kid")
if algorithm != "RS256":
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_alg")
if not isinstance(key_id, str) or not key_id.strip():
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_kid")
matching_key = None
for key in jwks.get("keys", []):
if isinstance(key, dict) and key.get("kid") == key_id:
matching_key = key
break
if matching_key is None:
raise OAuthTokenValidationError(
"Invalid or expired OAuth token.", "oauth_jwt_key_not_found"
)
try:
public_key = RSAAlgorithm.from_jwk(json.dumps(matching_key))
except Exception as exc:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_jwt_invalid_jwk",
) from exc
expected_audience = (
self.settings.oauth_expected_audience.strip()
or self.settings.gitea_oauth_client_id.strip()
)
decode_options = cast(Any, {"verify_aud": bool(expected_audience)})
try:
claims = jwt.decode(
token,
key=cast(Any, public_key),
algorithms=["RS256"],
issuer=issuer,
audience=expected_audience or None,
options=decode_options,
)
except InvalidTokenError as exc:
raise OAuthTokenValidationError(
"Invalid or expired OAuth token.", "oauth_jwt_invalid"
) from exc
if not isinstance(claims, dict):
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_claims")
scopes = self._extract_scopes(claims)
login = (
str(claims.get("preferred_username", "")).strip()
or str(claims.get("name", "")).strip()
or str(claims.get("sub", "unknown")).strip()
)
subject = str(claims.get("sub", login)).strip() or "unknown"
return {
"login": login,
"subject": subject,
"scopes": sorted(scopes),
}
async def _validate_userinfo(self, token: str) -> dict[str, Any]:
"""Validate token via Gitea userinfo endpoint (opaque token fallback)."""
userinfo_url = f"{self.settings.gitea_base_url}/login/oauth/userinfo"
try:
async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client:
response = await client.get(
userinfo_url,
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/json",
},
)
except httpx.RequestError as exc:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_userinfo_network",
) from exc
if response.status_code in {401, 403}:
raise OAuthTokenValidationError(
"Invalid or expired OAuth token.", "oauth_userinfo_denied"
)
if response.status_code != 200:
raise OAuthTokenValidationError(
"Unable to validate OAuth token at this time.",
"oauth_userinfo_unavailable",
)
try:
payload = response.json()
except ValueError as exc:
raise OAuthTokenValidationError(
"Invalid or expired OAuth token.", "oauth_userinfo_json"
) from exc
if not isinstance(payload, dict):
raise OAuthTokenValidationError(
"Invalid or expired OAuth token.", "oauth_userinfo_type"
)
scopes = self._extract_scopes(payload)
login = (
str(payload.get("preferred_username", "")).strip()
or str(payload.get("login", "")).strip()
or str(payload.get("name", "")).strip()
or str(payload.get("sub", "unknown")).strip()
)
subject = str(payload.get("sub", login)).strip() or "unknown"
return {
"login": login,
"subject": subject,
"scopes": sorted(scopes),
}
async def validate_oauth_token(
self,
token: str | None,
client_ip: str,
user_agent: str,
) -> tuple[bool, str | None, dict[str, Any] | None]:
"""Validate an incoming OAuth token and return principal context."""
if not self._check_rate_limit(client_ip):
return False, "Too many failed authentication attempts. Try again later.", None
if not token:
self._record_failed_attempt(client_ip)
return False, "Authorization header missing or empty.", None
try:
if self._looks_like_jwt(token):
try:
principal = await self._validate_jwt(token)
except OAuthTokenValidationError:
# Some providers issue opaque access tokens; verify those via userinfo.
principal = await self._validate_userinfo(token)
else:
principal = await self._validate_userinfo(token)
except OAuthTokenValidationError as exc:
self._record_failed_attempt(client_ip)
self.audit.log_access_denied(
tool_name="oauth_authentication",
reason=exc.reason,
)
return False, exc.public_message, None
self.audit.log_tool_invocation(
tool_name="oauth_authentication",
result_status="success",
params={
"client_ip": client_ip,
"user_agent": user_agent,
"gitea_user": principal.get("login", "unknown"),
},
)
return True, None, principal
_oauth_validator: GiteaOAuthValidator | None = None
def get_oauth_validator() -> GiteaOAuthValidator:
"""Get or create the global OAuth validator instance."""
global _oauth_validator
if _oauth_validator is None:
_oauth_validator = GiteaOAuthValidator()
return _oauth_validator
def reset_oauth_validator() -> None:
"""Reset the global OAuth validator instance (primarily for testing)."""
global _oauth_validator
_oauth_validator = None