update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple, Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, NamedTuple, Sequence
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,6 +16,7 @@ else:
|
||||
try:
|
||||
import discord # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
|
||||
class _DiscordStub:
|
||||
class Message: # minimal stub for type hints
|
||||
pass
|
||||
@@ -26,120 +27,122 @@ from guardden.models.guild import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Circuit breaker for regex safety
|
||||
class RegexTimeoutError(Exception):
|
||||
"""Raised when regex execution takes too long."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RegexCircuitBreaker:
|
||||
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
||||
|
||||
|
||||
def __init__(self, timeout_seconds: float = 0.1):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.failed_patterns: dict[str, datetime] = {}
|
||||
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
||||
|
||||
|
||||
def _timeout_handler(self, signum, frame):
|
||||
"""Signal handler for regex timeout."""
|
||||
raise RegexTimeoutError("Regex execution timed out")
|
||||
|
||||
|
||||
def is_pattern_disabled(self, pattern: str) -> bool:
|
||||
"""Check if a pattern is temporarily disabled due to timeouts."""
|
||||
if pattern not in self.failed_patterns:
|
||||
return False
|
||||
|
||||
|
||||
failure_time = self.failed_patterns[pattern]
|
||||
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
||||
# Re-enable the pattern after threshold time
|
||||
del self.failed_patterns[pattern]
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
||||
"""Safely execute regex search with timeout protection."""
|
||||
if self.is_pattern_disabled(pattern):
|
||||
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
# Basic pattern validation to catch obviously problematic patterns
|
||||
if self._is_dangerous_pattern(pattern):
|
||||
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
old_handler = None
|
||||
try:
|
||||
# Set up timeout signal (Unix systems only)
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
if hasattr(signal, "SIGALRM"):
|
||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
||||
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
|
||||
# Compile and execute regex
|
||||
compiled_pattern = re.compile(pattern, flags)
|
||||
result = bool(compiled_pattern.search(text))
|
||||
|
||||
|
||||
execution_time = time.perf_counter() - start_time
|
||||
|
||||
|
||||
# Log slow patterns for monitoring
|
||||
if execution_time > self.timeout_seconds * 0.8:
|
||||
logger.warning(
|
||||
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except RegexTimeoutError:
|
||||
# Pattern took too long, disable it temporarily
|
||||
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
||||
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in regex execution: {e}")
|
||||
return False
|
||||
|
||||
|
||||
finally:
|
||||
# Clean up timeout signal
|
||||
if hasattr(signal, 'SIGALRM') and old_handler is not None:
|
||||
if hasattr(signal, "SIGALRM") and old_handler is not None:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
|
||||
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
||||
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
||||
# Check for patterns that are commonly problematic
|
||||
dangerous_indicators = [
|
||||
r'(\w+)+', # Nested quantifiers
|
||||
r'(\d+)+', # Nested quantifiers on digits
|
||||
r'(.+)+', # Nested quantifiers on anything
|
||||
r'(.*)+', # Nested quantifiers on anything (greedy)
|
||||
r'(\w*)+', # Nested quantifiers with *
|
||||
r'(\S+)+', # Nested quantifiers on non-whitespace
|
||||
r"(\w+)+", # Nested quantifiers
|
||||
r"(\d+)+", # Nested quantifiers on digits
|
||||
r"(.+)+", # Nested quantifiers on anything
|
||||
r"(.*)+", # Nested quantifiers on anything (greedy)
|
||||
r"(\w*)+", # Nested quantifiers with *
|
||||
r"(\S+)+", # Nested quantifiers on non-whitespace
|
||||
]
|
||||
|
||||
|
||||
# Check for excessively long patterns
|
||||
if len(pattern) > 500:
|
||||
return True
|
||||
|
||||
|
||||
# Check for nested quantifiers (simplified detection)
|
||||
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
|
||||
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
|
||||
return True
|
||||
|
||||
|
||||
# Check for excessive repetition operators
|
||||
if pattern.count('+') > 10 or pattern.count('*') > 10:
|
||||
if pattern.count("+") > 10 or pattern.count("*") > 10:
|
||||
return True
|
||||
|
||||
|
||||
# Check for specific dangerous patterns
|
||||
for dangerous in dangerous_indicators:
|
||||
if dangerous in pattern:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -240,34 +243,43 @@ def normalize_domain(value: str) -> str:
|
||||
"""Normalize a domain or URL for allowlist checks with security validation."""
|
||||
if not value or not isinstance(value, str):
|
||||
return ""
|
||||
|
||||
|
||||
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
|
||||
return ""
|
||||
|
||||
text = value.strip().lower()
|
||||
if not text or len(text) > 2000: # Prevent excessively long URLs
|
||||
return ""
|
||||
|
||||
# Sanitize input to prevent injection attacks
|
||||
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
if "://" not in text:
|
||||
text = f"http://{text}"
|
||||
|
||||
|
||||
parsed = urlparse(text)
|
||||
hostname = parsed.hostname or ""
|
||||
|
||||
|
||||
# Additional validation for hostname
|
||||
if not hostname or len(hostname) > 253: # RFC limit
|
||||
return ""
|
||||
|
||||
|
||||
# Check for malicious patterns
|
||||
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
|
||||
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
|
||||
return ""
|
||||
|
||||
|
||||
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
|
||||
return ""
|
||||
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
|
||||
return ""
|
||||
for label in hostname.split("."):
|
||||
if not label:
|
||||
return ""
|
||||
if label.startswith("-") or label.endswith("-"):
|
||||
return ""
|
||||
|
||||
# Remove www prefix
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
|
||||
|
||||
return hostname
|
||||
except (ValueError, UnicodeError, Exception):
|
||||
# urlparse can raise various exceptions with malicious input
|
||||
@@ -305,13 +317,13 @@ class AutomodService:
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
# Use simple string operations for basic patterns to avoid regex overhead
|
||||
normalized = content.lower()
|
||||
|
||||
|
||||
# Remove special characters (simplified approach)
|
||||
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
# Normalize whitespace
|
||||
normalized = ' '.join(normalized.split())
|
||||
|
||||
normalized = " ".join(normalized.split())
|
||||
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
@@ -369,14 +381,14 @@ class AutomodService:
|
||||
# Limit URL length to prevent processing extremely long URLs
|
||||
if len(url) > 2000:
|
||||
continue
|
||||
|
||||
|
||||
url_lower = url.lower()
|
||||
hostname = normalize_domain(url)
|
||||
|
||||
|
||||
# Skip if hostname normalization failed (security check)
|
||||
if not hostname:
|
||||
continue
|
||||
|
||||
|
||||
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
||||
continue
|
||||
|
||||
@@ -540,3 +552,11 @@ class AutomodService:
|
||||
def cleanup_guild(self, guild_id: int) -> None:
|
||||
"""Remove all tracking data for a guild."""
|
||||
self._spam_trackers.pop(guild_id, None)
|
||||
|
||||
|
||||
_automod_service = AutomodService()
|
||||
|
||||
|
||||
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
|
||||
"""Convenience wrapper for scam detection."""
|
||||
return _automod_service.check_scam_links(content, allowlist)
|
||||
|
||||
Reference in New Issue
Block a user