feature/entra id authentication added
This commit is contained in:
162
backend/app/api/auth.py
Normal file
162
backend/app/api/auth.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
import msal
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from ..config import settings
|
||||
from ..models.schemas import (
|
||||
AuthCallbackRequest,
|
||||
AuthCallbackResponse,
|
||||
AuthUrlResponse,
|
||||
UserResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def get_msal_app():
|
||||
"""Create MSAL confidential client application"""
|
||||
if not all(
|
||||
[
|
||||
settings.ENTRA_TENANT_ID,
|
||||
settings.ENTRA_CLIENT_ID,
|
||||
settings.ENTRA_CLIENT_SECRET,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return msal.ConfidentialClientApplication(
|
||||
client_id=settings.ENTRA_CLIENT_ID,
|
||||
client_credential=settings.ENTRA_CLIENT_SECRET,
|
||||
authority=f"https://login.microsoftonline.com/{settings.ENTRA_TENANT_ID}",
|
||||
)
|
||||
|
||||
|
||||
def create_jwt_token(user_data: dict) -> str:
|
||||
"""Create JWT token with user data"""
|
||||
payload = {
|
||||
"sub": user_data.get("oid") or user_data.get("sub"),
|
||||
"name": user_data.get("name"),
|
||||
"email": user_data.get("preferred_username"),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRY_HOURS),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def decode_jwt_token(token: str) -> dict:
|
||||
"""Decode and validate JWT token"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> dict:
|
||||
"""Dependency to get current user from JWT token"""
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
return decode_jwt_token(credentials.credentials)
|
||||
|
||||
|
||||
@router.get("/login", response_model=AuthUrlResponse)
|
||||
async def login():
|
||||
"""Get Microsoft OAuth2 authorization URL"""
|
||||
msal_app = get_msal_app()
|
||||
|
||||
if not msal_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication not configured. Please set ENTRA_TENANT_ID, ENTRA_CLIENT_ID, and ENTRA_CLIENT_SECRET.",
|
||||
)
|
||||
|
||||
auth_url = msal_app.get_authorization_request_url(
|
||||
scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI
|
||||
)
|
||||
|
||||
return AuthUrlResponse(auth_url=auth_url)
|
||||
|
||||
|
||||
@router.post("/callback", response_model=AuthCallbackResponse)
|
||||
async def callback(request: AuthCallbackRequest):
|
||||
"""Exchange authorization code for tokens"""
|
||||
msal_app = get_msal_app()
|
||||
|
||||
if not msal_app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
result = msal_app.acquire_token_by_authorization_code(
|
||||
code=request.code,
|
||||
scopes=["User.Read"],
|
||||
redirect_uri=settings.ENTRA_REDIRECT_URI,
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Authentication failed: {result.get('error_description', result.get('error'))}",
|
||||
)
|
||||
|
||||
# Extract user info from ID token claims
|
||||
id_token_claims = result.get("id_token_claims", {})
|
||||
|
||||
# Create our JWT token
|
||||
token = create_jwt_token(id_token_claims)
|
||||
|
||||
return AuthCallbackResponse(
|
||||
token=token,
|
||||
user=UserResponse(
|
||||
id=id_token_claims.get("oid") or id_token_claims.get("sub"),
|
||||
name=id_token_claims.get("name"),
|
||||
email=id_token_claims.get("preferred_username"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(current_user: dict = Depends(get_current_user)):
|
||||
"""Get current user info"""
|
||||
return UserResponse(
|
||||
id=current_user.get("sub"),
|
||||
name=current_user.get("name"),
|
||||
email=current_user.get("email"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout():
|
||||
"""Logout (client should clear token)"""
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def auth_status():
|
||||
"""Check if authentication is configured"""
|
||||
configured = all(
|
||||
[
|
||||
settings.ENTRA_TENANT_ID,
|
||||
settings.ENTRA_CLIENT_ID,
|
||||
settings.ENTRA_CLIENT_SECRET,
|
||||
]
|
||||
)
|
||||
return {"configured": configured}
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..config import settings
|
||||
from ..middleware.auth import require_auth
|
||||
from ..models.schemas import ChatRequest, ChatResponse, ProviderListResponse
|
||||
from ..services.provider_manager import provider_manager
|
||||
|
||||
@@ -11,7 +12,7 @@ router = APIRouter(prefix="/api/chat", tags=["chat"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
async def chat(request: ChatRequest, user: dict = Depends(require_auth)):
|
||||
"""
|
||||
Non-streaming chat endpoint
|
||||
"""
|
||||
@@ -30,7 +31,7 @@ async def chat(request: ChatRequest):
|
||||
|
||||
|
||||
@router.post("/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
async def chat_stream(request: ChatRequest, user: dict = Depends(require_auth)):
|
||||
"""
|
||||
Streaming chat endpoint - returns SSE (Server-Sent Events)
|
||||
"""
|
||||
|
||||
@@ -25,6 +25,17 @@ class Settings(BaseSettings):
|
||||
RATE_LIMIT_REQUESTS: int = 10
|
||||
RATE_LIMIT_WINDOW: int = 60
|
||||
|
||||
# Microsoft Entra ID
|
||||
ENTRA_TENANT_ID: Optional[str] = None
|
||||
ENTRA_CLIENT_ID: Optional[str] = None
|
||||
ENTRA_CLIENT_SECRET: Optional[str] = None
|
||||
ENTRA_REDIRECT_URI: str = "http://localhost:3000/auth/callback"
|
||||
|
||||
# JWT
|
||||
JWT_SECRET: str = "change-this-in-production-use-a-secure-random-string"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_EXPIRY_HOURS: int = 24
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from .api import chat
|
||||
from .api import auth, chat
|
||||
from .config import settings
|
||||
from .services.provider_manager import provider_manager
|
||||
|
||||
@@ -26,6 +26,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(chat.router)
|
||||
|
||||
|
||||
|
||||
0
backend/app/middleware/__init__.py
Normal file
0
backend/app/middleware/__init__.py
Normal file
45
backend/app/middleware/auth.py
Normal file
45
backend/app/middleware/auth.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import jwt
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from ..config import settings
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def decode_jwt_token(token: str) -> dict:
|
||||
"""Decode and validate JWT token"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
async def require_auth(request: Request):
|
||||
"""Dependency to require authentication"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
if not auth_header:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authorization header",
|
||||
)
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
user = decode_jwt_token(token)
|
||||
request.state.user = user
|
||||
return user
|
||||
@@ -22,3 +22,23 @@ class ChatResponse(BaseModel):
|
||||
class ProviderListResponse(BaseModel):
|
||||
providers: List[str]
|
||||
default: str
|
||||
|
||||
|
||||
# Auth schemas
|
||||
class AuthUrlResponse(BaseModel):
|
||||
auth_url: str
|
||||
|
||||
|
||||
class AuthCallbackRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
class AuthCallbackResponse(BaseModel):
|
||||
token: str
|
||||
user: UserResponse
|
||||
|
||||
@@ -6,3 +6,5 @@ pydantic>=2.6.0
|
||||
pydantic-settings>=2.1.0
|
||||
python-dotenv>=1.0.0
|
||||
httpx>=0.27.0
|
||||
msal>=1.24.0
|
||||
PyJWT>=2.8.0
|
||||
|
||||
Reference in New Issue
Block a user