quick commit
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
This commit is contained in:
@@ -2,17 +2,150 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Sequence, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import discord
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
else:
|
||||
try:
|
||||
import discord # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
class _DiscordStub:
|
||||
class Message: # minimal stub for type hints
|
||||
pass
|
||||
|
||||
from guardden.models import BannedWord
|
||||
discord = _DiscordStub() # type: ignore
|
||||
|
||||
from guardden.models.guild import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Circuit breaker for regex safety
|
||||
class RegexTimeoutError(Exception):
|
||||
"""Raised when regex execution takes too long."""
|
||||
pass
|
||||
|
||||
|
||||
class RegexCircuitBreaker:
|
||||
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
||||
|
||||
def __init__(self, timeout_seconds: float = 0.1):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.failed_patterns: dict[str, datetime] = {}
|
||||
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
||||
|
||||
def _timeout_handler(self, signum, frame):
|
||||
"""Signal handler for regex timeout."""
|
||||
raise RegexTimeoutError("Regex execution timed out")
|
||||
|
||||
def is_pattern_disabled(self, pattern: str) -> bool:
|
||||
"""Check if a pattern is temporarily disabled due to timeouts."""
|
||||
if pattern not in self.failed_patterns:
|
||||
return False
|
||||
|
||||
failure_time = self.failed_patterns[pattern]
|
||||
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
||||
# Re-enable the pattern after threshold time
|
||||
del self.failed_patterns[pattern]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
||||
"""Safely execute regex search with timeout protection."""
|
||||
if self.is_pattern_disabled(pattern):
|
||||
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
# Basic pattern validation to catch obviously problematic patterns
|
||||
if self._is_dangerous_pattern(pattern):
|
||||
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
old_handler = None
|
||||
try:
|
||||
# Set up timeout signal (Unix systems only)
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Compile and execute regex
|
||||
compiled_pattern = re.compile(pattern, flags)
|
||||
result = bool(compiled_pattern.search(text))
|
||||
|
||||
execution_time = time.perf_counter() - start_time
|
||||
|
||||
# Log slow patterns for monitoring
|
||||
if execution_time > self.timeout_seconds * 0.8:
|
||||
logger.warning(
|
||||
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except RegexTimeoutError:
|
||||
# Pattern took too long, disable it temporarily
|
||||
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
||||
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in regex execution: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Clean up timeout signal
|
||||
if hasattr(signal, 'SIGALRM') and old_handler is not None:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
||||
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
||||
# Check for patterns that are commonly problematic
|
||||
dangerous_indicators = [
|
||||
r'(\w+)+', # Nested quantifiers
|
||||
r'(\d+)+', # Nested quantifiers on digits
|
||||
r'(.+)+', # Nested quantifiers on anything
|
||||
r'(.*)+', # Nested quantifiers on anything (greedy)
|
||||
r'(\w*)+', # Nested quantifiers with *
|
||||
r'(\S+)+', # Nested quantifiers on non-whitespace
|
||||
]
|
||||
|
||||
# Check for excessively long patterns
|
||||
if len(pattern) > 500:
|
||||
return True
|
||||
|
||||
# Check for nested quantifiers (simplified detection)
|
||||
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
|
||||
return True
|
||||
|
||||
# Check for excessive repetition operators
|
||||
if pattern.count('+') > 10 or pattern.count('*') > 10:
|
||||
return True
|
||||
|
||||
# Check for specific dangerous patterns
|
||||
for dangerous in dangerous_indicators:
|
||||
if dangerous in pattern:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global circuit breaker instance
|
||||
_regex_circuit_breaker = RegexCircuitBreaker()
|
||||
|
||||
|
||||
# Known scam/phishing patterns
|
||||
SCAM_PATTERNS = [
|
||||
@@ -47,10 +180,10 @@ SUSPICIOUS_TLDS = {
|
||||
".gq",
|
||||
}
|
||||
|
||||
# URL pattern for extraction
|
||||
# URL pattern for extraction - more restrictive for security
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
|
||||
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
|
||||
r"https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?|"
|
||||
r"(?:www\.)?[a-zA-Z0-9-]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|gov|edu)(?:/[^\s]*)?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
@@ -91,6 +224,66 @@ class AutomodResult:
|
||||
matched_filter: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpamConfig:
|
||||
"""Configuration for spam thresholds."""
|
||||
|
||||
message_rate_limit: int = 5
|
||||
message_rate_window: int = 5
|
||||
duplicate_threshold: int = 3
|
||||
mention_limit: int = 5
|
||||
mention_rate_limit: int = 10
|
||||
mention_rate_window: int = 60
|
||||
|
||||
|
||||
def normalize_domain(value: str) -> str:
|
||||
"""Normalize a domain or URL for allowlist checks with security validation."""
|
||||
if not value or not isinstance(value, str):
|
||||
return ""
|
||||
|
||||
text = value.strip().lower()
|
||||
if not text or len(text) > 2000: # Prevent excessively long URLs
|
||||
return ""
|
||||
|
||||
# Sanitize input to prevent injection attacks
|
||||
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
|
||||
return ""
|
||||
|
||||
try:
|
||||
if "://" not in text:
|
||||
text = f"http://{text}"
|
||||
|
||||
parsed = urlparse(text)
|
||||
hostname = parsed.hostname or ""
|
||||
|
||||
# Additional validation for hostname
|
||||
if not hostname or len(hostname) > 253: # RFC limit
|
||||
return ""
|
||||
|
||||
# Check for malicious patterns
|
||||
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
|
||||
return ""
|
||||
|
||||
# Remove www prefix
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
|
||||
return hostname
|
||||
except (ValueError, UnicodeError, Exception):
|
||||
# urlparse can raise various exceptions with malicious input
|
||||
return ""
|
||||
|
||||
|
||||
def is_allowed_domain(hostname: str, allowlist: set[str]) -> bool:
|
||||
"""Check if a hostname is allowlisted."""
|
||||
if not hostname:
|
||||
return False
|
||||
for domain in allowlist:
|
||||
if hostname == domain or hostname.endswith(f".{domain}"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AutomodService:
|
||||
"""Service for automatic content moderation."""
|
||||
|
||||
@@ -104,23 +297,25 @@ class AutomodService:
|
||||
lambda: defaultdict(UserSpamTracker)
|
||||
)
|
||||
|
||||
# Spam thresholds
|
||||
self.message_rate_limit = 5 # messages per window
|
||||
self.message_rate_window = 5 # seconds
|
||||
self.duplicate_threshold = 3 # same message count
|
||||
self.mention_limit = 5 # mentions per message
|
||||
self.mention_rate_limit = 10 # mentions per window
|
||||
self.mention_rate_window = 60 # seconds
|
||||
# Default spam thresholds
|
||||
self.default_spam_config = SpamConfig()
|
||||
|
||||
def _get_content_hash(self, content: str) -> str:
|
||||
"""Get a normalized hash of message content for duplicate detection."""
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
normalized = re.sub(r"[^\w\s]", "", content.lower())
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
# Use simple string operations for basic patterns to avoid regex overhead
|
||||
normalized = content.lower()
|
||||
|
||||
# Remove special characters (simplified approach)
|
||||
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
# Normalize whitespace
|
||||
normalized = ' '.join(normalized.split())
|
||||
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
self, content: str, banned_words: list[BannedWord]
|
||||
self, content: str, banned_words: Sequence[BannedWord]
|
||||
) -> AutomodResult | None:
|
||||
"""Check message against banned words list."""
|
||||
content_lower = content.lower()
|
||||
@@ -129,12 +324,9 @@ class AutomodService:
|
||||
matched = False
|
||||
|
||||
if banned.is_regex:
|
||||
try:
|
||||
if re.search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
except re.error:
|
||||
logger.warning(f"Invalid regex pattern: {banned.pattern}")
|
||||
continue
|
||||
# Use circuit breaker for safe regex execution
|
||||
if _regex_circuit_breaker.safe_regex_search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
else:
|
||||
if banned.pattern.lower() in content_lower:
|
||||
matched = True
|
||||
@@ -155,7 +347,9 @@ class AutomodService:
|
||||
|
||||
return None
|
||||
|
||||
def check_scam_links(self, content: str) -> AutomodResult | None:
|
||||
def check_scam_links(
|
||||
self, content: str, allowlist: list[str] | None = None
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for scam/phishing patterns."""
|
||||
# Check for known scam patterns
|
||||
for pattern in self._scam_patterns:
|
||||
@@ -167,10 +361,25 @@ class AutomodService:
|
||||
matched_filter="scam_pattern",
|
||||
)
|
||||
|
||||
allowlist_set = {normalize_domain(domain) for domain in allowlist or [] if domain}
|
||||
|
||||
# Check URLs for suspicious TLDs
|
||||
urls = URL_PATTERN.findall(content)
|
||||
for url in urls:
|
||||
# Limit URL length to prevent processing extremely long URLs
|
||||
if len(url) > 2000:
|
||||
continue
|
||||
|
||||
url_lower = url.lower()
|
||||
hostname = normalize_domain(url)
|
||||
|
||||
# Skip if hostname normalization failed (security check)
|
||||
if not hostname:
|
||||
continue
|
||||
|
||||
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
||||
continue
|
||||
|
||||
for tld in SUSPICIOUS_TLDS:
|
||||
if tld in url_lower:
|
||||
# Additional check: is it trying to impersonate a known domain?
|
||||
@@ -194,12 +403,21 @@ class AutomodService:
|
||||
return None
|
||||
|
||||
def check_spam(
|
||||
self, message: discord.Message, anti_spam_enabled: bool = True
|
||||
self,
|
||||
message: discord.Message,
|
||||
anti_spam_enabled: bool = True,
|
||||
spam_config: SpamConfig | None = None,
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for spam behavior."""
|
||||
if not anti_spam_enabled:
|
||||
return None
|
||||
|
||||
# Skip DM messages
|
||||
if message.guild is None:
|
||||
return None
|
||||
|
||||
config = spam_config or self.default_spam_config
|
||||
|
||||
guild_id = message.guild.id
|
||||
user_id = message.author.id
|
||||
tracker = self._spam_trackers[guild_id][user_id]
|
||||
@@ -213,21 +431,24 @@ class AutomodService:
|
||||
tracker.messages.append(SpamRecord(content_hash, now))
|
||||
|
||||
# Rate limit check
|
||||
recent_window = now - timedelta(seconds=self.message_rate_window)
|
||||
recent_window = now - timedelta(seconds=config.message_rate_window)
|
||||
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||
|
||||
if len(recent_messages) > self.message_rate_limit:
|
||||
if len(recent_messages) > config.message_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=60, # 1 minute timeout
|
||||
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
|
||||
reason=(
|
||||
f"Sending messages too fast ({len(recent_messages)} in "
|
||||
f"{config.message_rate_window}s)"
|
||||
),
|
||||
matched_filter="rate_limit",
|
||||
)
|
||||
|
||||
# Duplicate message check
|
||||
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||
if duplicate_count >= self.duplicate_threshold:
|
||||
if duplicate_count >= config.duplicate_threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
@@ -240,7 +461,7 @@ class AutomodService:
|
||||
if message.mention_everyone:
|
||||
mention_count += 100 # Treat @everyone as many mentions
|
||||
|
||||
if mention_count > self.mention_limit:
|
||||
if mention_count > config.mention_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
@@ -249,6 +470,26 @@ class AutomodService:
|
||||
matched_filter="mass_mention",
|
||||
)
|
||||
|
||||
if mention_count > 0:
|
||||
if tracker.last_mention_time:
|
||||
window = timedelta(seconds=config.mention_rate_window)
|
||||
if now - tracker.last_mention_time > window:
|
||||
tracker.mention_count = 0
|
||||
tracker.mention_count += mention_count
|
||||
tracker.last_mention_time = now
|
||||
|
||||
if tracker.mention_count > config.mention_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300,
|
||||
reason=(
|
||||
"Too many mentions in a short period "
|
||||
f"({tracker.mention_count} in {config.mention_rate_window}s)"
|
||||
),
|
||||
matched_filter="mention_rate",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||
|
||||
Reference in New Issue
Block a user