update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s

This commit is contained in:
2026-01-17 21:57:04 +01:00
parent 831eed8dbc
commit abef368a68
19 changed files with 677 additions and 757 deletions

View File

@@ -42,6 +42,7 @@ class GuardDen(commands.Bot):
self.database = Database(settings)
self.guild_config: "GuildConfigService | None" = None
self.ai_provider: AIProvider | None = None
self.wordlist_service = None
self.rate_limiter = RateLimiter()
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
@@ -90,6 +91,9 @@ class GuardDen(commands.Bot):
from guardden.services.guild_config import GuildConfigService
self.guild_config = GuildConfigService(self.database)
from guardden.services.wordlist import WordlistService
self.wordlist_service = WordlistService(self.database, self.settings)
# Initialize AI provider
api_key = None
@@ -115,6 +119,7 @@ class GuardDen(commands.Bot):
"guardden.cogs.ai_moderation",
"guardden.cogs.verification",
"guardden.cogs.health",
"guardden.cogs.wordlist_sync",
]
failed_cogs = []
@@ -131,7 +136,7 @@ class GuardDen(commands.Bot):
except Exception as e:
logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True)
failed_cogs.append(cog)
if failed_cogs:
logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}")
# Don't fail startup if some cogs fail to load, but log it prominently
@@ -146,7 +151,7 @@ class GuardDen(commands.Bot):
if self.guild_config:
initialized = 0
failed_guilds = []
for guild in self.guilds:
try:
if not self.is_guild_allowed(guild.id):
@@ -162,12 +167,17 @@ class GuardDen(commands.Bot):
await self.guild_config.create_guild(guild)
initialized += 1
except Exception as e:
logger.error(f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", exc_info=True)
logger.error(
f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}",
exc_info=True,
)
failed_guilds.append(guild.id)
logger.info("Initialized config for %s guild(s)", initialized)
if failed_guilds:
logger.warning(f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}")
logger.warning(
f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}"
)
# Set presence
activity = discord.Activity(
@@ -206,9 +216,7 @@ class GuardDen(commands.Bot):
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
if not self.is_guild_allowed(guild.id):
logger.warning(
"Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id
)
logger.warning("Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id)
await guild.leave()
return

View File

@@ -0,0 +1,38 @@
"""Background task for managed wordlist syncing."""
import logging
from discord.ext import commands, tasks
from guardden.services.wordlist import WordlistService
logger = logging.getLogger(__name__)
class WordlistSync(commands.Cog):
"""Periodic sync of managed wordlists into guild bans."""
def __init__(self, bot: commands.Bot, service: WordlistService) -> None:
self.bot = bot
self.service = service
self.sync_task.change_interval(hours=service.update_interval.total_seconds() / 3600)
self.sync_task.start()
def cog_unload(self) -> None:
self.sync_task.cancel()
@tasks.loop(hours=1)
async def sync_task(self) -> None:
await self.service.sync_all()
@sync_task.before_loop
async def before_sync_task(self) -> None:
await self.bot.wait_until_ready()
async def setup(bot: commands.Bot) -> None:
service = getattr(bot, "wordlist_service", None)
if not service:
logger.warning("Wordlist service not initialized; skipping sync task")
return
await bot.add_cog(WordlistSync(bot, service))

View File

@@ -5,9 +5,9 @@ import re
from pathlib import Path
from typing import Any, Literal
from pydantic import Field, SecretStr, field_validator, ValidationError
from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings.sources import EnvSettingsSource
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
@@ -19,17 +19,17 @@ def _validate_discord_id(value: str | int) -> int:
id_str = str(value)
else:
id_str = str(value).strip()
# Check format
if not DISCORD_ID_PATTERN.match(id_str):
raise ValueError(f"Invalid Discord ID format: {id_str}")
# Convert to int and validate range
discord_id = int(id_str)
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
raise ValueError(f"Discord ID out of valid range: {discord_id}")
return discord_id
@@ -65,6 +65,27 @@ def _parse_id_list(value: Any) -> list[int]:
return parsed
class GuardDenEnvSettingsSource(EnvSettingsSource):
"""Environment settings source with safe list parsing."""
def decode_complex_value(self, field_name: str, field, value: Any):
if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str):
return value
return super().decode_complex_value(field_name, field, value)
class WordlistSourceConfig(BaseModel):
"""Configuration for a managed wordlist source."""
name: str
url: str
category: Literal["hard", "soft", "context"]
action: Literal["delete", "warn", "strike"]
reason: str
is_regex: bool = False
enabled: bool = True
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
@@ -73,8 +94,25 @@ class Settings(BaseSettings):
env_file_encoding="utf-8",
case_sensitive=False,
env_prefix="GUARDDEN_",
env_parse_none_str="",
)
@classmethod
def settings_customise_sources(
cls,
settings_cls,
init_settings,
env_settings,
dotenv_settings,
file_secret_settings,
):
return (
init_settings,
GuardDenEnvSettingsSource(settings_cls),
dotenv_settings,
file_secret_settings,
)
# Discord settings
discord_token: SecretStr = Field(..., description="Discord bot token")
discord_prefix: str = Field(default="!", description="Default command prefix")
@@ -114,11 +152,43 @@ class Settings(BaseSettings):
# Paths
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
# Wordlist sync
wordlist_enabled: bool = Field(
default=True, description="Enable automatic managed wordlist syncing"
)
wordlist_update_hours: int = Field(
default=168, description="Managed wordlist sync interval in hours"
)
wordlist_sources: list[WordlistSourceConfig] = Field(
default_factory=list,
description="Managed wordlist sources (JSON array via env overrides)",
)
@field_validator("allowed_guilds", "owner_ids", mode="before")
@classmethod
def _validate_id_list(cls, value: Any) -> list[int]:
return _parse_id_list(value)
@field_validator("wordlist_sources", mode="before")
@classmethod
def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]:
if value is None:
return []
if isinstance(value, list):
return [WordlistSourceConfig.model_validate(item) for item in value]
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
data = json.loads(text)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON for wordlist_sources") from exc
if not isinstance(data, list):
raise ValueError("wordlist_sources must be a JSON array")
return [WordlistSourceConfig.model_validate(item) for item in data]
return []
@field_validator("discord_token")
@classmethod
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
@@ -126,11 +196,11 @@ class Settings(BaseSettings):
token = value.get_secret_value()
if not token:
raise ValueError("Discord token cannot be empty")
# Basic Discord token format validation (not perfect but catches common issues)
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
raise ValueError("Invalid Discord token format")
return value
@field_validator("anthropic_api_key", "openai_api_key")
@@ -139,15 +209,15 @@ class Settings(BaseSettings):
"""Validate API key format if provided."""
if value is None:
return None
key = value.get_secret_value()
if not key:
return None
# Basic API key validation
if len(key) < 20:
raise ValueError("API key too short to be valid")
return value
def validate_configuration(self) -> None:
@@ -157,17 +227,21 @@ class Settings(BaseSettings):
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
if self.ai_provider == "openai" and not self.openai_api_key:
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
# Database pool validation
if self.database_pool_min > self.database_pool_max:
raise ValueError("database_pool_min cannot be greater than database_pool_max")
if self.database_pool_min < 1:
raise ValueError("database_pool_min must be at least 1")
# Data directory validation
if not isinstance(self.data_dir, Path):
raise ValueError("data_dir must be a valid path")
# Wordlist validation
if self.wordlist_update_hours < 1:
raise ValueError("wordlist_update_hours must be at least 1")
def get_settings() -> Settings:
"""Get application settings instance."""

