71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
import json
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from ..config import settings
|
|
from ..middleware.auth import require_auth
|
|
from ..models.schemas import ChatRequest, ChatResponse, ProviderListResponse
|
|
from ..services.provider_manager import provider_manager
|
|
|
|
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
|
|
|
|
|
@router.post("/", response_model=ChatResponse)
|
|
async def chat(request: ChatRequest, user: dict = Depends(require_auth)):
|
|
"""
|
|
Non-streaming chat endpoint
|
|
"""
|
|
try:
|
|
provider = provider_manager.get_provider(request.provider)
|
|
response = await provider.chat(request.message)
|
|
|
|
return ChatResponse(message=response, provider=provider.get_provider_name())
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Error processing request: {str(e)}",
|
|
)
|
|
|
|
|
|
@router.post("/stream")
|
|
async def chat_stream(request: ChatRequest, user: dict = Depends(require_auth)):
|
|
"""
|
|
Streaming chat endpoint - returns SSE (Server-Sent Events)
|
|
"""
|
|
try:
|
|
provider = provider_manager.get_provider(request.provider)
|
|
|
|
async def event_generator():
|
|
try:
|
|
async for chunk in provider.chat_stream(request.message):
|
|
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
|
|
|
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
except Exception as e:
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
},
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
|
|
|
|
@router.get("/providers", response_model=ProviderListResponse)
|
|
async def list_providers():
|
|
"""
|
|
List available providers
|
|
"""
|
|
return ProviderListResponse(
|
|
providers=provider_manager.get_available_providers(),
|
|
default=settings.DEFAULT_PROVIDER,
|
|
)
|