update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
This commit is contained in:
@@ -42,6 +42,7 @@ class GuardDen(commands.Bot):
|
||||
self.database = Database(settings)
|
||||
self.guild_config: "GuildConfigService | None" = None
|
||||
self.ai_provider: AIProvider | None = None
|
||||
self.wordlist_service = None
|
||||
self.rate_limiter = RateLimiter()
|
||||
|
||||
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
||||
@@ -90,6 +91,9 @@ class GuardDen(commands.Bot):
|
||||
from guardden.services.guild_config import GuildConfigService
|
||||
|
||||
self.guild_config = GuildConfigService(self.database)
|
||||
from guardden.services.wordlist import WordlistService
|
||||
|
||||
self.wordlist_service = WordlistService(self.database, self.settings)
|
||||
|
||||
# Initialize AI provider
|
||||
api_key = None
|
||||
@@ -115,6 +119,7 @@ class GuardDen(commands.Bot):
|
||||
"guardden.cogs.ai_moderation",
|
||||
"guardden.cogs.verification",
|
||||
"guardden.cogs.health",
|
||||
"guardden.cogs.wordlist_sync",
|
||||
]
|
||||
|
||||
failed_cogs = []
|
||||
@@ -131,7 +136,7 @@ class GuardDen(commands.Bot):
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True)
|
||||
failed_cogs.append(cog)
|
||||
|
||||
|
||||
if failed_cogs:
|
||||
logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}")
|
||||
# Don't fail startup if some cogs fail to load, but log it prominently
|
||||
@@ -146,7 +151,7 @@ class GuardDen(commands.Bot):
|
||||
if self.guild_config:
|
||||
initialized = 0
|
||||
failed_guilds = []
|
||||
|
||||
|
||||
for guild in self.guilds:
|
||||
try:
|
||||
if not self.is_guild_allowed(guild.id):
|
||||
@@ -162,12 +167,17 @@ class GuardDen(commands.Bot):
|
||||
await self.guild_config.create_guild(guild)
|
||||
initialized += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
failed_guilds.append(guild.id)
|
||||
|
||||
logger.info("Initialized config for %s guild(s)", initialized)
|
||||
if failed_guilds:
|
||||
logger.warning(f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}")
|
||||
logger.warning(
|
||||
f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}"
|
||||
)
|
||||
|
||||
# Set presence
|
||||
activity = discord.Activity(
|
||||
@@ -206,9 +216,7 @@ class GuardDen(commands.Bot):
|
||||
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
||||
|
||||
if not self.is_guild_allowed(guild.id):
|
||||
logger.warning(
|
||||
"Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id
|
||||
)
|
||||
logger.warning("Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id)
|
||||
await guild.leave()
|
||||
return
|
||||
|
||||
|
||||
38
src/guardden/cogs/wordlist_sync.py
Normal file
38
src/guardden/cogs/wordlist_sync.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Background task for managed wordlist syncing."""
|
||||
|
||||
import logging
|
||||
|
||||
from discord.ext import commands, tasks
|
||||
|
||||
from guardden.services.wordlist import WordlistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordlistSync(commands.Cog):
|
||||
"""Periodic sync of managed wordlists into guild bans."""
|
||||
|
||||
def __init__(self, bot: commands.Bot, service: WordlistService) -> None:
|
||||
self.bot = bot
|
||||
self.service = service
|
||||
self.sync_task.change_interval(hours=service.update_interval.total_seconds() / 3600)
|
||||
self.sync_task.start()
|
||||
|
||||
def cog_unload(self) -> None:
|
||||
self.sync_task.cancel()
|
||||
|
||||
@tasks.loop(hours=1)
|
||||
async def sync_task(self) -> None:
|
||||
await self.service.sync_all()
|
||||
|
||||
@sync_task.before_loop
|
||||
async def before_sync_task(self) -> None:
|
||||
await self.bot.wait_until_ready()
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot) -> None:
|
||||
service = getattr(bot, "wordlist_service", None)
|
||||
if not service:
|
||||
logger.warning("Wordlist service not initialized; skipping sync task")
|
||||
return
|
||||
await bot.add_cog(WordlistSync(bot, service))
|
||||
@@ -5,9 +5,9 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator, ValidationError
|
||||
from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from pydantic_settings.sources import EnvSettingsSource
|
||||
|
||||
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
|
||||
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
|
||||
@@ -19,17 +19,17 @@ def _validate_discord_id(value: str | int) -> int:
|
||||
id_str = str(value)
|
||||
else:
|
||||
id_str = str(value).strip()
|
||||
|
||||
|
||||
# Check format
|
||||
if not DISCORD_ID_PATTERN.match(id_str):
|
||||
raise ValueError(f"Invalid Discord ID format: {id_str}")
|
||||
|
||||
|
||||
# Convert to int and validate range
|
||||
discord_id = int(id_str)
|
||||
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
|
||||
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
|
||||
raise ValueError(f"Discord ID out of valid range: {discord_id}")
|
||||
|
||||
|
||||
return discord_id
|
||||
|
||||
|
||||
@@ -65,6 +65,27 @@ def _parse_id_list(value: Any) -> list[int]:
|
||||
return parsed
|
||||
|
||||
|
||||
class GuardDenEnvSettingsSource(EnvSettingsSource):
|
||||
"""Environment settings source with safe list parsing."""
|
||||
|
||||
def decode_complex_value(self, field_name: str, field, value: Any):
|
||||
if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str):
|
||||
return value
|
||||
return super().decode_complex_value(field_name, field, value)
|
||||
|
||||
|
||||
class WordlistSourceConfig(BaseModel):
|
||||
"""Configuration for a managed wordlist source."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
category: Literal["hard", "soft", "context"]
|
||||
action: Literal["delete", "warn", "strike"]
|
||||
reason: str
|
||||
is_regex: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
@@ -73,8 +94,25 @@ class Settings(BaseSettings):
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
env_prefix="GUARDDEN_",
|
||||
env_parse_none_str="",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls,
|
||||
init_settings,
|
||||
env_settings,
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
):
|
||||
return (
|
||||
init_settings,
|
||||
GuardDenEnvSettingsSource(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
# Discord settings
|
||||
discord_token: SecretStr = Field(..., description="Discord bot token")
|
||||
discord_prefix: str = Field(default="!", description="Default command prefix")
|
||||
@@ -114,11 +152,43 @@ class Settings(BaseSettings):
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||
|
||||
# Wordlist sync
|
||||
wordlist_enabled: bool = Field(
|
||||
default=True, description="Enable automatic managed wordlist syncing"
|
||||
)
|
||||
wordlist_update_hours: int = Field(
|
||||
default=168, description="Managed wordlist sync interval in hours"
|
||||
)
|
||||
wordlist_sources: list[WordlistSourceConfig] = Field(
|
||||
default_factory=list,
|
||||
description="Managed wordlist sources (JSON array via env overrides)",
|
||||
)
|
||||
|
||||
@field_validator("allowed_guilds", "owner_ids", mode="before")
|
||||
@classmethod
|
||||
def _validate_id_list(cls, value: Any) -> list[int]:
|
||||
return _parse_id_list(value)
|
||||
|
||||
@field_validator("wordlist_sources", mode="before")
|
||||
@classmethod
|
||||
def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [WordlistSourceConfig.model_validate(item) for item in value]
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("Invalid JSON for wordlist_sources") from exc
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("wordlist_sources must be a JSON array")
|
||||
return [WordlistSourceConfig.model_validate(item) for item in data]
|
||||
return []
|
||||
|
||||
@field_validator("discord_token")
|
||||
@classmethod
|
||||
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
|
||||
@@ -126,11 +196,11 @@ class Settings(BaseSettings):
|
||||
token = value.get_secret_value()
|
||||
if not token:
|
||||
raise ValueError("Discord token cannot be empty")
|
||||
|
||||
|
||||
# Basic Discord token format validation (not perfect but catches common issues)
|
||||
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
|
||||
raise ValueError("Invalid Discord token format")
|
||||
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("anthropic_api_key", "openai_api_key")
|
||||
@@ -139,15 +209,15 @@ class Settings(BaseSettings):
|
||||
"""Validate API key format if provided."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
|
||||
key = value.get_secret_value()
|
||||
if not key:
|
||||
return None
|
||||
|
||||
|
||||
# Basic API key validation
|
||||
if len(key) < 20:
|
||||
raise ValueError("API key too short to be valid")
|
||||
|
||||
|
||||
return value
|
||||
|
||||
def validate_configuration(self) -> None:
|
||||
@@ -157,17 +227,21 @@ class Settings(BaseSettings):
|
||||
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
|
||||
if self.ai_provider == "openai" and not self.openai_api_key:
|
||||
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
|
||||
|
||||
|
||||
# Database pool validation
|
||||
if self.database_pool_min > self.database_pool_max:
|
||||
raise ValueError("database_pool_min cannot be greater than database_pool_max")
|
||||
if self.database_pool_min < 1:
|
||||
raise ValueError("database_pool_min must be at least 1")
|
||||
|
||||
|
||||
# Data directory validation
|
||||
if not isinstance(self.data_dir, Path):
|
||||
raise ValueError("data_dir must be a valid path")
|
||||
|
||||
# Wordlist validation
|
||||
if self.wordlist_update_hours < 1:
|
||||
raise ValueError("wordlist_update_hours must be at least 1")
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
|
||||
16
src/guardden/dashboard/__main__.py
Normal file
16
src/guardden/dashboard/__main__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Dashboard entrypoint for `python -m guardden.dashboard`."""
|
||||
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
||||
def main() -> None:
|
||||
host = os.getenv("GUARDDEN_DASHBOARD_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("GUARDDEN_DASHBOARD_PORT", "8000"))
|
||||
log_level = os.getenv("GUARDDEN_LOG_LEVEL", "info").lower()
|
||||
uvicorn.run("guardden.dashboard.main:app", host=host, port=port, log_level=log_level)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,7 +3,7 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy import JSON, Boolean, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -59,7 +59,9 @@ class GuildSettings(Base, TimestampMixin):
|
||||
# Role configuration
|
||||
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False)
|
||||
mod_role_ids: Mapped[dict] = mapped_column(
|
||||
JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False
|
||||
)
|
||||
|
||||
# Moderation settings
|
||||
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
@@ -73,11 +75,13 @@ class GuildSettings(Base, TimestampMixin):
|
||||
mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False)
|
||||
mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
||||
scam_allowlist: Mapped[list[str]] = mapped_column(JSONB, default=list, nullable=False)
|
||||
scam_allowlist: Mapped[list[str]] = mapped_column(
|
||||
JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False
|
||||
)
|
||||
|
||||
# Strike thresholds (actions at each threshold)
|
||||
strike_actions: Mapped[dict] = mapped_column(
|
||||
JSONB,
|
||||
JSONB().with_variant(JSON(), "sqlite"),
|
||||
default=lambda: {
|
||||
"1": {"action": "warn"},
|
||||
"3": {"action": "timeout", "duration": 3600},
|
||||
@@ -88,11 +92,11 @@ class GuildSettings(Base, TimestampMixin):
|
||||
)
|
||||
|
||||
# AI moderation settings
|
||||
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
|
||||
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=80, nullable=False) # 0-100 scale
|
||||
ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||||
ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Verification settings
|
||||
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
@@ -120,6 +124,9 @@ class BannedWord(Base, TimestampMixin):
|
||||
String(20), default="delete", nullable=False
|
||||
) # delete, warn, strike
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
source: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
category: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
managed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Who added this and when
|
||||
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple, Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, NamedTuple, Sequence
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,6 +16,7 @@ else:
|
||||
try:
|
||||
import discord # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
|
||||
class _DiscordStub:
|
||||
class Message: # minimal stub for type hints
|
||||
pass
|
||||
@@ -26,120 +27,122 @@ from guardden.models.guild import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Circuit breaker for regex safety
|
||||
class RegexTimeoutError(Exception):
|
||||
"""Raised when regex execution takes too long."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RegexCircuitBreaker:
|
||||
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
||||
|
||||
|
||||
def __init__(self, timeout_seconds: float = 0.1):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.failed_patterns: dict[str, datetime] = {}
|
||||
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
||||
|
||||
|
||||
def _timeout_handler(self, signum, frame):
|
||||
"""Signal handler for regex timeout."""
|
||||
raise RegexTimeoutError("Regex execution timed out")
|
||||
|
||||
|
||||
def is_pattern_disabled(self, pattern: str) -> bool:
|
||||
"""Check if a pattern is temporarily disabled due to timeouts."""
|
||||
if pattern not in self.failed_patterns:
|
||||
return False
|
||||
|
||||
|
||||
failure_time = self.failed_patterns[pattern]
|
||||
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
||||
# Re-enable the pattern after threshold time
|
||||
del self.failed_patterns[pattern]
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
||||
"""Safely execute regex search with timeout protection."""
|
||||
if self.is_pattern_disabled(pattern):
|
||||
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
# Basic pattern validation to catch obviously problematic patterns
|
||||
if self._is_dangerous_pattern(pattern):
|
||||
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
old_handler = None
|
||||
try:
|
||||
# Set up timeout signal (Unix systems only)
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
if hasattr(signal, "SIGALRM"):
|
||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
||||
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
|
||||
# Compile and execute regex
|
||||
compiled_pattern = re.compile(pattern, flags)
|
||||
result = bool(compiled_pattern.search(text))
|
||||
|
||||
|
||||
execution_time = time.perf_counter() - start_time
|
||||
|
||||
|
||||
# Log slow patterns for monitoring
|
||||
if execution_time > self.timeout_seconds * 0.8:
|
||||
logger.warning(
|
||||
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except RegexTimeoutError:
|
||||
# Pattern took too long, disable it temporarily
|
||||
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
||||
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in regex execution: {e}")
|
||||
return False
|
||||
|
||||
|
||||
finally:
|
||||
# Clean up timeout signal
|
||||
if hasattr(signal, 'SIGALRM') and old_handler is not None:
|
||||
if hasattr(signal, "SIGALRM") and old_handler is not None:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
|
||||
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
||||
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
||||
# Check for patterns that are commonly problematic
|
||||
dangerous_indicators = [
|
||||
r'(\w+)+', # Nested quantifiers
|
||||
r'(\d+)+', # Nested quantifiers on digits
|
||||
r'(.+)+', # Nested quantifiers on anything
|
||||
r'(.*)+', # Nested quantifiers on anything (greedy)
|
||||
r'(\w*)+', # Nested quantifiers with *
|
||||
r'(\S+)+', # Nested quantifiers on non-whitespace
|
||||
r"(\w+)+", # Nested quantifiers
|
||||
r"(\d+)+", # Nested quantifiers on digits
|
||||
r"(.+)+", # Nested quantifiers on anything
|
||||
r"(.*)+", # Nested quantifiers on anything (greedy)
|
||||
r"(\w*)+", # Nested quantifiers with *
|
||||
r"(\S+)+", # Nested quantifiers on non-whitespace
|
||||
]
|
||||
|
||||
|
||||
# Check for excessively long patterns
|
||||
if len(pattern) > 500:
|
||||
return True
|
||||
|
||||
|
||||
# Check for nested quantifiers (simplified detection)
|
||||
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
|
||||
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
|
||||
return True
|
||||
|
||||
|
||||
# Check for excessive repetition operators
|
||||
if pattern.count('+') > 10 or pattern.count('*') > 10:
|
||||
if pattern.count("+") > 10 or pattern.count("*") > 10:
|
||||
return True
|
||||
|
||||
|
||||
# Check for specific dangerous patterns
|
||||
for dangerous in dangerous_indicators:
|
||||
if dangerous in pattern:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -240,34 +243,43 @@ def normalize_domain(value: str) -> str:
|
||||
"""Normalize a domain or URL for allowlist checks with security validation."""
|
||||
if not value or not isinstance(value, str):
|
||||
return ""
|
||||
|
||||
|
||||
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
|
||||
return ""
|
||||
|
||||
text = value.strip().lower()
|
||||
if not text or len(text) > 2000: # Prevent excessively long URLs
|
||||
return ""
|
||||
|
||||
# Sanitize input to prevent injection attacks
|
||||
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
if "://" not in text:
|
||||
text = f"http://{text}"
|
||||
|
||||
|
||||
parsed = urlparse(text)
|
||||
hostname = parsed.hostname or ""
|
||||
|
||||
|
||||
# Additional validation for hostname
|
||||
if not hostname or len(hostname) > 253: # RFC limit
|
||||
return ""
|
||||
|
||||
|
||||
# Check for malicious patterns
|
||||
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
|
||||
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
|
||||
return ""
|
||||
|
||||
|
||||
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
|
||||
return ""
|
||||
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
|
||||
return ""
|
||||
for label in hostname.split("."):
|
||||
if not label:
|
||||
return ""
|
||||
if label.startswith("-") or label.endswith("-"):
|
||||
return ""
|
||||
|
||||
# Remove www prefix
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
|
||||
|
||||
return hostname
|
||||
except (ValueError, UnicodeError, Exception):
|
||||
# urlparse can raise various exceptions with malicious input
|
||||
@@ -305,13 +317,13 @@ class AutomodService:
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
# Use simple string operations for basic patterns to avoid regex overhead
|
||||
normalized = content.lower()
|
||||
|
||||
|
||||
# Remove special characters (simplified approach)
|
||||
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
# Normalize whitespace
|
||||
normalized = ' '.join(normalized.split())
|
||||
|
||||
normalized = " ".join(normalized.split())
|
||||
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
@@ -369,14 +381,14 @@ class AutomodService:
|
||||
# Limit URL length to prevent processing extremely long URLs
|
||||
if len(url) > 2000:
|
||||
continue
|
||||
|
||||
|
||||
url_lower = url.lower()
|
||||
hostname = normalize_domain(url)
|
||||
|
||||
|
||||
# Skip if hostname normalization failed (security check)
|
||||
if not hostname:
|
||||
continue
|
||||
|
||||
|
||||
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
||||
continue
|
||||
|
||||
@@ -540,3 +552,11 @@ class AutomodService:
|
||||
def cleanup_guild(self, guild_id: int) -> None:
|
||||
"""Remove all tracking data for a guild."""
|
||||
self._spam_trackers.pop(guild_id, None)
|
||||
|
||||
|
||||
_automod_service = AutomodService()
|
||||
|
||||
|
||||
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
|
||||
"""Convenience wrapper for scam detection."""
|
||||
return _automod_service.check_scam_links(content, allowlist)
|
||||
|
||||
@@ -141,6 +141,9 @@ class GuildConfigService:
|
||||
is_regex: bool = False,
|
||||
action: str = "delete",
|
||||
reason: str | None = None,
|
||||
source: str | None = None,
|
||||
category: str | None = None,
|
||||
managed: bool = False,
|
||||
) -> BannedWord:
|
||||
"""Add a banned word to a guild."""
|
||||
async with self.database.session() as session:
|
||||
@@ -150,6 +153,9 @@ class GuildConfigService:
|
||||
is_regex=is_regex,
|
||||
action=action,
|
||||
reason=reason,
|
||||
source=source,
|
||||
category=category,
|
||||
managed=managed,
|
||||
added_by=added_by,
|
||||
)
|
||||
session.add(banned_word)
|
||||
|
||||
180
src/guardden/services/wordlist.py
Normal file
180
src/guardden/services/wordlist.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Managed wordlist sync service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Iterable
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from guardden.config import Settings, WordlistSourceConfig
|
||||
from guardden.models import BannedWord, Guild
|
||||
from guardden.services.database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_WORDLIST_ENTRY_LENGTH = 128
|
||||
REQUEST_TIMEOUT = 20.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WordlistSource:
|
||||
name: str
|
||||
url: str
|
||||
category: str
|
||||
action: str
|
||||
reason: str
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
DEFAULT_SOURCES: list[WordlistSource] = [
|
||||
WordlistSource(
|
||||
name="ldnoobw_en",
|
||||
url="https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en",
|
||||
category="soft",
|
||||
action="warn",
|
||||
reason="Auto list: profanity",
|
||||
is_regex=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _normalize_entry(line: str) -> str:
|
||||
text = line.strip().lower()
|
||||
if not text:
|
||||
return ""
|
||||
if len(text) > MAX_WORDLIST_ENTRY_LENGTH:
|
||||
return ""
|
||||
return text
|
||||
|
||||
|
||||
def _parse_wordlist(text: str) -> list[str]:
|
||||
entries: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in text.splitlines():
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("#") or line.startswith("//") or line.startswith(";"):
|
||||
continue
|
||||
normalized = _normalize_entry(line)
|
||||
if not normalized or normalized in seen:
|
||||
continue
|
||||
entries.append(normalized)
|
||||
seen.add(normalized)
|
||||
return entries
|
||||
|
||||
|
||||
class WordlistService:
|
||||
"""Fetches and syncs managed wordlists into per-guild bans."""
|
||||
|
||||
def __init__(self, database: Database, settings: Settings) -> None:
|
||||
self.database = database
|
||||
self.settings = settings
|
||||
self.sources = self._load_sources(settings)
|
||||
self.update_interval = timedelta(hours=settings.wordlist_update_hours)
|
||||
self.last_sync: datetime | None = None
|
||||
|
||||
@staticmethod
|
||||
def _load_sources(settings: Settings) -> list[WordlistSource]:
|
||||
if settings.wordlist_sources:
|
||||
sources: list[WordlistSource] = []
|
||||
for src in settings.wordlist_sources:
|
||||
if not src.enabled:
|
||||
continue
|
||||
sources.append(
|
||||
WordlistSource(
|
||||
name=src.name,
|
||||
url=src.url,
|
||||
category=src.category,
|
||||
action=src.action,
|
||||
reason=src.reason,
|
||||
is_regex=src.is_regex,
|
||||
)
|
||||
)
|
||||
return sources
|
||||
return list(DEFAULT_SOURCES)
|
||||
|
||||
async def _fetch_source(self, source: WordlistSource) -> list[str]:
|
||||
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||||
response = await client.get(source.url)
|
||||
response.raise_for_status()
|
||||
return _parse_wordlist(response.text)
|
||||
|
||||
async def sync_all(self) -> None:
|
||||
if not self.settings.wordlist_enabled:
|
||||
logger.info("Managed wordlist sync disabled")
|
||||
return
|
||||
if not self.sources:
|
||||
logger.warning("No wordlist sources configured")
|
||||
return
|
||||
|
||||
logger.info("Starting managed wordlist sync (%d sources)", len(self.sources))
|
||||
async with self.database.session() as session:
|
||||
guild_ids = list((await session.execute(select(Guild.id))).scalars().all())
|
||||
|
||||
for source in self.sources:
|
||||
try:
|
||||
entries = await self._fetch_source(source)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch wordlist %s: %s", source.name, exc)
|
||||
continue
|
||||
|
||||
if not entries:
|
||||
logger.warning("Wordlist %s returned no entries", source.name)
|
||||
continue
|
||||
|
||||
await self._sync_source_to_guilds(source, entries, guild_ids)
|
||||
|
||||
self.last_sync = datetime.now(timezone.utc)
|
||||
logger.info("Managed wordlist sync completed")
|
||||
|
||||
async def _sync_source_to_guilds(
|
||||
self, source: WordlistSource, entries: Iterable[str], guild_ids: list[int]
|
||||
) -> None:
|
||||
entry_set = set(entries)
|
||||
async with self.database.session() as session:
|
||||
for guild_id in guild_ids:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(
|
||||
BannedWord.guild_id == guild_id,
|
||||
BannedWord.managed.is_(True),
|
||||
BannedWord.source == source.name,
|
||||
)
|
||||
)
|
||||
existing = list(result.scalars().all())
|
||||
existing_set = {word.pattern.lower() for word in existing}
|
||||
|
||||
to_add = entry_set - existing_set
|
||||
to_remove = existing_set - entry_set
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(BannedWord).where(
|
||||
BannedWord.guild_id == guild_id,
|
||||
BannedWord.managed.is_(True),
|
||||
BannedWord.source == source.name,
|
||||
BannedWord.pattern.in_(to_remove),
|
||||
)
|
||||
)
|
||||
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
BannedWord(
|
||||
guild_id=guild_id,
|
||||
pattern=pattern,
|
||||
is_regex=source.is_regex,
|
||||
action=source.action,
|
||||
reason=source.reason,
|
||||
source=source.name,
|
||||
category=source.category,
|
||||
managed=True,
|
||||
added_by=0,
|
||||
)
|
||||
for pattern in to_add
|
||||
]
|
||||
)
|
||||
Reference in New Issue
Block a user