"""Base classes for AI providers.""" import asyncio import logging from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from enum import Enum from typing import Literal, TypeVar class ContentCategory(str, Enum): """Categories of problematic content.""" SAFE = "safe" HARASSMENT = "harassment" HATE_SPEECH = "hate_speech" SEXUAL = "sexual" VIOLENCE = "violence" SELF_HARM = "self_harm" SPAM = "spam" SCAM = "scam" MISINFORMATION = "misinformation" _T = TypeVar("_T") @dataclass(frozen=True) class RetryConfig: """Retry configuration for AI calls.""" retries: int = 3 base_delay: float = 0.25 max_delay: float = 2.0 def parse_categories(values: list[str]) -> list[ContentCategory]: """Parse category values into ContentCategory enums.""" categories: list[ContentCategory] = [] for value in values: try: categories.append(ContentCategory(value)) except ValueError: continue return categories async def run_with_retries( operation: Callable[[], Awaitable[_T]], *, config: RetryConfig | None = None, logger: logging.Logger | None = None, operation_name: str = "AI call", ) -> _T: """Run an async operation with retries and backoff.""" retry_config = config or RetryConfig() delay = retry_config.base_delay last_error: Exception | None = None for attempt in range(1, retry_config.retries + 1): try: return await operation() except Exception as error: # noqa: BLE001 - we re-raise after retries last_error = error if attempt >= retry_config.retries: raise if logger: logger.warning( "%s failed (attempt %s/%s): %s", operation_name, attempt, retry_config.retries, error, ) await asyncio.sleep(delay) delay = min(retry_config.max_delay, delay * 2) if last_error: raise last_error raise RuntimeError("Retry loop exited unexpectedly") @dataclass class ModerationResult: """Result of AI content moderation.""" is_flagged: bool = False confidence: float = 0.0 # 0.0 to 1.0 categories: list[ContentCategory] = field(default_factory=list) explanation: str = "" suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none" @property def severity(self) -> int: """Get severity score 0-100 based on confidence and categories.""" if not self.is_flagged: return 0 # Base severity from confidence severity = int(self.confidence * 50) # Add severity based on category high_severity = { ContentCategory.HATE_SPEECH, ContentCategory.SELF_HARM, ContentCategory.SCAM, } medium_severity = { ContentCategory.HARASSMENT, ContentCategory.VIOLENCE, ContentCategory.SEXUAL, } for cat in self.categories: if cat in high_severity: severity += 30 elif cat in medium_severity: severity += 20 else: severity += 10 return min(severity, 100) @dataclass class ImageAnalysisResult: """Result of AI image analysis.""" is_nsfw: bool = False is_violent: bool = False is_disturbing: bool = False confidence: float = 0.0 description: str = "" categories: list[str] = field(default_factory=list) @dataclass class PhishingAnalysisResult: """Result of AI phishing/scam analysis.""" is_phishing: bool = False confidence: float = 0.0 risk_factors: list[str] = field(default_factory=list) explanation: str = "" class AIProvider(ABC): """Abstract base class for AI providers.""" @abstractmethod async def moderate_text( self, content: str, context: str | None = None, sensitivity: int = 50, ) -> ModerationResult: """ Analyze text content for policy violations. Args: content: The text to analyze context: Optional context about the conversation/server sensitivity: 0-100, higher means more strict Returns: ModerationResult with analysis """ pass @abstractmethod async def analyze_image( self, image_url: str, sensitivity: int = 50, ) -> ImageAnalysisResult: """ Analyze an image for NSFW or inappropriate content. Args: image_url: URL of the image to analyze sensitivity: 0-100, higher means more strict Returns: ImageAnalysisResult with analysis """ pass @abstractmethod async def analyze_phishing( self, url: str, message_content: str | None = None, ) -> PhishingAnalysisResult: """ Analyze a URL for phishing/scam indicators. Args: url: The URL to analyze message_content: Optional full message for context Returns: PhishingAnalysisResult with analysis """ pass @abstractmethod async def close(self) -> None: """Clean up resources.""" pass