View File

@@ -0,0 +1,16 @@
"""Dashboard entrypoint for `python -m guardden.dashboard`."""
import os
import uvicorn
def main() -> None:
host = os.getenv("GUARDDEN_DASHBOARD_HOST", "0.0.0.0")
port = int(os.getenv("GUARDDEN_DASHBOARD_PORT", "8000"))
log_level = os.getenv("GUARDDEN_LOG_LEVEL", "info").lower()
uvicorn.run("guardden.dashboard.main:app", host=host, port=port, log_level=log_level)
if __name__ == "__main__":
main()

View File

@@ -3,7 +3,7 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text
from sqlalchemy import JSON, Boolean, Float, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -59,7 +59,9 @@ class GuildSettings(Base, TimestampMixin):
# Role configuration
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False)
mod_role_ids: Mapped[dict] = mapped_column(
JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False
)
# Moderation settings
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
@@ -73,11 +75,13 @@ class GuildSettings(Base, TimestampMixin):
mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False)
mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
scam_allowlist: Mapped[list[str]] = mapped_column(JSONB, default=list, nullable=False)
scam_allowlist: Mapped[list[str]] = mapped_column(
JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False
)
# Strike thresholds (actions at each threshold)
strike_actions: Mapped[dict] = mapped_column(
JSONB,
JSONB().with_variant(JSON(), "sqlite"),
default=lambda: {
"1": {"action": "warn"},
"3": {"action": "timeout", "duration": 3600},
@@ -88,11 +92,11 @@ class GuildSettings(Base, TimestampMixin):
)
# AI moderation settings
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=80, nullable=False) # 0-100 scale
ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
# Verification settings
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
@@ -120,6 +124,9 @@ class BannedWord(Base, TimestampMixin):
String(20), default="delete", nullable=False
) # delete, warn, strike
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
source: Mapped[str | None] = mapped_column(String(100), nullable=True)
category: Mapped[str | None] = mapped_column(String(20), nullable=True)
managed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
# Who added this and when
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)

