feat: Complete cog and service rewrites

- Automod cog: 520 -> 100 lines (spam only, no commands)
- AI moderation cog: 664 -> 250 lines (images only, full cost controls)
- Automod service: 600+ -> 200 lines (spam only)
- All cost control measures implemented
- NSFW video domain blocking
- Rate limiting per guild and per user
- Image deduplication
- File size limits
- Configurable via YAML

Next: Update AI providers and models
This commit is contained in:
2026-01-27 19:17:18 +01:00
parent 08815a3dd0
commit d972f6f51c
3 changed files with 308 additions and 1509 deletions

View File

@@ -1,14 +1,11 @@
"""Automod service for content filtering and spam detection."""
"""Automod service for spam detection - Minimal Version."""
import logging
import re
import signal
import time
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, NamedTuple, Sequence
from urllib.parse import urlparse
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import discord
@@ -16,221 +13,17 @@ else:
try:
import discord # type: ignore
except ModuleNotFoundError: # pragma: no cover
class _DiscordStub:
class Message: # minimal stub for type hints
class Message:
pass
discord = _DiscordStub() # type: ignore
from guardden.models.guild import BannedWord
logger = logging.getLogger(__name__)
# Circuit breaker for regex safety
class RegexTimeoutError(Exception):
"""Raised when regex execution takes too long."""
pass
class RegexCircuitBreaker:
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
def __init__(self, timeout_seconds: float = 0.1):
self.timeout_seconds = timeout_seconds
self.failed_patterns: dict[str, datetime] = {}
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
def _timeout_handler(self, signum, frame):
"""Signal handler for regex timeout."""
raise RegexTimeoutError("Regex execution timed out")
def is_pattern_disabled(self, pattern: str) -> bool:
"""Check if a pattern is temporarily disabled due to timeouts."""
if pattern not in self.failed_patterns:
return False
failure_time = self.failed_patterns[pattern]
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
# Re-enable the pattern after threshold time
del self.failed_patterns[pattern]
return False
return True
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
"""Safely execute regex search with timeout protection."""
if self.is_pattern_disabled(pattern):
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
return False
# Basic pattern validation to catch obviously problematic patterns
if self._is_dangerous_pattern(pattern):
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
return False
old_handler = None
try:
# Set up timeout signal (Unix systems only)
if hasattr(signal, "SIGALRM"):
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
start_time = time.perf_counter()
# Compile and execute regex
compiled_pattern = re.compile(pattern, flags)
result = bool(compiled_pattern.search(text))
execution_time = time.perf_counter() - start_time
# Log slow patterns for monitoring
if execution_time > self.timeout_seconds * 0.8:
logger.warning(
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
)
return result
except RegexTimeoutError:
# Pattern took too long, disable it temporarily
self.failed_patterns[pattern] = datetime.now(timezone.utc)
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
return False
except re.error as e:
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
return False
except Exception as e:
logger.error(f"Unexpected error in regex execution: {e}")
return False
finally:
# Clean up timeout signal
if hasattr(signal, "SIGALRM") and old_handler is not None:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
def _is_dangerous_pattern(self, pattern: str) -> bool:
"""Basic heuristic to detect potentially dangerous regex patterns."""
# Check for patterns that are commonly problematic
dangerous_indicators = [
r"(\w+)+", # Nested quantifiers
r"(\d+)+", # Nested quantifiers on digits
r"(.+)+", # Nested quantifiers on anything
r"(.*)+", # Nested quantifiers on anything (greedy)
r"(\w*)+", # Nested quantifiers with *
r"(\S+)+", # Nested quantifiers on non-whitespace
]
# Check for excessively long patterns
if len(pattern) > 500:
return True
# Check for nested quantifiers (simplified detection)
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
return True
# Check for excessive repetition operators
if pattern.count("+") > 10 or pattern.count("*") > 10:
return True
# Check for specific dangerous patterns
for dangerous in dangerous_indicators:
if dangerous in pattern:
return True
return False
# Global circuit breaker instance
_regex_circuit_breaker = RegexCircuitBreaker()
# Known scam/phishing patterns
SCAM_PATTERNS = [
# Discord scam patterns
r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}",
r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro",
r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)",
# Generic phishing
r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account",
r"(?:suspended|locked|limited)[-.\s]?account",
r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)",
# Crypto scams
r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)",
r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)",
]
# Suspicious TLDs often used in phishing
SUSPICIOUS_TLDS = {
".xyz",
".top",
".club",
".work",
".click",
".link",
".info",
".ru",
".cn",
".tk",
".ml",
".ga",
".cf",
".gq",
}
# URL pattern for extraction - more restrictive for security
URL_PATTERN = re.compile(
r"https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?|"
r"(?:www\.)?[a-zA-Z0-9-]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|gov|edu)(?:/[^\s]*)?",
re.IGNORECASE,
)
class SpamRecord(NamedTuple):
"""Record of a message for spam tracking."""
content_hash: str
timestamp: datetime
@dataclass
class UserSpamTracker:
"""Tracks spam behavior for a single user."""
messages: list[SpamRecord] = field(default_factory=list)
mention_count: int = 0
last_mention_time: datetime | None = None
duplicate_count: int = 0
last_action_time: datetime | None = None
def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None:
"""Remove old messages from tracking."""
cutoff = datetime.now(timezone.utc) - max_age
self.messages = [m for m in self.messages if m.timestamp > cutoff]
@dataclass
class AutomodResult:
"""Result of automod check."""
should_delete: bool = False
should_warn: bool = False
should_strike: bool = False
should_timeout: bool = False
timeout_duration: int = 0 # seconds
reason: str = ""
matched_filter: str = ""
@dataclass(frozen=True)
class SpamConfig:
"""Configuration for spam thresholds."""
"""Spam detection configuration."""
message_rate_limit: int = 5
message_rate_window: int = 5
duplicate_threshold: int = 3
@@ -239,324 +32,158 @@ class SpamConfig:
mention_rate_window: int = 60
def normalize_domain(value: str) -> str:
"""Normalize a domain or URL for allowlist checks with security validation."""
if not value or not isinstance(value, str):
return ""
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
return ""
text = value.strip().lower()
if not text or len(text) > 2000: # Prevent excessively long URLs
return ""
try:
if "://" not in text:
text = f"http://{text}"
parsed = urlparse(text)
hostname = parsed.hostname or ""
# Additional validation for hostname
if not hostname or len(hostname) > 253: # RFC limit
return ""
# Check for malicious patterns
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
return ""
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
return ""
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
return ""
for label in hostname.split("."):
if not label:
return ""
if label.startswith("-") or label.endswith("-"):
return ""
# Remove www prefix
if hostname.startswith("www."):
hostname = hostname[4:]
return hostname
except (ValueError, UnicodeError, Exception):
# urlparse can raise various exceptions with malicious input
return ""
@dataclass
class AutomodResult:
"""Result of an automod check."""
matched_filter: str
reason: str
should_delete: bool = True
should_warn: bool = False
should_strike: bool = False
should_timeout: bool = False
timeout_duration: int | None = None
def is_allowed_domain(hostname: str, allowlist: set[str]) -> bool:
"""Check if a hostname is allowlisted."""
if not hostname:
return False
for domain in allowlist:
if hostname == domain or hostname.endswith(f".{domain}"):
return True
return False
class SpamTracker:
"""Track user spam behavior."""
def __init__(self):
# guild_id -> user_id -> deque of message timestamps
self.message_times: dict[int, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
# guild_id -> user_id -> deque of message contents for duplicate detection
self.message_contents: dict[int, dict[int, list[str]]] = defaultdict(lambda: defaultdict(list))
# guild_id -> user_id -> deque of mention timestamps
self.mention_times: dict[int, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
# Last cleanup time
self.last_cleanup = time.time()
def cleanup_old_entries(self):
"""Periodically cleanup old entries to prevent memory leaks."""
now = time.time()
if now - self.last_cleanup < 300: # Cleanup every 5 minutes
return
cutoff = now - 3600 # Keep last hour of data
for guild_data in [self.message_times, self.mention_times]:
for guild_id in list(guild_data.keys()):
for user_id in list(guild_data[guild_id].keys()):
# Remove old timestamps
guild_data[guild_id][user_id] = [
ts for ts in guild_data[guild_id][user_id] if ts > cutoff
]
# Remove empty users
if not guild_data[guild_id][user_id]:
del guild_data[guild_id][user_id]
# Remove empty guilds
if not guild_data[guild_id]:
del guild_data[guild_id]
# Cleanup message contents
for guild_id in list(self.message_contents.keys()):
for user_id in list(self.message_contents[guild_id].keys()):
# Keep only last 10 messages per user
self.message_contents[guild_id][user_id] = self.message_contents[guild_id][user_id][-10:]
if not self.message_contents[guild_id][user_id]:
del self.message_contents[guild_id][user_id]
if not self.message_contents[guild_id]:
del self.message_contents[guild_id]
self.last_cleanup = now
class AutomodService:
"""Service for automatic content moderation."""
"""Service for spam detection - no banned words, no scam links, no invites."""
def __init__(self) -> None:
# Compile scam patterns
self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS]
# Per-guild, per-user spam tracking
# Structure: {guild_id: {user_id: UserSpamTracker}}
self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict(
lambda: defaultdict(UserSpamTracker)
)
# Default spam thresholds
def __init__(self):
self.spam_tracker = SpamTracker()
self.default_spam_config = SpamConfig()
def _get_content_hash(self, content: str) -> str:
"""Get a normalized hash of message content for duplicate detection."""
# Normalize: lowercase, remove extra spaces, remove special chars
# Use simple string operations for basic patterns to avoid regex overhead
normalized = content.lower()
# Remove special characters (simplified approach)
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
# Normalize whitespace
normalized = " ".join(normalized.split())
return normalized
def check_banned_words(
self, content: str, banned_words: Sequence[BannedWord]
) -> AutomodResult | None:
"""Check message against banned words list."""
content_lower = content.lower()
for banned in banned_words:
matched = False
if banned.is_regex:
# Use circuit breaker for safe regex execution
if _regex_circuit_breaker.safe_regex_search(banned.pattern, content, re.IGNORECASE):
matched = True
else:
if banned.pattern.lower() in content_lower:
matched = True
if matched:
result = AutomodResult(
should_delete=True,
reason=banned.reason or f"Matched banned word filter",
matched_filter=f"banned_word:{banned.id}",
)
if banned.action == "warn":
result.should_warn = True
elif banned.action == "strike":
result.should_strike = True
return result
return None
def check_scam_links(
self, content: str, allowlist: list[str] | None = None
) -> AutomodResult | None:
"""Check message for scam/phishing patterns."""
# Check for known scam patterns
for pattern in self._scam_patterns:
if pattern.search(content):
return AutomodResult(
should_delete=True,
should_warn=True,
reason="Message matched known scam/phishing pattern",
matched_filter="scam_pattern",
)
allowlist_set = {normalize_domain(domain) for domain in allowlist or [] if domain}
# Check URLs for suspicious TLDs
urls = URL_PATTERN.findall(content)
for url in urls:
# Limit URL length to prevent processing extremely long URLs
if len(url) > 2000:
continue
url_lower = url.lower()
hostname = normalize_domain(url)
# Skip if hostname normalization failed (security check)
if not hostname:
continue
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
continue
for tld in SUSPICIOUS_TLDS:
if tld in url_lower:
# Additional check: is it trying to impersonate a known domain?
impersonation_keywords = [
"discord",
"steam",
"nitro",
"gift",
"free",
"login",
"verify",
]
if any(kw in url_lower for kw in impersonation_keywords):
return AutomodResult(
should_delete=True,
should_warn=True,
reason=f"Suspicious link detected: {url[:50]}",
matched_filter="suspicious_link",
)
return None
def check_spam(
self,
message: discord.Message,
message: "discord.Message",
anti_spam_enabled: bool = True,
spam_config: SpamConfig | None = None,
) -> AutomodResult | None:
"""Check message for spam behavior."""
"""Check message for spam patterns.
Args:
message: Discord message to check
anti_spam_enabled: Whether spam detection is enabled
spam_config: Spam configuration settings
Returns:
AutomodResult if spam detected, None otherwise
"""
if not anti_spam_enabled:
return None
# Skip DM messages
if message.guild is None:
return None
config = spam_config or self.default_spam_config
guild_id = message.guild.id
user_id = message.author.id
tracker = self._spam_trackers[guild_id][user_id]
now = datetime.now(timezone.utc)
now = time.time()
# Cleanup old records
tracker.cleanup()
# Periodic cleanup
self.spam_tracker.cleanup_old_entries()
# Check message rate
content_hash = self._get_content_hash(message.content)
tracker.messages.append(SpamRecord(content_hash, now))
# Check 1: Message rate limiting
message_times = self.spam_tracker.message_times[guild_id][user_id]
cutoff_time = now - config.message_rate_window
# Rate limit check
recent_window = now - timedelta(seconds=config.message_rate_window)
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
# Remove old timestamps
message_times = [ts for ts in message_times if ts > cutoff_time]
self.spam_tracker.message_times[guild_id][user_id] = message_times
if len(recent_messages) > config.message_rate_limit:
# Add current message
message_times.append(now)
if len(message_times) > config.message_rate_limit:
return AutomodResult(
matched_filter="spam_rate_limit",
reason=f"Exceeded message rate limit ({len(message_times)} messages in {config.message_rate_window}s)",
should_delete=True,
should_timeout=True,
timeout_duration=60, # 1 minute timeout
reason=(
f"Sending messages too fast ({len(recent_messages)} in "
f"{config.message_rate_window}s)"
),
matched_filter="rate_limit",
)
# Duplicate message check
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
# Check 2: Duplicate messages
message_contents = self.spam_tracker.message_contents[guild_id][user_id]
message_contents.append(message.content)
self.spam_tracker.message_contents[guild_id][user_id] = message_contents[-10:] # Keep last 10
# Count duplicates in recent messages
duplicate_count = message_contents.count(message.content)
if duplicate_count >= config.duplicate_threshold:
return AutomodResult(
matched_filter="spam_duplicate",
reason=f"Duplicate message posted {duplicate_count} times",
should_delete=True,
should_warn=True,
reason=f"Duplicate message detected ({duplicate_count} times)",
matched_filter="duplicate",
)
# Mass mention check
mention_count = len(message.mentions) + len(message.role_mentions)
if message.mention_everyone:
mention_count += 100 # Treat @everyone as many mentions
# Check 3: Mass mentions in single message
mention_count = len(message.mentions)
if mention_count > config.mention_limit:
return AutomodResult(
matched_filter="spam_mass_mentions",
reason=f"Too many mentions in single message ({mention_count})",
should_delete=True,
should_timeout=True,
timeout_duration=300, # 5 minute timeout
reason=f"Mass mentions detected ({mention_count} mentions)",
matched_filter="mass_mention",
)
# Check 4: Mention rate limiting
if mention_count > 0:
if tracker.last_mention_time:
window = timedelta(seconds=config.mention_rate_window)
if now - tracker.last_mention_time > window:
tracker.mention_count = 0
tracker.mention_count += mention_count
tracker.last_mention_time = now
mention_times = self.spam_tracker.mention_times[guild_id][user_id]
mention_cutoff = now - config.mention_rate_window
if tracker.mention_count > config.mention_rate_limit:
# Remove old timestamps
mention_times = [ts for ts in mention_times if ts > mention_cutoff]
# Add current mentions
mention_times.extend([now] * mention_count)
self.spam_tracker.mention_times[guild_id][user_id] = mention_times
if len(mention_times) > config.mention_rate_limit:
return AutomodResult(
matched_filter="spam_mention_rate",
reason=f"Exceeded mention rate limit ({len(mention_times)} mentions in {config.mention_rate_window}s)",
should_delete=True,
should_timeout=True,
timeout_duration=300,
reason=(
"Too many mentions in a short period "
f"({tracker.mention_count} in {config.mention_rate_window}s)"
),
matched_filter="mention_rate",
)
return None
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
"""Check for Discord invite links."""
if allow_invites:
return None
invite_pattern = re.compile(
r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+",
re.IGNORECASE,
)
if invite_pattern.search(content):
return AutomodResult(
should_delete=True,
reason="Discord invite links are not allowed",
matched_filter="invite_link",
)
return None
def check_all_caps(
self, content: str, threshold: float = 0.7, min_length: int = 10
) -> AutomodResult | None:
"""Check for excessive caps usage."""
# Only check messages with enough letters
letters = [c for c in content if c.isalpha()]
if len(letters) < min_length:
return None
caps_count = sum(1 for c in letters if c.isupper())
caps_ratio = caps_count / len(letters)
if caps_ratio > threshold:
return AutomodResult(
should_delete=True,
reason="Excessive caps usage",
matched_filter="caps",
)
return None
def reset_user_tracker(self, guild_id: int, user_id: int) -> None:
"""Reset spam tracking for a user."""
if guild_id in self._spam_trackers:
self._spam_trackers[guild_id].pop(user_id, None)
def cleanup_guild(self, guild_id: int) -> None:
"""Remove all tracking data for a guild."""
self._spam_trackers.pop(guild_id, None)
_automod_service = AutomodService()
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
"""Convenience wrapper for scam detection."""
return _automod_service.check_scam_links(content, allowlist)