from datetime import datetime, timedelta, timezone import jwt import msal from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from ..config import settings from ..models.schemas import ( AuthCallbackRequest, AuthCallbackResponse, AuthUrlResponse, UserResponse, ) router = APIRouter(prefix="/api/auth", tags=["auth"]) security = HTTPBearer(auto_error=False) def get_msal_app(): """Create MSAL confidential client application""" if not all( [ settings.ENTRA_TENANT_ID, settings.ENTRA_CLIENT_ID, settings.ENTRA_CLIENT_SECRET, ] ): return None return msal.ConfidentialClientApplication( client_id=settings.ENTRA_CLIENT_ID, client_credential=settings.ENTRA_CLIENT_SECRET, authority=f"https://login.microsoftonline.com/{settings.ENTRA_TENANT_ID}", ) def create_jwt_token(user_data: dict) -> str: """Create JWT token with user data""" payload = { "sub": user_data.get("oid") or user_data.get("sub"), "name": user_data.get("name"), "email": user_data.get("preferred_username"), "exp": datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRY_HOURS), "iat": datetime.now(timezone.utc), } return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) def decode_jwt_token(token: str) -> dict: """Decode and validate JWT token""" try: payload = jwt.decode( token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] ) return payload except jwt.ExpiredSignatureError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired" ) except jwt.InvalidTokenError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" ) async def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), ) -> dict: """Dependency to get current user from JWT token""" if not credentials: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated" ) return decode_jwt_token(credentials.credentials) @router.get("/login", response_model=AuthUrlResponse) async def login(): """Get Microsoft OAuth2 authorization URL""" msal_app = get_msal_app() if not msal_app: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Authentication not configured. Please set ENTRA_TENANT_ID, ENTRA_CLIENT_ID, and ENTRA_CLIENT_SECRET.", ) auth_url = msal_app.get_authorization_request_url( scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI ) return AuthUrlResponse(auth_url=auth_url) @router.post("/callback", response_model=AuthCallbackResponse) async def callback(request: AuthCallbackRequest): """Exchange authorization code for tokens""" msal_app = get_msal_app() if not msal_app: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Authentication not configured", ) result = msal_app.acquire_token_by_authorization_code( code=request.code, scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI, ) if "error" in result: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Authentication failed: {result.get('error_description', result.get('error'))}", ) # Extract user info from ID token claims id_token_claims = result.get("id_token_claims", {}) # Create our JWT token token = create_jwt_token(id_token_claims) return AuthCallbackResponse( token=token, user=UserResponse( id=id_token_claims.get("oid") or id_token_claims.get("sub"), name=id_token_claims.get("name"), email=id_token_claims.get("preferred_username"), ), ) @router.get("/me", response_model=UserResponse) async def me(current_user: dict = Depends(get_current_user)): """Get current user info""" return UserResponse( id=current_user.get("sub"), name=current_user.get("name"), email=current_user.get("email"), ) @router.post("/logout") async def logout(): """Logout (client should clear token)""" return {"message": "Logged out successfully"} @router.get("/status") async def auth_status(): """Check if authentication is configured""" configured = all( [ settings.ENTRA_TENANT_ID, settings.ENTRA_CLIENT_ID, settings.ENTRA_CLIENT_SECRET, ] ) return {"configured": configured}