try to fix entra
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
@@ -13,129 +14,258 @@ from ..models.schemas import (
|
||||
UserResponse,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('/app/auth.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def get_msal_app():
|
||||
"""Create MSAL confidential client application"""
|
||||
if not all(
|
||||
[
|
||||
settings.ENTRA_TENANT_ID,
|
||||
settings.ENTRA_CLIENT_ID,
|
||||
settings.ENTRA_CLIENT_SECRET,
|
||||
]
|
||||
):
|
||||
logger.info("Checking MSAL configuration")
|
||||
|
||||
required_settings = [
|
||||
("ENTRA_TENANT_ID", settings.ENTRA_TENANT_ID),
|
||||
("ENTRA_CLIENT_ID", settings.ENTRA_CLIENT_ID),
|
||||
("ENTRA_CLIENT_SECRET", settings.ENTRA_CLIENT_SECRET),
|
||||
]
|
||||
|
||||
missing_settings = [name for name, value in required_settings if not value]
|
||||
if missing_settings:
|
||||
logger.error(f"Missing required Entra ID settings: {missing_settings}")
|
||||
return None
|
||||
|
||||
return msal.ConfidentialClientApplication(
|
||||
client_id=settings.ENTRA_CLIENT_ID,
|
||||
client_credential=settings.ENTRA_CLIENT_SECRET,
|
||||
authority=f"https://login.microsoftonline.com/{settings.ENTRA_TENANT_ID}",
|
||||
)
|
||||
logger.info("All Entra ID settings present, creating MSAL app")
|
||||
try:
|
||||
msal_app = msal.ConfidentialClientApplication(
|
||||
client_id=settings.ENTRA_CLIENT_ID,
|
||||
client_credential=settings.ENTRA_CLIENT_SECRET,
|
||||
authority=f"https://login.microsoftonline.com/{settings.ENTRA_TENANT_ID}",
|
||||
)
|
||||
logger.info("MSAL application created successfully")
|
||||
return msal_app
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create MSAL application: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_jwt_token(user_data: dict) -> str:
|
||||
"""Create JWT token with user data"""
|
||||
payload = {
|
||||
"sub": user_data.get("oid") or user_data.get("sub"),
|
||||
"name": user_data.get("name"),
|
||||
"email": user_data.get("preferred_username"),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRY_HOURS),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
logger.info("Creating JWT token", {
|
||||
"user_id": user_data.get("oid") or user_data.get("sub"),
|
||||
"user_name": user_data.get("name"),
|
||||
"user_email": user_data.get("preferred_username")
|
||||
})
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"sub": user_data.get("oid") or user_data.get("sub"),
|
||||
"name": user_data.get("name"),
|
||||
"email": user_data.get("preferred_username"),
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRY_HOURS),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
logger.info("JWT token created successfully", {
|
||||
"expires_in_hours": settings.JWT_EXPIRY_HOURS,
|
||||
"algorithm": settings.JWT_ALGORITHM
|
||||
})
|
||||
return token
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create JWT token: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def decode_jwt_token(token: str) -> dict:
|
||||
"""Decode and validate JWT token"""
|
||||
logger.info("Decoding JWT token")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
logger.info("JWT token decoded successfully", {
|
||||
"user_id": payload.get("sub"),
|
||||
"user_name": payload.get("name"),
|
||||
"expires_at": datetime.fromtimestamp(payload.get("exp", 0), timezone.utc).isoformat()
|
||||
})
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
logger.warning("JWT token expired", {"error": str(e)})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.warning("Invalid JWT token", {"error": str(e)})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error decoding JWT token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token validation failed"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> dict:
|
||||
"""Dependency to get current user from JWT token"""
|
||||
logger.info("Getting current user from credentials")
|
||||
|
||||
if not credentials:
|
||||
logger.warning("No credentials provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
logger.info("Credentials found, decoding token")
|
||||
return decode_jwt_token(credentials.credentials)
|
||||
|
||||
|
||||
@router.get("/login", response_model=AuthUrlResponse)
|
||||
async def login():
|
||||
"""Get Microsoft OAuth2 authorization URL"""
|
||||
logger.info("Login endpoint called")
|
||||
|
||||
msal_app = get_msal_app()
|
||||
|
||||
if not msal_app:
|
||||
logger.error("MSAL app not available for login")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication not configured. Please set ENTRA_TENANT_ID, ENTRA_CLIENT_ID, and ENTRA_CLIENT_SECRET.",
|
||||
)
|
||||
|
||||
auth_url = msal_app.get_authorization_request_url(
|
||||
scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI
|
||||
)
|
||||
try:
|
||||
logger.info("Generating authorization URL", {
|
||||
"scopes": ["User.Read"],
|
||||
"redirect_uri": settings.ENTRA_REDIRECT_URI
|
||||
})
|
||||
auth_url = msal_app.get_authorization_request_url(
|
||||
scopes=["User.Read"], redirect_uri=settings.ENTRA_REDIRECT_URI
|
||||
)
|
||||
|
||||
return AuthUrlResponse(auth_url=auth_url)
|
||||
logger.info("Authorization URL generated successfully", {
|
||||
"url_length": len(auth_url),
|
||||
"url_start": auth_url[:100] + "..." if len(auth_url) > 100 else auth_url
|
||||
})
|
||||
|
||||
return AuthUrlResponse(auth_url=auth_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate authorization URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate authorization URL"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/callback", response_model=AuthCallbackResponse)
|
||||
async def callback(request: AuthCallbackRequest):
|
||||
"""Exchange authorization code for tokens"""
|
||||
logger.info("Callback endpoint called", {
|
||||
"code_length": len(request.code) if request.code else 0,
|
||||
"code_start": request.code[:50] + "..." if request.code and len(request.code) > 50 else request.code
|
||||
})
|
||||
|
||||
msal_app = get_msal_app()
|
||||
|
||||
if not msal_app:
|
||||
logger.error("MSAL app not available for callback")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication not configured",
|
||||
)
|
||||
|
||||
result = msal_app.acquire_token_by_authorization_code(
|
||||
code=request.code,
|
||||
scopes=["User.Read"],
|
||||
redirect_uri=settings.ENTRA_REDIRECT_URI,
|
||||
)
|
||||
try:
|
||||
logger.info("Exchanging authorization code for tokens", {
|
||||
"scopes": ["User.Read"],
|
||||
"redirect_uri": settings.ENTRA_REDIRECT_URI
|
||||
})
|
||||
|
||||
if "error" in result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Authentication failed: {result.get('error_description', result.get('error'))}",
|
||||
result = msal_app.acquire_token_by_authorization_code(
|
||||
code=request.code,
|
||||
scopes=["User.Read"],
|
||||
redirect_uri=settings.ENTRA_REDIRECT_URI,
|
||||
)
|
||||
|
||||
# Extract user info from ID token claims
|
||||
id_token_claims = result.get("id_token_claims", {})
|
||||
logger.info("Token exchange result", {
|
||||
"has_access_token": "access_token" in result,
|
||||
"has_id_token": "id_token" in result,
|
||||
"has_error": "error" in result,
|
||||
"error": result.get("error"),
|
||||
"error_description": result.get("error_description")
|
||||
})
|
||||
|
||||
# Create our JWT token
|
||||
token = create_jwt_token(id_token_claims)
|
||||
if "error" in result:
|
||||
logger.error("Token exchange failed", {
|
||||
"error": result.get("error"),
|
||||
"error_description": result.get("error_description"),
|
||||
"correlation_id": result.get("correlation_id")
|
||||
})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Authentication failed: {result.get('error_description', result.get('error'))}",
|
||||
)
|
||||
|
||||
return AuthCallbackResponse(
|
||||
token=token,
|
||||
user=UserResponse(
|
||||
id=id_token_claims.get("oid") or id_token_claims.get("sub"),
|
||||
name=id_token_claims.get("name"),
|
||||
email=id_token_claims.get("preferred_username"),
|
||||
),
|
||||
)
|
||||
# Extract user info from ID token claims
|
||||
id_token_claims = result.get("id_token_claims", {})
|
||||
logger.info("ID token claims extracted", {
|
||||
"claims_keys": list(id_token_claims.keys()),
|
||||
"user_id": id_token_claims.get("oid") or id_token_claims.get("sub"),
|
||||
"user_name": id_token_claims.get("name"),
|
||||
"user_email": id_token_claims.get("preferred_username")
|
||||
})
|
||||
|
||||
# Create our JWT token
|
||||
token = create_jwt_token(id_token_claims)
|
||||
|
||||
response_data = AuthCallbackResponse(
|
||||
token=token,
|
||||
user=UserResponse(
|
||||
id=id_token_claims.get("oid") or id_token_claims.get("sub"),
|
||||
name=id_token_claims.get("name"),
|
||||
email=id_token_claims.get("preferred_username"),
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Callback completed successfully", {
|
||||
"user_id": response_data.user.id,
|
||||
"user_name": response_data.user.name
|
||||
})
|
||||
|
||||
return response_data
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in callback: {e}", {"traceback": str(e)})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error during authentication"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(current_user: dict = Depends(get_current_user)):
|
||||
"""Get current user info"""
|
||||
logger.info("Me endpoint called", {
|
||||
"user_id": current_user.get("sub"),
|
||||
"user_name": current_user.get("name")
|
||||
})
|
||||
|
||||
return UserResponse(
|
||||
id=current_user.get("sub"),
|
||||
name=current_user.get("name"),
|
||||
@@ -146,12 +276,15 @@ async def me(current_user: dict = Depends(get_current_user)):
|
||||
@router.post("/logout")
|
||||
async def logout():
|
||||
"""Logout (client should clear token)"""
|
||||
logger.info("Logout endpoint called")
|
||||
return {"message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def auth_status():
|
||||
"""Check if authentication is configured"""
|
||||
logger.info("Auth status endpoint called")
|
||||
|
||||
configured = all(
|
||||
[
|
||||
settings.ENTRA_TENANT_ID,
|
||||
@@ -159,4 +292,14 @@ async def auth_status():
|
||||
settings.ENTRA_CLIENT_SECRET,
|
||||
]
|
||||
)
|
||||
|
||||
status_info = {
|
||||
"configured": configured,
|
||||
"has_tenant_id": bool(settings.ENTRA_TENANT_ID),
|
||||
"has_client_id": bool(settings.ENTRA_CLIENT_ID),
|
||||
"has_client_secret": bool(settings.ENTRA_CLIENT_SECRET),
|
||||
}
|
||||
|
||||
logger.info("Auth status checked", status_info)
|
||||
|
||||
return {"configured": configured}
|
||||
|
||||
@@ -9,7 +9,14 @@ from .config import settings
|
||||
from .services.provider_manager import provider_manager
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('/app/devden.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
@@ -30,9 +37,25 @@ app.include_router(auth.router)
|
||||
app.include_router(chat.router)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request, exc):
|
||||
"""Global exception handler to log all errors"""
|
||||
logger.error(f"Unhandled exception: {exc}", {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"traceback": str(exc)
|
||||
})
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
logger.info("Health check requested")
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "healthy",
|
||||
@@ -41,6 +64,18 @@ async def health_check():
|
||||
)
|
||||
|
||||
|
||||
@app.get("/logs")
|
||||
async def get_logs():
|
||||
"""Get recent log entries (for debugging)"""
|
||||
try:
|
||||
with open('/app/devden.log', 'r') as f:
|
||||
lines = f.readlines()[-50:] # Last 50 lines
|
||||
return {"logs": lines}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read logs: {e}")
|
||||
return {"error": "Failed to read logs"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("DevDen API starting up...")
|
||||
|
||||
Reference in New Issue
Block a user