Implement GuardDen Discord moderation bot
Features: - Core moderation: warn, kick, ban, timeout, strike system - Automod: banned words filter, scam detection, anti-spam, link filtering - AI moderation: Claude/OpenAI integration, NSFW detection, phishing analysis - Verification system: button, captcha, math, emoji challenges - Rate limiting system with configurable scopes - Event logging: joins, leaves, message edits/deletes, voice activity - Per-guild configuration with caching - Docker deployment support Bug fixes applied: - Fixed await on session.delete() in guild_config.py - Fixed memory leak in AI moderation message tracking (use deque) - Added error handling to bot shutdown - Added error handling to timeout command - Removed unused Literal import - Added prefix validation - Added image analysis limit (3 per message) - Fixed test mock for SQLAlchemy model
This commit is contained in:
16
src/guardden/services/__init__.py
Normal file
16
src/guardden/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Services for GuardDen."""
|
||||
|
||||
from guardden.services.automod import AutomodService
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
|
||||
from guardden.services.verification import ChallengeType, VerificationService
|
||||
|
||||
__all__ = [
|
||||
"AutomodService",
|
||||
"ChallengeType",
|
||||
"Database",
|
||||
"RateLimiter",
|
||||
"VerificationService",
|
||||
"get_rate_limiter",
|
||||
"ratelimit",
|
||||
]
|
||||
6
src/guardden/services/ai/__init__.py
Normal file
6
src/guardden/services/ai/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""AI services for content moderation."""
|
||||
|
||||
from guardden.services.ai.base import AIProvider, ModerationResult
|
||||
from guardden.services.ai.factory import create_ai_provider
|
||||
|
||||
__all__ = ["AIProvider", "ModerationResult", "create_ai_provider"]
|
||||
261
src/guardden/services/ai/anthropic_provider.py
Normal file
261
src/guardden/services/ai/anthropic_provider.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Anthropic Claude AI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Content moderation system prompt
|
||||
MODERATION_SYSTEM_PROMPT = """You are a content moderation AI for a Discord server. Analyze the given message and determine if it violates community guidelines.
|
||||
|
||||
Categories to check:
|
||||
- harassment: Personal attacks, bullying, intimidation
|
||||
- hate_speech: Discrimination, slurs, dehumanization based on identity
|
||||
- sexual: Explicit sexual content, sexual solicitation
|
||||
- violence: Threats, graphic violence, encouraging harm
|
||||
- self_harm: Suicide, self-injury content or encouragement
|
||||
- spam: Repetitive, promotional, or low-quality content
|
||||
- scam: Phishing attempts, fraudulent offers, impersonation
|
||||
- misinformation: Dangerous false information
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_flagged": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"categories": ["category1", "category2"],
|
||||
"explanation": "Brief explanation",
|
||||
"suggested_action": "none/warn/delete/timeout/ban"
|
||||
}
|
||||
|
||||
Be balanced - flag genuinely problematic content but allow normal conversation, jokes, and mild language. Consider context."""
|
||||
|
||||
IMAGE_ANALYSIS_PROMPT = """Analyze this image for content moderation purposes. Check for:
|
||||
- NSFW content (nudity, sexual content)
|
||||
- Violence or gore
|
||||
- Disturbing or shocking content
|
||||
- Any content inappropriate for a general audience
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_nsfw": true/false,
|
||||
"is_violent": true/false,
|
||||
"is_disturbing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"description": "Brief description of the image",
|
||||
"categories": ["category1", "category2"]
|
||||
}
|
||||
|
||||
Be accurate but not overly sensitive - artistic nudity or mild violence in appropriate contexts may be acceptable."""
|
||||
|
||||
PHISHING_ANALYSIS_PROMPT = """Analyze this URL and message context for phishing or scam indicators.
|
||||
|
||||
Check for:
|
||||
- Domain impersonation (typosquatting, lookalike domains)
|
||||
- Urgency tactics ("act now", "limited time")
|
||||
- Requests for credentials or personal info
|
||||
- Too-good-to-be-true offers
|
||||
- Suspicious redirects or URL shorteners
|
||||
- Mismatched or hidden URLs
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_phishing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"risk_factors": ["factor1", "factor2"],
|
||||
"explanation": "Brief explanation"
|
||||
}"""
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
"""AI provider using Anthropic's Claude API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "claude-3-haiku-20240307") -> None:
|
||||
"""
|
||||
Initialize Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model to use (default: claude-3-haiku for speed/cost)
|
||||
"""
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("anthropic package required. Install with: pip install anthropic")
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
self.model = model
|
||||
logger.info(f"Initialized Anthropic provider with model: {model}")
|
||||
|
||||
async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str:
|
||||
"""Make an API call to Claude."""
|
||||
try:
|
||||
message = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": user_content}],
|
||||
)
|
||||
return message.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
def _parse_json_response(self, response: str) -> dict:
|
||||
"""Parse JSON from response, handling markdown code blocks."""
|
||||
import json
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
text = response.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
# Remove first and last lines (```json and ```)
|
||||
text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
return json.loads(text)
|
||||
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""Analyze text content for policy violations."""
|
||||
# Adjust prompt based on sensitivity
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = "\n\nBe lenient - only flag clearly problematic content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = "\n\nBe strict - flag anything potentially problematic."
|
||||
|
||||
system = MODERATION_SYSTEM_PROMPT + sensitivity_note
|
||||
|
||||
user_message = f"Message to analyze:\n{content}"
|
||||
if context:
|
||||
user_message = f"Context: {context}\n\n{user_message}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
categories = [
|
||||
ContentCategory(cat)
|
||||
for cat in data.get("categories", [])
|
||||
if cat in ContentCategory.__members__.values()
|
||||
]
|
||||
|
||||
return ModerationResult(
|
||||
is_flagged=data.get("is_flagged", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
categories=categories,
|
||||
explanation=data.get("explanation", ""),
|
||||
suggested_action=data.get("suggested_action", "none"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moderating text: {e}")
|
||||
return ModerationResult(
|
||||
is_flagged=False,
|
||||
explanation=f"Error analyzing content: {str(e)}",
|
||||
)
|
||||
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""Analyze an image for NSFW or inappropriate content."""
|
||||
import base64
|
||||
|
||||
import aiohttp
|
||||
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = "\n\nBe lenient - only flag explicit content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = "\n\nBe strict - flag suggestive content as well."
|
||||
|
||||
system = IMAGE_ANALYSIS_PROMPT + sensitivity_note
|
||||
|
||||
try:
|
||||
# Download image and convert to base64
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status != 200:
|
||||
return ImageAnalysisResult(
|
||||
description=f"Failed to download image: HTTP {resp.status}"
|
||||
)
|
||||
|
||||
content_type = resp.content_type or "image/jpeg"
|
||||
image_data = await resp.read()
|
||||
|
||||
# Check file size (max 20MB for Claude)
|
||||
if len(image_data) > 20 * 1024 * 1024:
|
||||
return ImageAnalysisResult(description="Image too large to analyze")
|
||||
|
||||
base64_image = base64.standard_b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Create multimodal message
|
||||
user_content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": content_type,
|
||||
"data": base64_image,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Analyze this image for content moderation."},
|
||||
]
|
||||
|
||||
response = await self._call_api(system, user_content)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return ImageAnalysisResult(
|
||||
is_nsfw=data.get("is_nsfw", False),
|
||||
is_violent=data.get("is_violent", False),
|
||||
is_disturbing=data.get("is_disturbing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
description=data.get("description", ""),
|
||||
categories=data.get("categories", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing image: {e}")
|
||||
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""Analyze a URL for phishing/scam indicators."""
|
||||
user_message = f"URL to analyze: {url}"
|
||||
if message_content:
|
||||
user_message += f"\n\nFull message context:\n{message_content}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(PHISHING_ANALYSIS_PROMPT, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return PhishingAnalysisResult(
|
||||
is_phishing=data.get("is_phishing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
risk_factors=data.get("risk_factors", []),
|
||||
explanation=data.get("explanation", ""),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phishing: {e}")
|
||||
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
await self.client.close()
|
||||
149
src/guardden/services/ai/base.py
Normal file
149
src/guardden/services/ai/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Base classes for AI providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of problematic content."""
|
||||
|
||||
SAFE = "safe"
|
||||
HARASSMENT = "harassment"
|
||||
HATE_SPEECH = "hate_speech"
|
||||
SEXUAL = "sexual"
|
||||
VIOLENCE = "violence"
|
||||
SELF_HARM = "self_harm"
|
||||
SPAM = "spam"
|
||||
SCAM = "scam"
|
||||
MISINFORMATION = "misinformation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModerationResult:
|
||||
"""Result of AI content moderation."""
|
||||
|
||||
is_flagged: bool = False
|
||||
confidence: float = 0.0 # 0.0 to 1.0
|
||||
categories: list[ContentCategory] = field(default_factory=list)
|
||||
explanation: str = ""
|
||||
suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none"
|
||||
|
||||
@property
|
||||
def severity(self) -> int:
|
||||
"""Get severity score 0-100 based on confidence and categories."""
|
||||
if not self.is_flagged:
|
||||
return 0
|
||||
|
||||
# Base severity from confidence
|
||||
severity = int(self.confidence * 50)
|
||||
|
||||
# Add severity based on category
|
||||
high_severity = {
|
||||
ContentCategory.HATE_SPEECH,
|
||||
ContentCategory.SELF_HARM,
|
||||
ContentCategory.SCAM,
|
||||
}
|
||||
medium_severity = {
|
||||
ContentCategory.HARASSMENT,
|
||||
ContentCategory.VIOLENCE,
|
||||
ContentCategory.SEXUAL,
|
||||
}
|
||||
|
||||
for cat in self.categories:
|
||||
if cat in high_severity:
|
||||
severity += 30
|
||||
elif cat in medium_severity:
|
||||
severity += 20
|
||||
else:
|
||||
severity += 10
|
||||
|
||||
return min(severity, 100)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnalysisResult:
|
||||
"""Result of AI image analysis."""
|
||||
|
||||
is_nsfw: bool = False
|
||||
is_violent: bool = False
|
||||
is_disturbing: bool = False
|
||||
confidence: float = 0.0
|
||||
description: str = ""
|
||||
categories: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhishingAnalysisResult:
|
||||
"""Result of AI phishing/scam analysis."""
|
||||
|
||||
is_phishing: bool = False
|
||||
confidence: float = 0.0
|
||||
risk_factors: list[str] = field(default_factory=list)
|
||||
explanation: str = ""
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""
|
||||
Analyze text content for policy violations.
|
||||
|
||||
Args:
|
||||
content: The text to analyze
|
||||
context: Optional context about the conversation/server
|
||||
sensitivity: 0-100, higher means more strict
|
||||
|
||||
Returns:
|
||||
ModerationResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""
|
||||
Analyze an image for NSFW or inappropriate content.
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to analyze
|
||||
sensitivity: 0-100, higher means more strict
|
||||
|
||||
Returns:
|
||||
ImageAnalysisResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""
|
||||
Analyze a URL for phishing/scam indicators.
|
||||
|
||||
Args:
|
||||
url: The URL to analyze
|
||||
message_content: Optional full message for context
|
||||
|
||||
Returns:
|
||||
PhishingAnalysisResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
pass
|
||||
67
src/guardden/services/ai/factory.py
Normal file
67
src/guardden/services/ai/factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Factory for creating AI providers."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from guardden.services.ai.base import AIProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NullProvider(AIProvider):
|
||||
"""Null provider that does nothing (for when AI is disabled)."""
|
||||
|
||||
async def moderate_text(self, content, context=None, sensitivity=50):
|
||||
from guardden.services.ai.base import ModerationResult
|
||||
|
||||
return ModerationResult()
|
||||
|
||||
async def analyze_image(self, image_url, sensitivity=50):
|
||||
from guardden.services.ai.base import ImageAnalysisResult
|
||||
|
||||
return ImageAnalysisResult()
|
||||
|
||||
async def analyze_phishing(self, url, message_content=None):
|
||||
from guardden.services.ai.base import PhishingAnalysisResult
|
||||
|
||||
return PhishingAnalysisResult()
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_ai_provider(
|
||||
provider: Literal["anthropic", "openai", "none"],
|
||||
api_key: str | None = None,
|
||||
) -> AIProvider:
|
||||
"""
|
||||
Create an AI provider instance.
|
||||
|
||||
Args:
|
||||
provider: The provider type to create
|
||||
api_key: API key for the provider
|
||||
|
||||
Returns:
|
||||
AIProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is unknown or API key is missing
|
||||
"""
|
||||
if provider == "none":
|
||||
logger.info("AI moderation disabled")
|
||||
return NullProvider()
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"API key required for {provider} provider")
|
||||
|
||||
if provider == "anthropic":
|
||||
from guardden.services.ai.anthropic_provider import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(api_key)
|
||||
|
||||
if provider == "openai":
|
||||
from guardden.services.ai.openai_provider import OpenAIProvider
|
||||
|
||||
return OpenAIProvider(api_key)
|
||||
|
||||
raise ValueError(f"Unknown AI provider: {provider}")
|
||||
213
src/guardden/services/ai/openai_provider.py
Normal file
213
src/guardden/services/ai/openai_provider.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""OpenAI AI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
"""AI provider using OpenAI's API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None:
|
||||
"""
|
||||
Initialize OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model to use (default: gpt-4o-mini for speed/cost)
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ImportError("openai package required. Install with: pip install openai")
|
||||
|
||||
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
logger.info(f"Initialized OpenAI provider with model: {model}")
|
||||
|
||||
async def _call_api(
|
||||
self,
|
||||
system: str,
|
||||
user_content: Any,
|
||||
max_tokens: int = 500,
|
||||
) -> str:
|
||||
"""Make an API call to OpenAI."""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
def _parse_json_response(self, response: str) -> dict:
|
||||
"""Parse JSON from response."""
|
||||
import json
|
||||
|
||||
return json.loads(response)
|
||||
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""Analyze text content for policy violations."""
|
||||
# First, use OpenAI's built-in moderation API for quick check
|
||||
try:
|
||||
mod_response = await self.client.moderations.create(input=content)
|
||||
results = mod_response.results[0]
|
||||
|
||||
# Map OpenAI categories to our categories
|
||||
category_mapping = {
|
||||
"harassment": ContentCategory.HARASSMENT,
|
||||
"harassment/threatening": ContentCategory.HARASSMENT,
|
||||
"hate": ContentCategory.HATE_SPEECH,
|
||||
"hate/threatening": ContentCategory.HATE_SPEECH,
|
||||
"self-harm": ContentCategory.SELF_HARM,
|
||||
"self-harm/intent": ContentCategory.SELF_HARM,
|
||||
"self-harm/instructions": ContentCategory.SELF_HARM,
|
||||
"sexual": ContentCategory.SEXUAL,
|
||||
"sexual/minors": ContentCategory.SEXUAL,
|
||||
"violence": ContentCategory.VIOLENCE,
|
||||
"violence/graphic": ContentCategory.VIOLENCE,
|
||||
}
|
||||
|
||||
flagged_categories = []
|
||||
max_score = 0.0
|
||||
|
||||
for category, score in results.category_scores.model_dump().items():
|
||||
if score > 0.5: # Threshold
|
||||
if category in category_mapping:
|
||||
flagged_categories.append(category_mapping[category])
|
||||
max_score = max(max_score, score)
|
||||
|
||||
# Adjust threshold based on sensitivity
|
||||
threshold = 0.3 + (0.4 * (100 - sensitivity) / 100) # 0.3 to 0.7
|
||||
|
||||
if results.flagged or max_score > threshold:
|
||||
return ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=max_score,
|
||||
categories=list(set(flagged_categories)),
|
||||
explanation="Content flagged by moderation API",
|
||||
suggested_action="delete" if max_score > 0.8 else "warn",
|
||||
)
|
||||
|
||||
return ModerationResult(is_flagged=False, confidence=1.0 - max_score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moderating text: {e}")
|
||||
return ModerationResult(
|
||||
is_flagged=False,
|
||||
explanation=f"Error analyzing content: {str(e)}",
|
||||
)
|
||||
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""Analyze an image for NSFW or inappropriate content."""
|
||||
system = """Analyze this image for content moderation. Respond in JSON format:
|
||||
{
|
||||
"is_nsfw": true/false,
|
||||
"is_violent": true/false,
|
||||
"is_disturbing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"description": "Brief description",
|
||||
"categories": ["category1"]
|
||||
}"""
|
||||
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = " Be lenient - only flag explicit content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = " Be strict - flag suggestive content."
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="gpt-4o-mini", # Use vision-capable model
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{"role": "system", "content": system + sensitivity_note},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Analyze this image for moderation."},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
data = self._parse_json_response(response.choices[0].message.content or "{}")
|
||||
|
||||
return ImageAnalysisResult(
|
||||
is_nsfw=data.get("is_nsfw", False),
|
||||
is_violent=data.get("is_violent", False),
|
||||
is_disturbing=data.get("is_disturbing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
description=data.get("description", ""),
|
||||
categories=data.get("categories", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing image: {e}")
|
||||
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""Analyze a URL for phishing/scam indicators."""
|
||||
system = """Analyze the URL for phishing/scam indicators. Respond in JSON:
|
||||
{
|
||||
"is_phishing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"risk_factors": ["factor1"],
|
||||
"explanation": "Brief explanation"
|
||||
}
|
||||
|
||||
Check for: domain impersonation, urgency tactics, credential requests, too-good-to-be-true offers."""
|
||||
|
||||
user_message = f"URL: {url}"
|
||||
if message_content:
|
||||
user_message += f"\n\nMessage context: {message_content}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return PhishingAnalysisResult(
|
||||
is_phishing=data.get("is_phishing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
risk_factors=data.get("risk_factors", []),
|
||||
explanation=data.get("explanation", ""),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phishing: {e}")
|
||||
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
await self.client.close()
|
||||
301
src/guardden/services/automod.py
Normal file
301
src/guardden/services/automod.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Automod service for content filtering and spam detection."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple
|
||||
|
||||
import discord
|
||||
|
||||
from guardden.models import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Known scam/phishing patterns
|
||||
SCAM_PATTERNS = [
|
||||
# Discord scam patterns
|
||||
r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}",
|
||||
r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro",
|
||||
r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)",
|
||||
# Generic phishing
|
||||
r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account",
|
||||
r"(?:suspended|locked|limited)[-.\s]?account",
|
||||
r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)",
|
||||
# Crypto scams
|
||||
r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)",
|
||||
r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)",
|
||||
]
|
||||
|
||||
# Suspicious TLDs often used in phishing
|
||||
SUSPICIOUS_TLDS = {
|
||||
".xyz",
|
||||
".top",
|
||||
".club",
|
||||
".work",
|
||||
".click",
|
||||
".link",
|
||||
".info",
|
||||
".ru",
|
||||
".cn",
|
||||
".tk",
|
||||
".ml",
|
||||
".ga",
|
||||
".cf",
|
||||
".gq",
|
||||
}
|
||||
|
||||
# URL pattern for extraction
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
|
||||
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class SpamRecord(NamedTuple):
|
||||
"""Record of a message for spam tracking."""
|
||||
|
||||
content_hash: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSpamTracker:
|
||||
"""Tracks spam behavior for a single user."""
|
||||
|
||||
messages: list[SpamRecord] = field(default_factory=list)
|
||||
mention_count: int = 0
|
||||
last_mention_time: datetime | None = None
|
||||
duplicate_count: int = 0
|
||||
last_action_time: datetime | None = None
|
||||
|
||||
def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None:
|
||||
"""Remove old messages from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
self.messages = [m for m in self.messages if m.timestamp > cutoff]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutomodResult:
|
||||
"""Result of automod check."""
|
||||
|
||||
should_delete: bool = False
|
||||
should_warn: bool = False
|
||||
should_strike: bool = False
|
||||
should_timeout: bool = False
|
||||
timeout_duration: int = 0 # seconds
|
||||
reason: str = ""
|
||||
matched_filter: str = ""
|
||||
|
||||
|
||||
class AutomodService:
|
||||
"""Service for automatic content moderation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Compile scam patterns
|
||||
self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS]
|
||||
|
||||
# Per-guild, per-user spam tracking
|
||||
# Structure: {guild_id: {user_id: UserSpamTracker}}
|
||||
self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict(
|
||||
lambda: defaultdict(UserSpamTracker)
|
||||
)
|
||||
|
||||
# Spam thresholds
|
||||
self.message_rate_limit = 5 # messages per window
|
||||
self.message_rate_window = 5 # seconds
|
||||
self.duplicate_threshold = 3 # same message count
|
||||
self.mention_limit = 5 # mentions per message
|
||||
self.mention_rate_limit = 10 # mentions per window
|
||||
self.mention_rate_window = 60 # seconds
|
||||
|
||||
def _get_content_hash(self, content: str) -> str:
|
||||
"""Get a normalized hash of message content for duplicate detection."""
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
normalized = re.sub(r"[^\w\s]", "", content.lower())
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
self, content: str, banned_words: list[BannedWord]
|
||||
) -> AutomodResult | None:
|
||||
"""Check message against banned words list."""
|
||||
content_lower = content.lower()
|
||||
|
||||
for banned in banned_words:
|
||||
matched = False
|
||||
|
||||
if banned.is_regex:
|
||||
try:
|
||||
if re.search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
except re.error:
|
||||
logger.warning(f"Invalid regex pattern: {banned.pattern}")
|
||||
continue
|
||||
else:
|
||||
if banned.pattern.lower() in content_lower:
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
result = AutomodResult(
|
||||
should_delete=True,
|
||||
reason=banned.reason or f"Matched banned word filter",
|
||||
matched_filter=f"banned_word:{banned.id}",
|
||||
)
|
||||
|
||||
if banned.action == "warn":
|
||||
result.should_warn = True
|
||||
elif banned.action == "strike":
|
||||
result.should_strike = True
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def check_scam_links(self, content: str) -> AutomodResult | None:
|
||||
"""Check message for scam/phishing patterns."""
|
||||
# Check for known scam patterns
|
||||
for pattern in self._scam_patterns:
|
||||
if pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason="Message matched known scam/phishing pattern",
|
||||
matched_filter="scam_pattern",
|
||||
)
|
||||
|
||||
# Check URLs for suspicious TLDs
|
||||
urls = URL_PATTERN.findall(content)
|
||||
for url in urls:
|
||||
url_lower = url.lower()
|
||||
for tld in SUSPICIOUS_TLDS:
|
||||
if tld in url_lower:
|
||||
# Additional check: is it trying to impersonate a known domain?
|
||||
impersonation_keywords = [
|
||||
"discord",
|
||||
"steam",
|
||||
"nitro",
|
||||
"gift",
|
||||
"free",
|
||||
"login",
|
||||
"verify",
|
||||
]
|
||||
if any(kw in url_lower for kw in impersonation_keywords):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Suspicious link detected: {url[:50]}",
|
||||
matched_filter="suspicious_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_spam(
|
||||
self, message: discord.Message, anti_spam_enabled: bool = True
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for spam behavior."""
|
||||
if not anti_spam_enabled:
|
||||
return None
|
||||
|
||||
guild_id = message.guild.id
|
||||
user_id = message.author.id
|
||||
tracker = self._spam_trackers[guild_id][user_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Cleanup old records
|
||||
tracker.cleanup()
|
||||
|
||||
# Check message rate
|
||||
content_hash = self._get_content_hash(message.content)
|
||||
tracker.messages.append(SpamRecord(content_hash, now))
|
||||
|
||||
# Rate limit check
|
||||
recent_window = now - timedelta(seconds=self.message_rate_window)
|
||||
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||
|
||||
if len(recent_messages) > self.message_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=60, # 1 minute timeout
|
||||
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
|
||||
matched_filter="rate_limit",
|
||||
)
|
||||
|
||||
# Duplicate message check
|
||||
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||
if duplicate_count >= self.duplicate_threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Duplicate message detected ({duplicate_count} times)",
|
||||
matched_filter="duplicate",
|
||||
)
|
||||
|
||||
# Mass mention check
|
||||
mention_count = len(message.mentions) + len(message.role_mentions)
|
||||
if message.mention_everyone:
|
||||
mention_count += 100 # Treat @everyone as many mentions
|
||||
|
||||
if mention_count > self.mention_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300, # 5 minute timeout
|
||||
reason=f"Mass mentions detected ({mention_count} mentions)",
|
||||
matched_filter="mass_mention",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||
"""Check for Discord invite links."""
|
||||
if allow_invites:
|
||||
return None
|
||||
|
||||
invite_pattern = re.compile(
|
||||
r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
if invite_pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Discord invite links are not allowed",
|
||||
matched_filter="invite_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_all_caps(
|
||||
self, content: str, threshold: float = 0.7, min_length: int = 10
|
||||
) -> AutomodResult | None:
|
||||
"""Check for excessive caps usage."""
|
||||
# Only check messages with enough letters
|
||||
letters = [c for c in content if c.isalpha()]
|
||||
if len(letters) < min_length:
|
||||
return None
|
||||
|
||||
caps_count = sum(1 for c in letters if c.isupper())
|
||||
caps_ratio = caps_count / len(letters)
|
||||
|
||||
if caps_ratio > threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Excessive caps usage",
|
||||
matched_filter="caps",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def reset_user_tracker(self, guild_id: int, user_id: int) -> None:
|
||||
"""Reset spam tracking for a user."""
|
||||
if guild_id in self._spam_trackers:
|
||||
self._spam_trackers[guild_id].pop(user_id, None)
|
||||
|
||||
def cleanup_guild(self, guild_id: int) -> None:
|
||||
"""Remove all tracking data for a guild."""
|
||||
self._spam_trackers.pop(guild_id, None)
|
||||
99
src/guardden/services/database.py
Normal file
99
src/guardden/services/database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Database connection and session management."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from guardden.config import Settings
|
||||
from guardden.models import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
"""Manages database connections and sessions."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self._engine = None
|
||||
self._session_factory = None
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize database connection pool."""
|
||||
db_url = self.settings.database_url.get_secret_value()
|
||||
|
||||
# Create SQLAlchemy async engine
|
||||
# Convert postgresql:// to postgresql+asyncpg://
|
||||
if db_url.startswith("postgresql://"):
|
||||
sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
else:
|
||||
sqlalchemy_url = db_url
|
||||
|
||||
self._engine = create_async_engine(
|
||||
sqlalchemy_url,
|
||||
pool_size=self.settings.database_pool_min,
|
||||
max_overflow=self.settings.database_pool_max - self.settings.database_pool_min,
|
||||
echo=self.settings.log_level == "DEBUG",
|
||||
)
|
||||
|
||||
self._session_factory = async_sessionmaker(
|
||||
self._engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Also create a raw asyncpg pool for performance-critical operations
|
||||
self._pool = await asyncpg.create_pool(
|
||||
db_url,
|
||||
min_size=self.settings.database_pool_min,
|
||||
max_size=self.settings.database_pool_max,
|
||||
)
|
||||
|
||||
logger.info("Database connection established")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close all database connections."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
|
||||
if self._engine:
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
|
||||
logger.info("Database connections closed")
|
||||
|
||||
async def create_tables(self) -> None:
|
||||
"""Create all database tables."""
|
||||
if not self._engine:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
logger.info("Database tables created")
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session context manager."""
|
||||
if not self._session_factory:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
@property
|
||||
def pool(self) -> asyncpg.Pool:
|
||||
"""Get the raw asyncpg connection pool."""
|
||||
if not self._pool:
|
||||
raise RuntimeError("Database not connected")
|
||||
return self._pool
|
||||
145
src/guardden/services/guild_config.py
Normal file
145
src/guardden/services/guild_config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Guild configuration service."""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
import discord
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from guardden.models import BannedWord, Guild, GuildSettings
|
||||
from guardden.services.database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuildConfigService:
|
||||
"""Manages guild configurations with caching."""
|
||||
|
||||
def __init__(self, database: Database) -> None:
|
||||
self.database = database
|
||||
self._cache: dict[int, GuildSettings] = {}
|
||||
|
||||
async def get_config(self, guild_id: int) -> GuildSettings | None:
|
||||
"""Get guild configuration, using cache if available."""
|
||||
if guild_id in self._cache:
|
||||
return self._cache[guild_id]
|
||||
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if settings:
|
||||
self._cache[guild_id] = settings
|
||||
|
||||
return settings
|
||||
|
||||
async def get_guild(self, guild_id: int) -> Guild | None:
|
||||
"""Get full guild data including settings and banned words."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(Guild)
|
||||
.options(selectinload(Guild.settings), selectinload(Guild.banned_words))
|
||||
.where(Guild.id == guild_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_guild(self, guild: discord.Guild) -> Guild:
|
||||
"""Create a new guild entry with default settings."""
|
||||
async with self.database.session() as session:
|
||||
# Check if guild already exists
|
||||
existing = await session.get(Guild, guild.id)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Create new guild
|
||||
db_guild = Guild(
|
||||
id=guild.id,
|
||||
name=guild.name,
|
||||
owner_id=guild.owner_id,
|
||||
)
|
||||
session.add(db_guild)
|
||||
await session.flush()
|
||||
|
||||
# Create default settings
|
||||
settings = GuildSettings(guild_id=guild.id)
|
||||
session.add(settings)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Created guild config for {guild.name} (ID: {guild.id})")
|
||||
return db_guild
|
||||
|
||||
async def update_settings(self, guild_id: int, **kwargs) -> GuildSettings | None:
|
||||
"""Update guild settings."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Invalidate cache
|
||||
self._cache.pop(guild_id, None)
|
||||
|
||||
return settings
|
||||
|
||||
def invalidate_cache(self, guild_id: int) -> None:
|
||||
"""Remove a guild from the cache."""
|
||||
self._cache.pop(guild_id, None)
|
||||
|
||||
async def get_banned_words(self, guild_id: int) -> list[BannedWord]:
|
||||
"""Get all banned words for a guild."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == guild_id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def add_banned_word(
|
||||
self,
|
||||
guild_id: int,
|
||||
pattern: str,
|
||||
added_by: int,
|
||||
is_regex: bool = False,
|
||||
action: str = "delete",
|
||||
reason: str | None = None,
|
||||
) -> BannedWord:
|
||||
"""Add a banned word to a guild."""
|
||||
async with self.database.session() as session:
|
||||
banned_word = BannedWord(
|
||||
guild_id=guild_id,
|
||||
pattern=pattern,
|
||||
is_regex=is_regex,
|
||||
action=action,
|
||||
reason=reason,
|
||||
added_by=added_by,
|
||||
)
|
||||
session.add(banned_word)
|
||||
await session.commit()
|
||||
return banned_word
|
||||
|
||||
async def remove_banned_word(self, guild_id: int, word_id: int) -> bool:
|
||||
"""Remove a banned word from a guild."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(BannedWord.id == word_id, BannedWord.guild_id == guild_id)
|
||||
)
|
||||
word = result.scalar_one_or_none()
|
||||
|
||||
if word:
|
||||
session.delete(word)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
300
src/guardden/services/ratelimit.py
Normal file
300
src/guardden/services/ratelimit.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Rate limiting service for command and action throttling."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitScope(str, Enum):
|
||||
"""Scope of rate limiting."""
|
||||
|
||||
USER = "user" # Per user globally
|
||||
MEMBER = "member" # Per user per guild
|
||||
CHANNEL = "channel" # Per channel
|
||||
GUILD = "guild" # Per guild
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitBucket:
|
||||
"""Tracks rate limit state for a single bucket."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: float
|
||||
requests: list[datetime] = field(default_factory=list)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove expired requests from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.window_seconds)
|
||||
self.requests = [r for r in self.requests if r > cutoff]
|
||||
|
||||
def is_limited(self) -> bool:
|
||||
"""Check if this bucket is rate limited."""
|
||||
self.cleanup()
|
||||
return len(self.requests) >= self.max_requests
|
||||
|
||||
def record(self) -> None:
|
||||
"""Record a request."""
|
||||
self.requests.append(datetime.now(timezone.utc))
|
||||
|
||||
def remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
self.cleanup()
|
||||
return max(0, self.max_requests - len(self.requests))
|
||||
|
||||
def reset_after(self) -> float:
|
||||
"""Get seconds until rate limit resets."""
|
||||
if not self.requests:
|
||||
return 0.0
|
||||
self.cleanup()
|
||||
if not self.requests:
|
||||
return 0.0
|
||||
oldest = min(self.requests)
|
||||
reset_time = oldest + timedelta(seconds=self.window_seconds)
|
||||
remaining = (reset_time - datetime.now(timezone.utc)).total_seconds()
|
||||
return max(0.0, remaining)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for a rate limit."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: float
|
||||
scope: RateLimitScope = RateLimitScope.MEMBER
|
||||
|
||||
def create_bucket(self) -> RateLimitBucket:
|
||||
return RateLimitBucket(
|
||||
max_requests=self.max_requests,
|
||||
window_seconds=self.window_seconds,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
is_limited: bool
|
||||
remaining: int
|
||||
reset_after: float
|
||||
bucket_key: str
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""General-purpose rate limiter."""
|
||||
|
||||
# Default rate limits for various actions
|
||||
DEFAULT_LIMITS = {
|
||||
"command": RateLimitConfig(5, 10, RateLimitScope.MEMBER), # 5 commands per 10s
|
||||
"moderation": RateLimitConfig(10, 60, RateLimitScope.MEMBER), # 10 mod actions per minute
|
||||
"verification": RateLimitConfig(3, 300, RateLimitScope.MEMBER), # 3 verifications per 5 min
|
||||
"message": RateLimitConfig(10, 10, RateLimitScope.MEMBER), # 10 messages per 10s
|
||||
"api_call": RateLimitConfig(
|
||||
30, 60, RateLimitScope.GUILD
|
||||
), # 30 API calls per minute per guild
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Buckets: {action: {bucket_key: RateLimitBucket}}
|
||||
self._buckets: dict[str, dict[str, RateLimitBucket]] = defaultdict(dict)
|
||||
self._configs: dict[str, RateLimitConfig] = dict(self.DEFAULT_LIMITS)
|
||||
|
||||
def configure(self, action: str, config: RateLimitConfig) -> None:
|
||||
"""Configure rate limit for an action."""
|
||||
self._configs[action] = config
|
||||
# Clear existing buckets for this action
|
||||
self._buckets[action].clear()
|
||||
|
||||
def _get_bucket_key(
|
||||
self,
|
||||
scope: RateLimitScope,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> str:
|
||||
"""Generate a bucket key based on scope."""
|
||||
if scope == RateLimitScope.USER:
|
||||
return f"user:{user_id}"
|
||||
elif scope == RateLimitScope.MEMBER:
|
||||
return f"member:{guild_id}:{user_id}"
|
||||
elif scope == RateLimitScope.CHANNEL:
|
||||
return f"channel:{channel_id}"
|
||||
elif scope == RateLimitScope.GUILD:
|
||||
return f"guild:{guild_id}"
|
||||
return f"unknown:{user_id}:{guild_id}"
|
||||
|
||||
def check(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check if an action is rate limited.
|
||||
|
||||
Does not record the request - use `acquire()` for that.
|
||||
"""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=999,
|
||||
reset_after=0,
|
||||
bucket_key="",
|
||||
)
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
bucket = self._buckets[action].get(bucket_key)
|
||||
|
||||
if not bucket:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=config.max_requests,
|
||||
reset_after=0,
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
return RateLimitResult(
|
||||
is_limited=bucket.is_limited(),
|
||||
remaining=bucket.remaining(),
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Attempt to acquire a rate limit slot.
|
||||
|
||||
Records the request if not limited.
|
||||
"""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=999,
|
||||
reset_after=0,
|
||||
bucket_key="",
|
||||
)
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
|
||||
if bucket_key not in self._buckets[action]:
|
||||
self._buckets[action][bucket_key] = config.create_bucket()
|
||||
|
||||
bucket = self._buckets[action][bucket_key]
|
||||
|
||||
if bucket.is_limited():
|
||||
return RateLimitResult(
|
||||
is_limited=True,
|
||||
remaining=0,
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
bucket.record()
|
||||
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=bucket.remaining(),
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> bool:
|
||||
"""Reset rate limit for a specific bucket."""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
return self._buckets[action].pop(bucket_key, None) is not None
|
||||
|
||||
def cleanup(self) -> int:
|
||||
"""Clean up empty and expired buckets. Returns count removed."""
|
||||
removed = 0
|
||||
for action in list(self._buckets.keys()):
|
||||
for key in list(self._buckets[action].keys()):
|
||||
bucket = self._buckets[action][key]
|
||||
bucket.cleanup()
|
||||
if not bucket.requests:
|
||||
del self._buckets[action][key]
|
||||
removed += 1
|
||||
return removed
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get or create the global rate limiter instance."""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
def ratelimit(
|
||||
action: str = "command",
|
||||
max_requests: int | None = None,
|
||||
window_seconds: float | None = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator for rate limiting commands.
|
||||
|
||||
Usage:
|
||||
@ratelimit("moderation", max_requests=5, window_seconds=60)
|
||||
async def my_command(self, ctx):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
async def wrapper(self, ctx, *args, **kwargs):
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# Configure if custom limits provided
|
||||
if max_requests is not None and window_seconds is not None:
|
||||
limiter.configure(
|
||||
action,
|
||||
RateLimitConfig(max_requests, window_seconds, RateLimitScope.MEMBER),
|
||||
)
|
||||
|
||||
result = limiter.acquire(
|
||||
action,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
|
||||
if result.is_limited:
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {result.reset_after:.1f} seconds.",
|
||||
delete_after=5,
|
||||
)
|
||||
return
|
||||
|
||||
return await func(self, ctx, *args, **kwargs)
|
||||
|
||||
# Preserve function metadata
|
||||
wrapper.__name__ = func.__name__
|
||||
wrapper.__doc__ = func.__doc__
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
300
src/guardden/services/verification.py
Normal file
300
src/guardden/services/verification.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Verification service for new member challenges."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChallengeType(str, Enum):
|
||||
"""Types of verification challenges."""
|
||||
|
||||
BUTTON = "button" # Simple button click
|
||||
CAPTCHA = "captcha" # Text-based captcha
|
||||
MATH = "math" # Simple math problem
|
||||
EMOJI = "emoji" # Select correct emoji
|
||||
QUESTIONS = "questions" # Custom questions
|
||||
|
||||
|
||||
@dataclass
|
||||
class Challenge:
|
||||
"""Represents a verification challenge."""
|
||||
|
||||
challenge_type: ChallengeType
|
||||
question: str
|
||||
answer: str
|
||||
options: list[str] = field(default_factory=list) # For multiple choice
|
||||
expires_at: datetime = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
)
|
||||
attempts: int = 0
|
||||
max_attempts: int = 3
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def check_answer(self, response: str) -> bool:
|
||||
"""Check if the response is correct."""
|
||||
self.attempts += 1
|
||||
return response.strip().lower() == self.answer.lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingVerification:
|
||||
"""Tracks a pending verification for a user."""
|
||||
|
||||
user_id: int
|
||||
guild_id: int
|
||||
challenge: Challenge
|
||||
message_id: int | None = None
|
||||
channel_id: int | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class ChallengeGenerator(ABC):
|
||||
"""Abstract base class for challenge generators."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> Challenge:
|
||||
"""Generate a new challenge."""
|
||||
pass
|
||||
|
||||
|
||||
class ButtonChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates simple button click challenges."""
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.BUTTON,
|
||||
question="Click the button below to verify you're human.",
|
||||
answer="verified",
|
||||
)
|
||||
|
||||
|
||||
class CaptchaChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates text-based captcha challenges."""
|
||||
|
||||
def __init__(self, length: int = 6) -> None:
|
||||
self.length = length
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
# Generate random alphanumeric code (avoiding confusing chars)
|
||||
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
code = "".join(random.choices(chars, k=self.length))
|
||||
|
||||
# Create visual representation with some obfuscation
|
||||
visual = self._create_visual(code)
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.CAPTCHA,
|
||||
question=f"Enter the code shown below:\n```\n{visual}\n```",
|
||||
answer=code,
|
||||
)
|
||||
|
||||
def _create_visual(self, code: str) -> str:
|
||||
"""Create a simple text-based visual captcha."""
|
||||
lines = []
|
||||
# Add some noise characters
|
||||
noise_chars = ".-*~^"
|
||||
|
||||
for _ in range(2):
|
||||
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||
|
||||
# Add the code with spacing
|
||||
spaced = " ".join(code)
|
||||
lines.append(spaced)
|
||||
|
||||
for _ in range(2):
|
||||
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class MathChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates simple math problem challenges."""
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
# Generate simple addition/subtraction/multiplication
|
||||
operation = random.choice(["+", "-", "*"])
|
||||
|
||||
if operation == "*":
|
||||
a = random.randint(2, 10)
|
||||
b = random.randint(2, 10)
|
||||
else:
|
||||
a = random.randint(10, 50)
|
||||
b = random.randint(1, 20)
|
||||
|
||||
if operation == "+":
|
||||
answer = a + b
|
||||
elif operation == "-":
|
||||
# Ensure positive result
|
||||
if b > a:
|
||||
a, b = b, a
|
||||
answer = a - b
|
||||
else:
|
||||
answer = a * b
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.MATH,
|
||||
question=f"Solve this math problem: **{a} {operation} {b} = ?**",
|
||||
answer=str(answer),
|
||||
)
|
||||
|
||||
|
||||
class EmojiChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates emoji selection challenges."""
|
||||
|
||||
EMOJI_SETS = [
|
||||
("animals", ["🐶", "🐱", "🐭", "🐹", "🐰", "🦊", "🐻", "🐼"]),
|
||||
("fruits", ["🍎", "🍐", "🍊", "🍋", "🍌", "🍉", "🍇", "🍓"]),
|
||||
("weather", ["☀️", "🌙", "⭐", "🌧️", "❄️", "🌈", "⚡", "🌪️"]),
|
||||
("sports", ["⚽", "🏀", "🏈", "⚾", "🎾", "🏐", "🏉", "🎱"]),
|
||||
]
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
category, emojis = random.choice(self.EMOJI_SETS)
|
||||
target = random.choice(emojis)
|
||||
|
||||
# Create options with the target and some others
|
||||
options = [target]
|
||||
other_emojis = [e for e in emojis if e != target]
|
||||
options.extend(random.sample(other_emojis, min(3, len(other_emojis))))
|
||||
random.shuffle(options)
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.EMOJI,
|
||||
question=f"Select the {self._emoji_name(target)} emoji:",
|
||||
answer=target,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def _emoji_name(self, emoji: str) -> str:
|
||||
"""Get a description of the emoji."""
|
||||
names = {
|
||||
"🐶": "dog",
|
||||
"🐱": "cat",
|
||||
"🐭": "mouse",
|
||||
"🐹": "hamster",
|
||||
"🐰": "rabbit",
|
||||
"🦊": "fox",
|
||||
"🐻": "bear",
|
||||
"🐼": "panda",
|
||||
"🍎": "apple",
|
||||
"🍐": "pear",
|
||||
"🍊": "orange",
|
||||
"🍋": "lemon",
|
||||
"🍌": "banana",
|
||||
"🍉": "watermelon",
|
||||
"🍇": "grapes",
|
||||
"🍓": "strawberry",
|
||||
"☀️": "sun",
|
||||
"🌙": "moon",
|
||||
"⭐": "star",
|
||||
"🌧️": "rain",
|
||||
"❄️": "snowflake",
|
||||
"🌈": "rainbow",
|
||||
"⚡": "lightning",
|
||||
"🌪️": "tornado",
|
||||
"⚽": "soccer ball",
|
||||
"🏀": "basketball",
|
||||
"🏈": "football",
|
||||
"⚾": "baseball",
|
||||
"🎾": "tennis",
|
||||
"🏐": "volleyball",
|
||||
"🏉": "rugby",
|
||||
"🎱": "pool ball",
|
||||
}
|
||||
return names.get(emoji, "correct")
|
||||
|
||||
|
||||
class VerificationService:
|
||||
"""Service for managing member verification."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Pending verifications: {(guild_id, user_id): PendingVerification}
|
||||
self._pending: dict[tuple[int, int], PendingVerification] = {}
|
||||
|
||||
# Challenge generators
|
||||
self._generators: dict[ChallengeType, ChallengeGenerator] = {
|
||||
ChallengeType.BUTTON: ButtonChallengeGenerator(),
|
||||
ChallengeType.CAPTCHA: CaptchaChallengeGenerator(),
|
||||
ChallengeType.MATH: MathChallengeGenerator(),
|
||||
ChallengeType.EMOJI: EmojiChallengeGenerator(),
|
||||
}
|
||||
|
||||
def create_challenge(
|
||||
self,
|
||||
user_id: int,
|
||||
guild_id: int,
|
||||
challenge_type: ChallengeType = ChallengeType.BUTTON,
|
||||
) -> PendingVerification:
|
||||
"""Create a new verification challenge for a user."""
|
||||
generator = self._generators.get(challenge_type)
|
||||
if not generator:
|
||||
generator = self._generators[ChallengeType.BUTTON]
|
||||
|
||||
challenge = generator.generate()
|
||||
pending = PendingVerification(
|
||||
user_id=user_id,
|
||||
guild_id=guild_id,
|
||||
challenge=challenge,
|
||||
)
|
||||
|
||||
self._pending[(guild_id, user_id)] = pending
|
||||
return pending
|
||||
|
||||
def get_pending(self, guild_id: int, user_id: int) -> PendingVerification | None:
|
||||
"""Get a pending verification for a user."""
|
||||
return self._pending.get((guild_id, user_id))
|
||||
|
||||
def verify(self, guild_id: int, user_id: int, response: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Attempt to verify a user's response.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
pending = self._pending.get((guild_id, user_id))
|
||||
|
||||
if not pending:
|
||||
return False, "No pending verification found."
|
||||
|
||||
if pending.challenge.is_expired:
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return False, "Verification expired. Please request a new one."
|
||||
|
||||
if pending.challenge.attempts >= pending.challenge.max_attempts:
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return False, "Too many failed attempts. Please request a new verification."
|
||||
|
||||
if pending.challenge.check_answer(response):
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return True, "Verification successful!"
|
||||
|
||||
remaining = pending.challenge.max_attempts - pending.challenge.attempts
|
||||
return False, f"Incorrect. {remaining} attempt(s) remaining."
|
||||
|
||||
def cancel(self, guild_id: int, user_id: int) -> bool:
|
||||
"""Cancel a pending verification."""
|
||||
return self._pending.pop((guild_id, user_id), None) is not None
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired verifications. Returns count of removed."""
|
||||
expired = [key for key, pending in self._pending.items() if pending.challenge.is_expired]
|
||||
for key in expired:
|
||||
self._pending.pop(key, None)
|
||||
return len(expired)
|
||||
|
||||
def get_pending_count(self, guild_id: int) -> int:
|
||||
"""Get count of pending verifications for a guild."""
|
||||
return sum(1 for (gid, _) in self._pending if gid == guild_id)
|
||||
Reference in New Issue
Block a user