update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s

This commit is contained in:
2026-01-17 21:57:04 +01:00
parent 831eed8dbc
commit abef368a68
19 changed files with 677 additions and 757 deletions

View File

@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import NamedTuple, Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple, Sequence
from urllib.parse import urlparse
if TYPE_CHECKING:
@@ -16,6 +16,7 @@ else:
try:
import discord # type: ignore
except ModuleNotFoundError: # pragma: no cover
class _DiscordStub:
class Message: # minimal stub for type hints
pass
@@ -26,120 +27,122 @@ from guardden.models.guild import BannedWord
logger = logging.getLogger(__name__)
# Circuit breaker for regex safety
class RegexTimeoutError(Exception):
"""Raised when regex execution takes too long."""
pass
class RegexCircuitBreaker:
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
def __init__(self, timeout_seconds: float = 0.1):
self.timeout_seconds = timeout_seconds
self.failed_patterns: dict[str, datetime] = {}
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
def _timeout_handler(self, signum, frame):
"""Signal handler for regex timeout."""
raise RegexTimeoutError("Regex execution timed out")
def is_pattern_disabled(self, pattern: str) -> bool:
"""Check if a pattern is temporarily disabled due to timeouts."""
if pattern not in self.failed_patterns:
return False
failure_time = self.failed_patterns[pattern]
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
# Re-enable the pattern after threshold time
del self.failed_patterns[pattern]
return False
return True
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
"""Safely execute regex search with timeout protection."""
if self.is_pattern_disabled(pattern):
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
return False
# Basic pattern validation to catch obviously problematic patterns
if self._is_dangerous_pattern(pattern):
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
return False
old_handler = None
try:
# Set up timeout signal (Unix systems only)
if hasattr(signal, 'SIGALRM'):
if hasattr(signal, "SIGALRM"):
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
start_time = time.perf_counter()
# Compile and execute regex
compiled_pattern = re.compile(pattern, flags)
result = bool(compiled_pattern.search(text))
execution_time = time.perf_counter() - start_time
# Log slow patterns for monitoring
if execution_time > self.timeout_seconds * 0.8:
logger.warning(
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
)
return result
except RegexTimeoutError:
# Pattern took too long, disable it temporarily
self.failed_patterns[pattern] = datetime.now(timezone.utc)
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
return False
except re.error as e:
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
return False
except Exception as e:
logger.error(f"Unexpected error in regex execution: {e}")
return False
finally:
# Clean up timeout signal
if hasattr(signal, 'SIGALRM') and old_handler is not None:
if hasattr(signal, "SIGALRM") and old_handler is not None:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
def _is_dangerous_pattern(self, pattern: str) -> bool:
"""Basic heuristic to detect potentially dangerous regex patterns."""
# Check for patterns that are commonly problematic
dangerous_indicators = [
r'(\w+)+', # Nested quantifiers
r'(\d+)+', # Nested quantifiers on digits
r'(.+)+', # Nested quantifiers on anything
r'(.*)+', # Nested quantifiers on anything (greedy)
r'(\w*)+', # Nested quantifiers with *
r'(\S+)+', # Nested quantifiers on non-whitespace
r"(\w+)+", # Nested quantifiers
r"(\d+)+", # Nested quantifiers on digits
r"(.+)+", # Nested quantifiers on anything
r"(.*)+", # Nested quantifiers on anything (greedy)
r"(\w*)+", # Nested quantifiers with *
r"(\S+)+", # Nested quantifiers on non-whitespace
]
# Check for excessively long patterns
if len(pattern) > 500:
return True
# Check for nested quantifiers (simplified detection)
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
return True
# Check for excessive repetition operators
if pattern.count('+') > 10 or pattern.count('*') > 10:
if pattern.count("+") > 10 or pattern.count("*") > 10:
return True
# Check for specific dangerous patterns
for dangerous in dangerous_indicators:
if dangerous in pattern:
return True
return False
@@ -240,34 +243,43 @@ def normalize_domain(value: str) -> str:
"""Normalize a domain or URL for allowlist checks with security validation."""
if not value or not isinstance(value, str):
return ""
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
return ""
text = value.strip().lower()
if not text or len(text) > 2000: # Prevent excessively long URLs
return ""
# Sanitize input to prevent injection attacks
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
return ""
try:
if "://" not in text:
text = f"http://{text}"
parsed = urlparse(text)
hostname = parsed.hostname or ""
# Additional validation for hostname
if not hostname or len(hostname) > 253: # RFC limit
return ""
# Check for malicious patterns
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
return ""
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
return ""
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
return ""
for label in hostname.split("."):
if not label:
return ""
if label.startswith("-") or label.endswith("-"):
return ""
# Remove www prefix
if hostname.startswith("www."):
hostname = hostname[4:]
return hostname
except (ValueError, UnicodeError, Exception):
# urlparse can raise various exceptions with malicious input
@@ -305,13 +317,13 @@ class AutomodService:
# Normalize: lowercase, remove extra spaces, remove special chars
# Use simple string operations for basic patterns to avoid regex overhead
normalized = content.lower()
# Remove special characters (simplified approach)
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
# Normalize whitespace
normalized = ' '.join(normalized.split())
normalized = " ".join(normalized.split())
return normalized
def check_banned_words(
@@ -369,14 +381,14 @@ class AutomodService:
# Limit URL length to prevent processing extremely long URLs
if len(url) > 2000:
continue
url_lower = url.lower()
hostname = normalize_domain(url)
# Skip if hostname normalization failed (security check)
if not hostname:
continue
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
continue
@@ -540,3 +552,11 @@ class AutomodService:
def cleanup_guild(self, guild_id: int) -> None:
"""Remove all tracking data for a guild."""
self._spam_trackers.pop(guild_id, None)
_automod_service = AutomodService()
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
"""Convenience wrapper for scam detection."""
return _automod_service.check_scam_links(content, allowlist)