Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
211 lines
5.3 KiB
Python
211 lines
5.3 KiB
Python
"""Base classes for AI providers."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Literal, TypeVar
|
|
|
|
|
|
class ContentCategory(str, Enum):
|
|
"""Categories of problematic content."""
|
|
|
|
SAFE = "safe"
|
|
HARASSMENT = "harassment"
|
|
HATE_SPEECH = "hate_speech"
|
|
SEXUAL = "sexual"
|
|
VIOLENCE = "violence"
|
|
SELF_HARM = "self_harm"
|
|
SPAM = "spam"
|
|
SCAM = "scam"
|
|
MISINFORMATION = "misinformation"
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RetryConfig:
|
|
"""Retry configuration for AI calls."""
|
|
|
|
retries: int = 3
|
|
base_delay: float = 0.25
|
|
max_delay: float = 2.0
|
|
|
|
|
|
def parse_categories(values: list[str]) -> list[ContentCategory]:
|
|
"""Parse category values into ContentCategory enums."""
|
|
categories: list[ContentCategory] = []
|
|
for value in values:
|
|
try:
|
|
categories.append(ContentCategory(value))
|
|
except ValueError:
|
|
continue
|
|
return categories
|
|
|
|
|
|
async def run_with_retries(
|
|
operation: Callable[[], Awaitable[_T]],
|
|
*,
|
|
config: RetryConfig | None = None,
|
|
logger: logging.Logger | None = None,
|
|
operation_name: str = "AI call",
|
|
) -> _T:
|
|
"""Run an async operation with retries and backoff."""
|
|
retry_config = config or RetryConfig()
|
|
delay = retry_config.base_delay
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(1, retry_config.retries + 1):
|
|
try:
|
|
return await operation()
|
|
except Exception as error: # noqa: BLE001 - we re-raise after retries
|
|
last_error = error
|
|
if attempt >= retry_config.retries:
|
|
raise
|
|
if logger:
|
|
logger.warning(
|
|
"%s failed (attempt %s/%s): %s",
|
|
operation_name,
|
|
attempt,
|
|
retry_config.retries,
|
|
error,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
delay = min(retry_config.max_delay, delay * 2)
|
|
|
|
if last_error:
|
|
raise last_error
|
|
raise RuntimeError("Retry loop exited unexpectedly")
|
|
|
|
|
|
@dataclass
|
|
class ModerationResult:
|
|
"""Result of AI content moderation."""
|
|
|
|
is_flagged: bool = False
|
|
confidence: float = 0.0 # 0.0 to 1.0
|
|
categories: list[ContentCategory] = field(default_factory=list)
|
|
explanation: str = ""
|
|
suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none"
|
|
|
|
@property
|
|
def severity(self) -> int:
|
|
"""Get severity score 0-100 based on confidence and categories."""
|
|
if not self.is_flagged:
|
|
return 0
|
|
|
|
# Base severity from confidence
|
|
severity = int(self.confidence * 50)
|
|
|
|
# Add severity based on category
|
|
high_severity = {
|
|
ContentCategory.HATE_SPEECH,
|
|
ContentCategory.SELF_HARM,
|
|
ContentCategory.SCAM,
|
|
}
|
|
medium_severity = {
|
|
ContentCategory.HARASSMENT,
|
|
ContentCategory.VIOLENCE,
|
|
ContentCategory.SEXUAL,
|
|
}
|
|
|
|
for cat in self.categories:
|
|
if cat in high_severity:
|
|
severity += 30
|
|
elif cat in medium_severity:
|
|
severity += 20
|
|
else:
|
|
severity += 10
|
|
|
|
return min(severity, 100)
|
|
|
|
|
|
@dataclass
|
|
class ImageAnalysisResult:
|
|
"""Result of AI image analysis."""
|
|
|
|
is_nsfw: bool = False
|
|
is_violent: bool = False
|
|
is_disturbing: bool = False
|
|
confidence: float = 0.0
|
|
description: str = ""
|
|
categories: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class PhishingAnalysisResult:
|
|
"""Result of AI phishing/scam analysis."""
|
|
|
|
is_phishing: bool = False
|
|
confidence: float = 0.0
|
|
risk_factors: list[str] = field(default_factory=list)
|
|
explanation: str = ""
|
|
|
|
|
|
class AIProvider(ABC):
|
|
"""Abstract base class for AI providers."""
|
|
|
|
@abstractmethod
|
|
async def moderate_text(
|
|
self,
|
|
content: str,
|
|
context: str | None = None,
|
|
sensitivity: int = 50,
|
|
) -> ModerationResult:
|
|
"""
|
|
Analyze text content for policy violations.
|
|
|
|
Args:
|
|
content: The text to analyze
|
|
context: Optional context about the conversation/server
|
|
sensitivity: 0-100, higher means more strict
|
|
|
|
Returns:
|
|
ModerationResult with analysis
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def analyze_image(
|
|
self,
|
|
image_url: str,
|
|
sensitivity: int = 50,
|
|
) -> ImageAnalysisResult:
|
|
"""
|
|
Analyze an image for NSFW or inappropriate content.
|
|
|
|
Args:
|
|
image_url: URL of the image to analyze
|
|
sensitivity: 0-100, higher means more strict
|
|
|
|
Returns:
|
|
ImageAnalysisResult with analysis
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def analyze_phishing(
|
|
self,
|
|
url: str,
|
|
message_content: str | None = None,
|
|
) -> PhishingAnalysisResult:
|
|
"""
|
|
Analyze a URL for phishing/scam indicators.
|
|
|
|
Args:
|
|
url: The URL to analyze
|
|
message_content: Optional full message for context
|
|
|
|
Returns:
|
|
PhishingAnalysisResult with analysis
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def close(self) -> None:
|
|
"""Clean up resources."""
|
|
pass
|