View File

@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import NamedTuple, Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple, Sequence
from urllib.parse import urlparse
if TYPE_CHECKING:
@@ -16,6 +16,7 @@ else:
try:
import discord # type: ignore
except ModuleNotFoundError: # pragma: no cover
class _DiscordStub:
class Message: # minimal stub for type hints
pass
@@ -26,120 +27,122 @@ from guardden.models.guild import BannedWord
logger = logging.getLogger(__name__)
# Circuit breaker for regex safety
class RegexTimeoutError(Exception):
"""Raised when regex execution takes too long."""
pass
class RegexCircuitBreaker:
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
def __init__(self, timeout_seconds: float = 0.1):
self.timeout_seconds = timeout_seconds
self.failed_patterns: dict[str, datetime] = {}
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
def _timeout_handler(self, signum, frame):
"""Signal handler for regex timeout."""
raise RegexTimeoutError("Regex execution timed out")
def is_pattern_disabled(self, pattern: str) -> bool:
"""Check if a pattern is temporarily disabled due to timeouts."""
if pattern not in self.failed_patterns:
return False
failure_time = self.failed_patterns[pattern]
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
# Re-enable the pattern after threshold time
del self.failed_patterns[pattern]
return False
return True
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
"""Safely execute regex search with timeout protection."""
if self.is_pattern_disabled(pattern):
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
return False
# Basic pattern validation to catch obviously problematic patterns
if self._is_dangerous_pattern(pattern):
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
return False
old_handler = None
try:
# Set up timeout signal (Unix systems only)
if hasattr(signal, 'SIGALRM'):
if hasattr(signal, "SIGALRM"):
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
start_time = time.perf_counter()
# Compile and execute regex
compiled_pattern = re.compile(pattern, flags)
result = bool(compiled_pattern.search(text))
execution_time = time.perf_counter() - start_time
# Log slow patterns for monitoring
if execution_time > self.timeout_seconds * 0.8:
logger.warning(
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
)
return result
except RegexTimeoutError:
# Pattern took too long, disable it temporarily
self.failed_patterns[pattern] = datetime.now(timezone.utc)
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
return False
except re.error as e:
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
return False
except Exception as e:
logger.error(f"Unexpected error in regex execution: {e}")
return False
finally:
# Clean up timeout signal
if hasattr(signal, 'SIGALRM') and old_handler is not None:
if hasattr(signal, "SIGALRM") and old_handler is not None:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
def _is_dangerous_pattern(self, pattern: str) -> bool:
"""Basic heuristic to detect potentially dangerous regex patterns."""
# Check for patterns that are commonly problematic
dangerous_indicators = [
r'(\w+)+', # Nested quantifiers
r'(\d+)+', # Nested quantifiers on digits
r'(.+)+', # Nested quantifiers on anything
r'(.*)+', # Nested quantifiers on anything (greedy)
r'(\w*)+', # Nested quantifiers with *
r'(\S+)+', # Nested quantifiers on non-whitespace
r"(\w+)+", # Nested quantifiers
r"(\d+)+", # Nested quantifiers on digits
r"(.+)+", # Nested quantifiers on anything
r"(.*)+", # Nested quantifiers on anything (greedy)
r"(\w*)+", # Nested quantifiers with *
r"(\S+)+", # Nested quantifiers on non-whitespace
]
# Check for excessively long patterns
if len(pattern) > 500:
return True
# Check for nested quantifiers (simplified detection)
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
return True
# Check for excessive repetition operators
if pattern.count('+') > 10 or pattern.count('*') > 10:
if pattern.count("+") > 10 or pattern.count("*") > 10:
return True
# Check for specific dangerous patterns
for dangerous in dangerous_indicators:
if dangerous in pattern:
return True
return False
@@ -240,34 +243,43 @@ def normalize_domain(value: str) -> str:
"""Normalize a domain or URL for allowlist checks with security validation."""
if not value or not isinstance(value, str):
return ""
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
return ""
text = value.strip().lower()
if not text or len(text) > 2000: # Prevent excessively long URLs
return ""
# Sanitize input to prevent injection attacks
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
return ""
try:
if "://" not in text:
text = f"http://{text}"
parsed = urlparse(text)
hostname = parsed.hostname or ""
# Additional validation for hostname
if not hostname or len(hostname) > 253: # RFC limit
return ""
# Check for malicious patterns
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
return ""
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
return ""
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
return ""
for label in hostname.split("."):
if not label:
return ""
if label.startswith("-") or label.endswith("-"):
return ""
# Remove www prefix
if hostname.startswith("www."):
hostname = hostname[4:]
return hostname
except (ValueError, UnicodeError, Exception):
# urlparse can raise various exceptions with malicious input
@@ -305,13 +317,13 @@ class AutomodService:
# Normalize: lowercase, remove extra spaces, remove special chars
# Use simple string operations for basic patterns to avoid regex overhead
normalized = content.lower()
# Remove special characters (simplified approach)
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
# Normalize whitespace
normalized = ' '.join(normalized.split())
normalized = " ".join(normalized.split())
return normalized
def check_banned_words(
@@ -369,14 +381,14 @@ class AutomodService:
# Limit URL length to prevent processing extremely long URLs
if len(url) > 2000:
continue
url_lower = url.lower()
hostname = normalize_domain(url)
# Skip if hostname normalization failed (security check)
if not hostname:
continue
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
continue
@@ -540,3 +552,11 @@ class AutomodService:
def cleanup_guild(self, guild_id: int) -> None:
"""Remove all tracking data for a guild."""
self._spam_trackers.pop(guild_id, None)
_automod_service = AutomodService()
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
"""Convenience wrapper for scam detection."""
return _automod_service.check_scam_links(content, allowlist)

View File

@@ -141,6 +141,9 @@ class GuildConfigService:
is_regex: bool = False,
action: str = "delete",
reason: str | None = None,
source: str | None = None,
category: str | None = None,
managed: bool = False,
) -> BannedWord:
"""Add a banned word to a guild."""
async with self.database.session() as session:
@@ -150,6 +153,9 @@ class GuildConfigService:
is_regex=is_regex,
action=action,
reason=reason,
source=source,
category=category,
managed=managed,
added_by=added_by,
)
session.add(banned_word)

