306 lines
10 KiB
Python
306 lines
10 KiB
Python
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import jwt
|
|
import msal
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
from ..config import settings
|
|
from ..models.schemas import (
|
|
AuthCallbackRequest,
|
|
AuthCallbackResponse,
|
|
AuthUrlResponse,
|
|
UserResponse,
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('/app/auth.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
def get_msal_app():
|
|
"""Create MSAL confidential client application"""
|
|
logger.info("Checking MSAL configuration")
|
|
|
|
required_settings = [
|
|
("ENTRA_TENANT_ID", settings.ENTRA_TENANT_ID),
|
|
("ENTRA_CLIENT_ID", settings.ENTRA_CLIENT_ID),
|
|
("ENTRA_CLIENT_SECRET", settings.ENTRA_CLIENT_SECRET),
|
|
]
|
|
|
|
missing_settings = [name for name, value in required_settings if not value]
|
|
if missing_settings:
|
|
logger.error(f"Missing required Entra ID settings: {missing_settings}")
|
|
return None
|
|
|
|
logger.info("All Entra ID settings present, creating MSAL app")
|
|
try:
|
|
msal_app = msal.ConfidentialClientApplication(
|
|
client_id=settings.ENTRA_CLIENT_ID,
|
|
client_credential=settings.ENTRA_CLIENT_SECRET,
|
|
authority=f"https://login.microsoftonline.com/{settings.ENTRA_TENANT_ID}",
|
|
)
|
|
logger.info("MSAL application created successfully")
|
|
return msal_app
|
|
except Exception as e:
|
|
logger.error(f"Failed to create MSAL application: {e}")
|
|
return None
|
|
|
|
|
|
def create_jwt_token(user_data: dict) -> str:
|
|
"""Create JWT token with user data"""
|
|
logger.info("Creating JWT token", {
|
|
"user_id": user_data.get("oid") or user_data.get("sub"),
|
|
"user_name": user_data.get("name"),
|
|
"user_email": user_data.get("preferred_username")
|
|
})
|
|
|
|
try:
|
|
payload = {
|
|
"sub": user_data.get("oid") or user_data.get("sub"),
|
|
"name": user_data.get("name"),
|
|
"email": user_data.get("preferred_username"),
|
|
"exp": datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRY_HOURS),
|
|
"iat": datetime.now(timezone.utc),
|
|
}
|
|
|
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
logger.info("JWT token created successfully", {
|
|
"expires_in_hours": settings.JWT_EXPIRY_HOURS,
|
|
"algorithm": settings.JWT_ALGORITHM
|
|
})
|
|
return token
|
|
except Exception as e:
|
|
logger.error(f"Failed to create JWT token: {e}")
|
|
raise
|
|
|
|
|
|
def decode_jwt_token(token: str) -> dict:
|
|
"""Decode and validate JWT token"""
|
|
logger.info("Decoding JWT token")
|
|
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
logger.info("JWT token decoded successfully", {
|
|
"user_id": payload.get("sub"),
|
|
"user_name": payload.get("name"),
|
|
"expires_at": datetime.fromtimestamp(payload.get("exp", 0), timezone.utc).isoformat()
|
|
})
|
|
return payload
|
|
except jwt.ExpiredSignatureError as e:
|
|
logger.warning("JWT token expired", {"error": str(e)})
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
|
)
|
|
except jwt.InvalidTokenError as e:
|
|
logger.warning("Invalid JWT token", {"error": str(e)})
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error decoding JWT token: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token validation failed"
|
|
)
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
) -> dict:
|
|
"""Dependency to get current user from JWT token"""
|
|
logger.info("Getting current user from credentials")
|
|
|
|
if not credentials:
|
|
logger.warning("No credentials provided")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
|
)
|
|
|
|
logger.info("Credentials found, decoding token")
|
|
return decode_jwt_token(credentials.credentials)
|
|
|
|
|
|
@router.get("/login", response_model=AuthUrlResponse)
|
|
async def login():
|
|
"""Get Microsoft OAuth2 authorization URL"""
|
|
logger.info("Login endpoint called")
|
|
|
|
msal_app = get_msal_app()
|
|
|
|
if not msal_app:
|
|
logger.error("MSAL app not available for login")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Authentication not configured. Please set ENTRA_TENANT_ID, ENTRA_CLIENT_ID, and ENTRA_CLIENT_SECRET.",
|
|
)
|
|
|
|
try:
|
|
logger.info("Generating authorization URL", {
|
|
"scopes": ["User.Read"],
|
|
"redirect_uri": settings.ENTRA_REDIRECT_URI
|
|
})
|
|
auth_url = msal_app.get_authorization_request_url(
|
|
scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI
|
|
)
|
|
|
|
logger.info("Authorization URL generated successfully", {
|
|
"url_length": len(auth_url),
|
|
"url_start": auth_url[:100] + "..." if len(auth_url) > 100 else auth_url
|
|
})
|
|
|
|
return AuthUrlResponse(auth_url=auth_url)
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate authorization URL: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to generate authorization URL"
|
|
)
|
|
|
|
|
|
@router.post("/callback", response_model=AuthCallbackResponse)
|
|
async def callback(request: AuthCallbackRequest):
|
|
"""Exchange authorization code for tokens"""
|
|
logger.info("Callback endpoint called", {
|
|
"code_length": len(request.code) if request.code else 0,
|
|
"code_start": request.code[:50] + "..." if request.code and len(request.code) > 50 else request.code
|
|
})
|
|
|
|
msal_app = get_msal_app()
|
|
|
|
if not msal_app:
|
|
logger.error("MSAL app not available for callback")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Authentication not configured",
|
|
)
|
|
|
|
try:
|
|
logger.info("Exchanging authorization code for tokens", {
|
|
"scopes": ["User.Read"],
|
|
"redirect_uri": settings.ENTRA_REDIRECT_URI
|
|
})
|
|
|
|
result = msal_app.acquire_token_by_authorization_code(
|
|
code=request.code,
|
|
scopes=["User.Read"],
|
|
redirect_uri=settings.ENTRA_REDIRECT_URI,
|
|
)
|
|
|
|
logger.info("Token exchange result", {
|
|
"has_access_token": "access_token" in result,
|
|
"has_id_token": "id_token" in result,
|
|
"has_error": "error" in result,
|
|
"error": result.get("error"),
|
|
"error_description": result.get("error_description")
|
|
})
|
|
|
|
if "error" in result:
|
|
logger.error("Token exchange failed", {
|
|
"error": result.get("error"),
|
|
"error_description": result.get("error_description"),
|
|
"correlation_id": result.get("correlation_id")
|
|
})
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Authentication failed: {result.get('error_description', result.get('error'))}",
|
|
)
|
|
|
|
# Extract user info from ID token claims
|
|
id_token_claims = result.get("id_token_claims", {})
|
|
logger.info("ID token claims extracted", {
|
|
"claims_keys": list(id_token_claims.keys()),
|
|
"user_id": id_token_claims.get("oid") or id_token_claims.get("sub"),
|
|
"user_name": id_token_claims.get("name"),
|
|
"user_email": id_token_claims.get("preferred_username")
|
|
})
|
|
|
|
# Create our JWT token
|
|
token = create_jwt_token(id_token_claims)
|
|
|
|
response_data = AuthCallbackResponse(
|
|
token=token,
|
|
user=UserResponse(
|
|
id=id_token_claims.get("oid") or id_token_claims.get("sub"),
|
|
name=id_token_claims.get("name"),
|
|
email=id_token_claims.get("preferred_username"),
|
|
),
|
|
)
|
|
|
|
logger.info("Callback completed successfully", {
|
|
"user_id": response_data.user.id,
|
|
"user_name": response_data.user.name
|
|
})
|
|
|
|
return response_data
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTP exceptions as-is
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in callback: {e}", {"traceback": str(e)})
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error during authentication"
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def me(current_user: dict = Depends(get_current_user)):
|
|
"""Get current user info"""
|
|
logger.info("Me endpoint called", {
|
|
"user_id": current_user.get("sub"),
|
|
"user_name": current_user.get("name")
|
|
})
|
|
|
|
return UserResponse(
|
|
id=current_user.get("sub"),
|
|
name=current_user.get("name"),
|
|
email=current_user.get("email"),
|
|
)
|
|
|
|
|
|
@router.post("/logout")
|
|
async def logout():
|
|
"""Logout (client should clear token)"""
|
|
logger.info("Logout endpoint called")
|
|
return {"message": "Logged out successfully"}
|
|
|
|
|
|
@router.get("/status")
|
|
async def auth_status():
|
|
"""Check if authentication is configured"""
|
|
logger.info("Auth status endpoint called")
|
|
|
|
configured = all(
|
|
[
|
|
settings.ENTRA_TENANT_ID,
|
|
settings.ENTRA_CLIENT_ID,
|
|
settings.ENTRA_CLIENT_SECRET,
|
|
]
|
|
)
|
|
|
|
status_info = {
|
|
"configured": configured,
|
|
"has_tenant_id": bool(settings.ENTRA_TENANT_ID),
|
|
"has_client_id": bool(settings.ENTRA_CLIENT_ID),
|
|
"has_client_secret": bool(settings.ENTRA_CLIENT_SECRET),
|
|
}
|
|
|
|
logger.info("Auth status checked", status_info)
|
|
|
|
return {"configured": configured}
|