"""WebSocket support for real-time dashboard updates.""" import asyncio import logging from datetime import datetime from typing import Any from fastapi import APIRouter, WebSocket, WebSocketDisconnect from guardden.dashboard.config import DashboardSettings from guardden.dashboard.schemas import WebSocketEvent logger = logging.getLogger(__name__) class ConnectionManager: """Manage WebSocket connections for real-time updates.""" def __init__(self) -> None: self.active_connections: dict[int, list[WebSocket]] = {} self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, guild_id: int) -> None: """Accept a new WebSocket connection.""" await websocket.accept() async with self._lock: if guild_id not in self.active_connections: self.active_connections[guild_id] = [] self.active_connections[guild_id].append(websocket) logger.info("New WebSocket connection for guild %s", guild_id) async def disconnect(self, websocket: WebSocket, guild_id: int) -> None: """Remove a WebSocket connection.""" async with self._lock: if guild_id in self.active_connections: if websocket in self.active_connections[guild_id]: self.active_connections[guild_id].remove(websocket) if not self.active_connections[guild_id]: del self.active_connections[guild_id] logger.info("WebSocket disconnected for guild %s", guild_id) async def broadcast_to_guild(self, guild_id: int, event: WebSocketEvent) -> None: """Broadcast an event to all connections for a specific guild.""" async with self._lock: connections = self.active_connections.get(guild_id, []).copy() if not connections: return # Convert event to JSON message = event.model_dump_json() # Send to all connections dead_connections = [] for connection in connections: try: await connection.send_text(message) except Exception as e: logger.warning("Failed to send message to WebSocket: %s", e) dead_connections.append(connection) # Clean up dead connections if dead_connections: async with self._lock: if guild_id in self.active_connections: for conn in dead_connections: if conn in self.active_connections[guild_id]: self.active_connections[guild_id].remove(conn) if not self.active_connections[guild_id]: del self.active_connections[guild_id] async def broadcast_to_all(self, event: WebSocketEvent) -> None: """Broadcast an event to all connections.""" async with self._lock: all_guilds = list(self.active_connections.keys()) for guild_id in all_guilds: await self.broadcast_to_guild(guild_id, event) def get_connection_count(self, guild_id: int | None = None) -> int: """Get the number of active connections.""" if guild_id is not None: return len(self.active_connections.get(guild_id, [])) return sum(len(conns) for conns in self.active_connections.values()) # Global connection manager connection_manager = ConnectionManager() def create_websocket_router(settings: DashboardSettings) -> APIRouter: """Create the WebSocket API router.""" router = APIRouter() @router.websocket("/ws/events") async def websocket_events(websocket: WebSocket, guild_id: int) -> None: """WebSocket endpoint for real-time events.""" await connection_manager.connect(websocket, guild_id) try: # Send initial connection confirmation await websocket.send_json( { "type": "connected", "guild_id": guild_id, "timestamp": datetime.now().isoformat(), "data": {"message": "Connected to real-time events"}, } ) # Keep connection alive and handle incoming messages while True: try: # Wait for messages from client (ping/pong, etc.) data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) # Echo back as heartbeat if data == "ping": await websocket.send_text("pong") except asyncio.TimeoutError: # Send periodic ping to keep connection alive await websocket.send_json( { "type": "ping", "guild_id": guild_id, "timestamp": datetime.now().isoformat(), "data": {}, } ) except WebSocketDisconnect: logger.info("Client disconnected from WebSocket for guild %s", guild_id) except Exception as e: logger.error("WebSocket error for guild %s: %s", guild_id, e) finally: await connection_manager.disconnect(websocket, guild_id) return router # Helper functions to broadcast events from other parts of the application async def broadcast_moderation_action( guild_id: int, action: str, target_id: int, target_name: str, moderator_name: str, reason: str | None = None, ) -> None: """Broadcast a moderation action event.""" event = WebSocketEvent( type="moderation_action", guild_id=guild_id, timestamp=datetime.now(), data={ "action": action, "target_id": target_id, "target_name": target_name, "moderator_name": moderator_name, "reason": reason, }, ) await connection_manager.broadcast_to_guild(guild_id, event) async def broadcast_user_join( guild_id: int, user_id: int, username: str, ) -> None: """Broadcast a user join event.""" event = WebSocketEvent( type="user_join", guild_id=guild_id, timestamp=datetime.now(), data={ "user_id": user_id, "username": username, }, ) await connection_manager.broadcast_to_guild(guild_id, event) async def broadcast_ai_alert( guild_id: int, user_id: int, severity: str, category: str, confidence: float, ) -> None: """Broadcast an AI moderation alert.""" event = WebSocketEvent( type="ai_alert", guild_id=guild_id, timestamp=datetime.now(), data={ "user_id": user_id, "severity": severity, "category": category, "confidence": confidence, }, ) await connection_manager.broadcast_to_guild(guild_id, event) async def broadcast_system_event( event_type: str, data: dict[str, Any], guild_id: int | None = None, ) -> None: """Broadcast a generic system event.""" event = WebSocketEvent( type=event_type, guild_id=guild_id or 0, timestamp=datetime.now(), data=data, ) if guild_id: await connection_manager.broadcast_to_guild(guild_id, event) else: await connection_manager.broadcast_to_all(event)