Files
GuardDen/src/guardden/services/ai/base.py
latte 831eed8dbc
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
quick commit
2026-01-17 20:24:43 +01:00

211 lines
5.3 KiB
Python

"""Base classes for AI providers."""
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, TypeVar
class ContentCategory(str, Enum):
"""Categories of problematic content."""
SAFE = "safe"
HARASSMENT = "harassment"
HATE_SPEECH = "hate_speech"
SEXUAL = "sexual"
VIOLENCE = "violence"
SELF_HARM = "self_harm"
SPAM = "spam"
SCAM = "scam"
MISINFORMATION = "misinformation"
_T = TypeVar("_T")
@dataclass(frozen=True)
class RetryConfig:
"""Retry configuration for AI calls."""
retries: int = 3
base_delay: float = 0.25
max_delay: float = 2.0
def parse_categories(values: list[str]) -> list[ContentCategory]:
"""Parse category values into ContentCategory enums."""
categories: list[ContentCategory] = []
for value in values:
try:
categories.append(ContentCategory(value))
except ValueError:
continue
return categories
async def run_with_retries(
operation: Callable[[], Awaitable[_T]],
*,
config: RetryConfig | None = None,
logger: logging.Logger | None = None,
operation_name: str = "AI call",
) -> _T:
"""Run an async operation with retries and backoff."""
retry_config = config or RetryConfig()
delay = retry_config.base_delay
last_error: Exception | None = None
for attempt in range(1, retry_config.retries + 1):
try:
return await operation()
except Exception as error: # noqa: BLE001 - we re-raise after retries
last_error = error
if attempt >= retry_config.retries:
raise
if logger:
logger.warning(
"%s failed (attempt %s/%s): %s",
operation_name,
attempt,
retry_config.retries,
error,
)
await asyncio.sleep(delay)
delay = min(retry_config.max_delay, delay * 2)
if last_error:
raise last_error
raise RuntimeError("Retry loop exited unexpectedly")
@dataclass
class ModerationResult:
"""Result of AI content moderation."""
is_flagged: bool = False
confidence: float = 0.0 # 0.0 to 1.0
categories: list[ContentCategory] = field(default_factory=list)
explanation: str = ""
suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none"
@property
def severity(self) -> int:
"""Get severity score 0-100 based on confidence and categories."""
if not self.is_flagged:
return 0
# Base severity from confidence
severity = int(self.confidence * 50)
# Add severity based on category
high_severity = {
ContentCategory.HATE_SPEECH,
ContentCategory.SELF_HARM,
ContentCategory.SCAM,
}
medium_severity = {
ContentCategory.HARASSMENT,
ContentCategory.VIOLENCE,
ContentCategory.SEXUAL,
}
for cat in self.categories:
if cat in high_severity:
severity += 30
elif cat in medium_severity:
severity += 20
else:
severity += 10
return min(severity, 100)
@dataclass
class ImageAnalysisResult:
"""Result of AI image analysis."""
is_nsfw: bool = False
is_violent: bool = False
is_disturbing: bool = False
confidence: float = 0.0
description: str = ""
categories: list[str] = field(default_factory=list)
@dataclass
class PhishingAnalysisResult:
"""Result of AI phishing/scam analysis."""
is_phishing: bool = False
confidence: float = 0.0
risk_factors: list[str] = field(default_factory=list)
explanation: str = ""
class AIProvider(ABC):
"""Abstract base class for AI providers."""
@abstractmethod
async def moderate_text(
self,
content: str,
context: str | None = None,
sensitivity: int = 50,
) -> ModerationResult:
"""
Analyze text content for policy violations.
Args:
content: The text to analyze
context: Optional context about the conversation/server
sensitivity: 0-100, higher means more strict
Returns:
ModerationResult with analysis
"""
pass
@abstractmethod
async def analyze_image(
self,
image_url: str,
sensitivity: int = 50,
) -> ImageAnalysisResult:
"""
Analyze an image for NSFW or inappropriate content.
Args:
image_url: URL of the image to analyze
sensitivity: 0-100, higher means more strict
Returns:
ImageAnalysisResult with analysis
"""
pass
@abstractmethod
async def analyze_phishing(
self,
url: str,
message_content: str | None = None,
) -> PhishingAnalysisResult:
"""
Analyze a URL for phishing/scam indicators.
Args:
url: The URL to analyze
message_content: Optional full message for context
Returns:
PhishingAnalysisResult with analysis
"""
pass
@abstractmethod
async def close(self) -> None:
"""Clean up resources."""
pass