163 lines
4.7 KiB
Python
163 lines
4.7 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
|
|
import jwt
|
|
import msal
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
from ..config import settings
|
|
from ..models.schemas import (
|
|
AuthCallbackRequest,
|
|
AuthCallbackResponse,
|
|
AuthUrlResponse,
|
|
UserResponse,
|
|
)
|
|
|
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
def get_msal_app():
|
|
"""Create MSAL confidential client application"""
|
|
if not all(
|
|
[
|
|
settings.ENTRA_TENANT_ID,
|
|
settings.ENTRA_CLIENT_ID,
|
|
settings.ENTRA_CLIENT_SECRET,
|
|
]
|
|
):
|
|
return None
|
|
|
|
return msal.ConfidentialClientApplication(
|
|
client_id=settings.ENTRA_CLIENT_ID,
|
|
client_credential=settings.ENTRA_CLIENT_SECRET,
|
|
authority=f"https://login.microsoftonline.com/{settings.ENTRA_TENANT_ID}",
|
|
)
|
|
|
|
|
|
def create_jwt_token(user_data: dict) -> str:
|
|
"""Create JWT token with user data"""
|
|
payload = {
|
|
"sub": user_data.get("oid") or user_data.get("sub"),
|
|
"name": user_data.get("name"),
|
|
"email": user_data.get("preferred_username"),
|
|
"exp": datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRY_HOURS),
|
|
"iat": datetime.now(timezone.utc),
|
|
}
|
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
|
|
|
|
def decode_jwt_token(token: str) -> dict:
|
|
"""Decode and validate JWT token"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
|
)
|
|
except jwt.InvalidTokenError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
|
)
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
) -> dict:
|
|
"""Dependency to get current user from JWT token"""
|
|
if not credentials:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
|
)
|
|
|
|
return decode_jwt_token(credentials.credentials)
|
|
|
|
|
|
@router.get("/login", response_model=AuthUrlResponse)
|
|
async def login():
|
|
"""Get Microsoft OAuth2 authorization URL"""
|
|
msal_app = get_msal_app()
|
|
|
|
if not msal_app:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Authentication not configured. Please set ENTRA_TENANT_ID, ENTRA_CLIENT_ID, and ENTRA_CLIENT_SECRET.",
|
|
)
|
|
|
|
auth_url = msal_app.get_authorization_request_url(
|
|
scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI
|
|
)
|
|
|
|
return AuthUrlResponse(auth_url=auth_url)
|
|
|
|
|
|
@router.post("/callback", response_model=AuthCallbackResponse)
|
|
async def callback(request: AuthCallbackRequest):
|
|
"""Exchange authorization code for tokens"""
|
|
msal_app = get_msal_app()
|
|
|
|
if not msal_app:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Authentication not configured",
|
|
)
|
|
|
|
result = msal_app.acquire_token_by_authorization_code(
|
|
code=request.code,
|
|
scopes=["User.Read"],
|
|
redirect_uri=settings.ENTRA_REDIRECT_URI,
|
|
)
|
|
|
|
if "error" in result:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Authentication failed: {result.get('error_description', result.get('error'))}",
|
|
)
|
|
|
|
# Extract user info from ID token claims
|
|
id_token_claims = result.get("id_token_claims", {})
|
|
|
|
# Create our JWT token
|
|
token = create_jwt_token(id_token_claims)
|
|
|
|
return AuthCallbackResponse(
|
|
token=token,
|
|
user=UserResponse(
|
|
id=id_token_claims.get("oid") or id_token_claims.get("sub"),
|
|
name=id_token_claims.get("name"),
|
|
email=id_token_claims.get("preferred_username"),
|
|
),
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def me(current_user: dict = Depends(get_current_user)):
|
|
"""Get current user info"""
|
|
return UserResponse(
|
|
id=current_user.get("sub"),
|
|
name=current_user.get("name"),
|
|
email=current_user.get("email"),
|
|
)
|
|
|
|
|
|
@router.post("/logout")
|
|
async def logout():
|
|
"""Logout (client should clear token)"""
|
|
return {"message": "Logged out successfully"}
|
|
|
|
|
|
@router.get("/status")
|
|
async def auth_status():
|
|
"""Check if authentication is configured"""
|
|
configured = all(
|
|
[
|
|
settings.ENTRA_TENANT_ID,
|
|
settings.ENTRA_CLIENT_ID,
|
|
settings.ENTRA_CLIENT_SECRET,
|
|
]
|
|
)
|
|
return {"configured": configured}
|