Add OAuth2/OIDC per-user Gitea authentication
Introduce a GiteaOAuthValidator for JWT and userinfo validation and fallbacks, add /oauth/token proxy, and thread per-user tokens through the request context and automation paths. Update config and .env.example for OAuth-first mode, add OpenAPI, extensive unit/integration tests, GitHub/Gitea CI workflows, docs, and lint/test enforcement (>=80% cov).
This commit is contained in:
366
src/aegis_gitea_mcp/oauth.py
Normal file
366
src/aegis_gitea_mcp/oauth.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""OAuth2/OIDC token validation for per-user Gitea authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from jwt import InvalidTokenError
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
|
||||
|
||||
class OAuthTokenValidationError(RuntimeError):
|
||||
"""Raised when a provided OAuth token cannot be trusted."""
|
||||
|
||||
def __init__(self, public_message: str, reason: str) -> None:
|
||||
"""Initialize validation error details."""
|
||||
super().__init__(public_message)
|
||||
self.public_message = public_message
|
||||
self.reason = reason
|
||||
|
||||
|
||||
class GiteaOAuthValidator:
|
||||
"""Validate per-user OAuth access tokens issued by Gitea."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize OAuth validator state and caches."""
|
||||
self.settings = get_settings()
|
||||
self.audit = get_audit_logger()
|
||||
self._failed_attempts: dict[str, list[datetime]] = {}
|
||||
self._discovery_cache: tuple[dict[str, Any], float] | None = None
|
||||
self._jwks_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
||||
|
||||
@staticmethod
|
||||
def extract_bearer_token(authorization_header: str | None) -> str | None:
|
||||
"""Extract token from `Authorization: Bearer <token>` header."""
|
||||
if not authorization_header:
|
||||
return None
|
||||
scheme, separator, token = authorization_header.partition(" ")
|
||||
if separator != " " or scheme != "Bearer":
|
||||
return None
|
||||
stripped = token.strip()
|
||||
if not stripped or " " in stripped:
|
||||
return None
|
||||
return stripped
|
||||
|
||||
def _check_rate_limit(self, identifier: str) -> bool:
|
||||
"""Check whether authentication failures exceed configured threshold."""
|
||||
now = datetime.now(timezone.utc)
|
||||
boundary = now.timestamp() - self.settings.auth_failure_window
|
||||
|
||||
if identifier in self._failed_attempts:
|
||||
self._failed_attempts[identifier] = [
|
||||
attempt
|
||||
for attempt in self._failed_attempts[identifier]
|
||||
if attempt.timestamp() > boundary
|
||||
]
|
||||
|
||||
return len(self._failed_attempts.get(identifier, [])) < self.settings.max_auth_failures
|
||||
|
||||
def _record_failed_attempt(self, identifier: str) -> None:
|
||||
"""Record a failed authentication attempt for rate limiting."""
|
||||
attempt_time = datetime.now(timezone.utc)
|
||||
self._failed_attempts.setdefault(identifier, []).append(attempt_time)
|
||||
|
||||
if len(self._failed_attempts[identifier]) >= self.settings.max_auth_failures:
|
||||
self.audit.log_security_event(
|
||||
event_type="oauth_rate_limit_exceeded",
|
||||
description="OAuth authentication failure threshold exceeded",
|
||||
severity="high",
|
||||
metadata={
|
||||
"identifier": identifier,
|
||||
"failure_count": len(self._failed_attempts[identifier]),
|
||||
"window_seconds": self.settings.auth_failure_window,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_jwt(token: str) -> bool:
|
||||
"""Return True when token has JWT segment structure."""
|
||||
return token.count(".") == 2
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scopes(raw: Any) -> set[str]:
|
||||
"""Normalize scope claim variations to a set."""
|
||||
normalized: set[str] = set()
|
||||
if isinstance(raw, str):
|
||||
normalized.update(scope for scope in raw.split(" ") if scope)
|
||||
elif isinstance(raw, list):
|
||||
normalized.update(str(scope).strip() for scope in raw if str(scope).strip())
|
||||
return normalized
|
||||
|
||||
def _extract_scopes(self, payload: dict[str, Any]) -> set[str]:
|
||||
"""Extract scopes from JWT or userinfo payload."""
|
||||
scopes = set()
|
||||
scopes.update(self._normalize_scopes(payload.get("scope")))
|
||||
scopes.update(self._normalize_scopes(payload.get("scopes")))
|
||||
scopes.update(self._normalize_scopes(payload.get("scp")))
|
||||
return scopes
|
||||
|
||||
async def _fetch_json_document(self, url: str) -> dict[str, Any]:
|
||||
"""Fetch a JSON document from a trusted OAuth endpoint."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client:
|
||||
response = await client.get(url, headers={"Accept": "application/json"})
|
||||
except httpx.RequestError as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_network_error",
|
||||
) from exc
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_metadata_unavailable",
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_metadata_invalid_json",
|
||||
) from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_metadata_invalid_type",
|
||||
)
|
||||
return data
|
||||
|
||||
async def _get_discovery_document(self) -> dict[str, Any]:
|
||||
"""Get cached OIDC discovery metadata."""
|
||||
now = time.monotonic()
|
||||
if self._discovery_cache and now < self._discovery_cache[1]:
|
||||
return self._discovery_cache[0]
|
||||
|
||||
discovery_url = f"{self.settings.gitea_base_url}/.well-known/openid-configuration"
|
||||
discovery = await self._fetch_json_document(discovery_url)
|
||||
issuer = discovery.get("issuer")
|
||||
jwks_uri = discovery.get("jwks_uri")
|
||||
if not isinstance(issuer, str) or not issuer.strip():
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_discovery_missing_issuer",
|
||||
)
|
||||
if not isinstance(jwks_uri, str) or not jwks_uri.strip():
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_discovery_missing_jwks_uri",
|
||||
)
|
||||
|
||||
self._discovery_cache = (discovery, now + self.settings.oauth_cache_ttl_seconds)
|
||||
return discovery
|
||||
|
||||
async def _get_jwks(self, jwks_uri: str) -> dict[str, Any]:
|
||||
"""Get cached JWKS document."""
|
||||
now = time.monotonic()
|
||||
cached = self._jwks_cache.get(jwks_uri)
|
||||
if cached and now < cached[1]:
|
||||
return cached[0]
|
||||
|
||||
jwks = await self._fetch_json_document(jwks_uri)
|
||||
keys = jwks.get("keys")
|
||||
if not isinstance(keys, list) or not keys:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_jwks_missing_keys",
|
||||
)
|
||||
self._jwks_cache[jwks_uri] = (jwks, now + self.settings.oauth_cache_ttl_seconds)
|
||||
return jwks
|
||||
|
||||
async def _validate_jwt(self, token: str) -> dict[str, Any]:
|
||||
"""Validate JWT access token using OIDC discovery and JWKS."""
|
||||
discovery = await self._get_discovery_document()
|
||||
issuer = str(discovery["issuer"]).rstrip("/")
|
||||
jwks_uri = str(discovery["jwks_uri"])
|
||||
jwks = await self._get_jwks(jwks_uri)
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except InvalidTokenError as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Invalid or expired OAuth token.", "oauth_jwt_header"
|
||||
) from exc
|
||||
|
||||
algorithm = header.get("alg")
|
||||
key_id = header.get("kid")
|
||||
if algorithm != "RS256":
|
||||
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_alg")
|
||||
if not isinstance(key_id, str) or not key_id.strip():
|
||||
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_kid")
|
||||
|
||||
matching_key = None
|
||||
for key in jwks.get("keys", []):
|
||||
if isinstance(key, dict) and key.get("kid") == key_id:
|
||||
matching_key = key
|
||||
break
|
||||
if matching_key is None:
|
||||
raise OAuthTokenValidationError(
|
||||
"Invalid or expired OAuth token.", "oauth_jwt_key_not_found"
|
||||
)
|
||||
|
||||
try:
|
||||
public_key = RSAAlgorithm.from_jwk(json.dumps(matching_key))
|
||||
except Exception as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_jwt_invalid_jwk",
|
||||
) from exc
|
||||
|
||||
expected_audience = (
|
||||
self.settings.oauth_expected_audience.strip()
|
||||
or self.settings.gitea_oauth_client_id.strip()
|
||||
)
|
||||
|
||||
decode_options = cast(Any, {"verify_aud": bool(expected_audience)})
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key=cast(Any, public_key),
|
||||
algorithms=["RS256"],
|
||||
issuer=issuer,
|
||||
audience=expected_audience or None,
|
||||
options=decode_options,
|
||||
)
|
||||
except InvalidTokenError as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Invalid or expired OAuth token.", "oauth_jwt_invalid"
|
||||
) from exc
|
||||
|
||||
if not isinstance(claims, dict):
|
||||
raise OAuthTokenValidationError("Invalid or expired OAuth token.", "oauth_jwt_claims")
|
||||
|
||||
scopes = self._extract_scopes(claims)
|
||||
login = (
|
||||
str(claims.get("preferred_username", "")).strip()
|
||||
or str(claims.get("name", "")).strip()
|
||||
or str(claims.get("sub", "unknown")).strip()
|
||||
)
|
||||
subject = str(claims.get("sub", login)).strip() or "unknown"
|
||||
return {
|
||||
"login": login,
|
||||
"subject": subject,
|
||||
"scopes": sorted(scopes),
|
||||
}
|
||||
|
||||
async def _validate_userinfo(self, token: str) -> dict[str, Any]:
|
||||
"""Validate token via Gitea userinfo endpoint (opaque token fallback)."""
|
||||
userinfo_url = f"{self.settings.gitea_base_url}/login/oauth/userinfo"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.settings.request_timeout_seconds) as client:
|
||||
response = await client.get(
|
||||
userinfo_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_userinfo_network",
|
||||
) from exc
|
||||
|
||||
if response.status_code in {401, 403}:
|
||||
raise OAuthTokenValidationError(
|
||||
"Invalid or expired OAuth token.", "oauth_userinfo_denied"
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OAuthTokenValidationError(
|
||||
"Unable to validate OAuth token at this time.",
|
||||
"oauth_userinfo_unavailable",
|
||||
)
|
||||
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError as exc:
|
||||
raise OAuthTokenValidationError(
|
||||
"Invalid or expired OAuth token.", "oauth_userinfo_json"
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise OAuthTokenValidationError(
|
||||
"Invalid or expired OAuth token.", "oauth_userinfo_type"
|
||||
)
|
||||
|
||||
scopes = self._extract_scopes(payload)
|
||||
login = (
|
||||
str(payload.get("preferred_username", "")).strip()
|
||||
or str(payload.get("login", "")).strip()
|
||||
or str(payload.get("name", "")).strip()
|
||||
or str(payload.get("sub", "unknown")).strip()
|
||||
)
|
||||
subject = str(payload.get("sub", login)).strip() or "unknown"
|
||||
return {
|
||||
"login": login,
|
||||
"subject": subject,
|
||||
"scopes": sorted(scopes),
|
||||
}
|
||||
|
||||
async def validate_oauth_token(
|
||||
self,
|
||||
token: str | None,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
) -> tuple[bool, str | None, dict[str, Any] | None]:
|
||||
"""Validate an incoming OAuth token and return principal context."""
|
||||
if not self._check_rate_limit(client_ip):
|
||||
return False, "Too many failed authentication attempts. Try again later.", None
|
||||
|
||||
if not token:
|
||||
self._record_failed_attempt(client_ip)
|
||||
return False, "Authorization header missing or empty.", None
|
||||
|
||||
try:
|
||||
if self._looks_like_jwt(token):
|
||||
try:
|
||||
principal = await self._validate_jwt(token)
|
||||
except OAuthTokenValidationError:
|
||||
# Some providers issue opaque access tokens; verify those via userinfo.
|
||||
principal = await self._validate_userinfo(token)
|
||||
else:
|
||||
principal = await self._validate_userinfo(token)
|
||||
except OAuthTokenValidationError as exc:
|
||||
self._record_failed_attempt(client_ip)
|
||||
self.audit.log_access_denied(
|
||||
tool_name="oauth_authentication",
|
||||
reason=exc.reason,
|
||||
)
|
||||
return False, exc.public_message, None
|
||||
|
||||
self.audit.log_tool_invocation(
|
||||
tool_name="oauth_authentication",
|
||||
result_status="success",
|
||||
params={
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"gitea_user": principal.get("login", "unknown"),
|
||||
},
|
||||
)
|
||||
return True, None, principal
|
||||
|
||||
|
||||
_oauth_validator: GiteaOAuthValidator | None = None
|
||||
|
||||
|
||||
def get_oauth_validator() -> GiteaOAuthValidator:
|
||||
"""Get or create the global OAuth validator instance."""
|
||||
global _oauth_validator
|
||||
if _oauth_validator is None:
|
||||
_oauth_validator = GiteaOAuthValidator()
|
||||
return _oauth_validator
|
||||
|
||||
|
||||
def reset_oauth_validator() -> None:
|
||||
"""Reset the global OAuth validator instance (primarily for testing)."""
|
||||
global _oauth_validator
|
||||
_oauth_validator = None
|
||||
Reference in New Issue
Block a user