View File

@@ -0,0 +1,180 @@
"""Managed wordlist sync service."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Iterable
import httpx
from sqlalchemy import delete, select
from guardden.config import Settings, WordlistSourceConfig
from guardden.models import BannedWord, Guild
from guardden.services.database import Database
logger = logging.getLogger(__name__)
MAX_WORDLIST_ENTRY_LENGTH = 128
REQUEST_TIMEOUT = 20.0
@dataclass(frozen=True)
class WordlistSource:
name: str
url: str
category: str
action: str
reason: str
is_regex: bool = False
DEFAULT_SOURCES: list[WordlistSource] = [
WordlistSource(
name="ldnoobw_en",
url="https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en",
category="soft",
action="warn",
reason="Auto list: profanity",
is_regex=False,
),
]
def _normalize_entry(line: str) -> str:
text = line.strip().lower()
if not text:
return ""
if len(text) > MAX_WORDLIST_ENTRY_LENGTH:
return ""
return text
def _parse_wordlist(text: str) -> list[str]:
entries: list[str] = []
seen: set[str] = set()
for raw in text.splitlines():
line = raw.strip()
if not line:
continue
if line.startswith("#") or line.startswith("//") or line.startswith(";"):
continue
normalized = _normalize_entry(line)
if not normalized or normalized in seen:
continue
entries.append(normalized)
seen.add(normalized)
return entries
class WordlistService:
"""Fetches and syncs managed wordlists into per-guild bans."""
def __init__(self, database: Database, settings: Settings) -> None:
self.database = database
self.settings = settings
self.sources = self._load_sources(settings)
self.update_interval = timedelta(hours=settings.wordlist_update_hours)
self.last_sync: datetime | None = None
@staticmethod
def _load_sources(settings: Settings) -> list[WordlistSource]:
if settings.wordlist_sources:
sources: list[WordlistSource] = []
for src in settings.wordlist_sources:
if not src.enabled:
continue
sources.append(
WordlistSource(
name=src.name,
url=src.url,
category=src.category,
action=src.action,
reason=src.reason,
is_regex=src.is_regex,
)
)
return sources
return list(DEFAULT_SOURCES)
async def _fetch_source(self, source: WordlistSource) -> list[str]:
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
response = await client.get(source.url)
response.raise_for_status()
return _parse_wordlist(response.text)
async def sync_all(self) -> None:
if not self.settings.wordlist_enabled:
logger.info("Managed wordlist sync disabled")
return
if not self.sources:
logger.warning("No wordlist sources configured")
return
logger.info("Starting managed wordlist sync (%d sources)", len(self.sources))
async with self.database.session() as session:
guild_ids = list((await session.execute(select(Guild.id))).scalars().all())
for source in self.sources:
try:
entries = await self._fetch_source(source)
except Exception as exc:
logger.error("Failed to fetch wordlist %s: %s", source.name, exc)
continue
if not entries:
logger.warning("Wordlist %s returned no entries", source.name)
continue
await self._sync_source_to_guilds(source, entries, guild_ids)
self.last_sync = datetime.now(timezone.utc)
logger.info("Managed wordlist sync completed")
async def _sync_source_to_guilds(
self, source: WordlistSource, entries: Iterable[str], guild_ids: list[int]
) -> None:
entry_set = set(entries)
async with self.database.session() as session:
for guild_id in guild_ids:
result = await session.execute(
select(BannedWord).where(
BannedWord.guild_id == guild_id,
BannedWord.managed.is_(True),
BannedWord.source == source.name,
)
)
existing = list(result.scalars().all())
existing_set = {word.pattern.lower() for word in existing}
to_add = entry_set - existing_set
to_remove = existing_set - entry_set
if to_remove:
await session.execute(
delete(BannedWord).where(
BannedWord.guild_id == guild_id,
BannedWord.managed.is_(True),
BannedWord.source == source.name,
BannedWord.pattern.in_(to_remove),
)
)
if to_add:
session.add_all(
[
BannedWord(
guild_id=guild_id,
pattern=pattern,
is_regex=source.is_regex,
action=source.action,
reason=source.reason,
source=source.name,
category=source.category,
managed=True,
added_by=0,
)
for pattern in to_add
]
)