Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Failing after 4m58s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m0s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""Configuration management for GuardDen."""
|
|
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
from pydantic_settings.sources import EnvSettingsSource
|
|
|
|
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
|
|
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
|
|
|
|
|
|
def _validate_discord_id(value: str | int) -> int:
|
|
"""Validate a Discord snowflake ID."""
|
|
if isinstance(value, int):
|
|
id_str = str(value)
|
|
else:
|
|
id_str = str(value).strip()
|
|
|
|
# Check format
|
|
if not DISCORD_ID_PATTERN.match(id_str):
|
|
raise ValueError(f"Invalid Discord ID format: {id_str}")
|
|
|
|
# Convert to int and validate range
|
|
discord_id = int(id_str)
|
|
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
|
|
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
|
|
raise ValueError(f"Discord ID out of valid range: {discord_id}")
|
|
|
|
return discord_id
|
|
|
|
|
|
def _parse_id_list(value: Any) -> list[int]:
|
|
"""Parse an environment value into a list of valid Discord IDs."""
|
|
if value is None:
|
|
return []
|
|
|
|
items: list[Any]
|
|
if isinstance(value, list):
|
|
items = value
|
|
elif isinstance(value, str):
|
|
text = value.strip()
|
|
if not text:
|
|
return []
|
|
# Only allow comma or semicolon separated values, no JSON parsing for security
|
|
items = [part.strip() for part in text.replace(";", ",").split(",") if part.strip()]
|
|
else:
|
|
items = [value]
|
|
|
|
parsed: list[int] = []
|
|
seen: set[int] = set()
|
|
for item in items:
|
|
try:
|
|
discord_id = _validate_discord_id(item)
|
|
if discord_id not in seen:
|
|
parsed.append(discord_id)
|
|
seen.add(discord_id)
|
|
except (ValueError, TypeError):
|
|
# Skip invalid IDs rather than failing silently
|
|
continue
|
|
|
|
return parsed
|
|
|
|
|
|
class GuardDenEnvSettingsSource(EnvSettingsSource):
|
|
"""Environment settings source with safe list parsing."""
|
|
|
|
def decode_complex_value(self, field_name: str, field, value: Any):
|
|
if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str):
|
|
return value
|
|
return super().decode_complex_value(field_name, field, value)
|
|
|
|
|
|
class WordlistSourceConfig(BaseModel):
|
|
"""Configuration for a managed wordlist source."""
|
|
|
|
name: str
|
|
url: str
|
|
category: Literal["hard", "soft", "context"]
|
|
action: Literal["delete", "warn", "strike"]
|
|
reason: str
|
|
is_regex: bool = False
|
|
enabled: bool = True
|
|
|
|
|
|
class GuildDefaults(BaseModel):
|
|
"""Default values for new guild settings (configurable via env).
|
|
|
|
These values are used when creating a new guild configuration.
|
|
Override via environment variables with GUARDDEN_GUILD_DEFAULT_ prefix.
|
|
Example: GUARDDEN_GUILD_DEFAULT_PREFIX=? sets the default prefix to "?"
|
|
"""
|
|
|
|
prefix: str = Field(default="!", min_length=1, max_length=10)
|
|
locale: str = Field(default="en", min_length=2, max_length=10)
|
|
automod_enabled: bool = True
|
|
anti_spam_enabled: bool = True
|
|
link_filter_enabled: bool = False
|
|
message_rate_limit: int = Field(default=5, ge=1)
|
|
message_rate_window: int = Field(default=5, ge=1)
|
|
duplicate_threshold: int = Field(default=3, ge=1)
|
|
mention_limit: int = Field(default=5, ge=1)
|
|
mention_rate_limit: int = Field(default=10, ge=1)
|
|
mention_rate_window: int = Field(default=60, ge=1)
|
|
ai_moderation_enabled: bool = True
|
|
ai_sensitivity: int = Field(default=80, ge=0, le=100)
|
|
ai_confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
|
|
ai_log_only: bool = False
|
|
nsfw_detection_enabled: bool = True
|
|
verification_enabled: bool = False
|
|
verification_type: Literal["button", "captcha", "math", "emoji"] = "button"
|
|
strike_actions: dict = Field(
|
|
default_factory=lambda: {
|
|
"1": {"action": "warn"},
|
|
"3": {"action": "timeout", "duration": 3600},
|
|
"5": {"action": "kick"},
|
|
"7": {"action": "ban"},
|
|
}
|
|
)
|
|
scam_allowlist: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Application settings loaded from environment variables."""
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
case_sensitive=False,
|
|
env_prefix="GUARDDEN_",
|
|
env_parse_none_str="",
|
|
env_nested_delimiter="_",
|
|
)
|
|
|
|
@classmethod
|
|
def settings_customise_sources(
|
|
cls,
|
|
settings_cls,
|
|
init_settings,
|
|
env_settings,
|
|
dotenv_settings,
|
|
file_secret_settings,
|
|
):
|
|
return (
|
|
init_settings,
|
|
GuardDenEnvSettingsSource(settings_cls),
|
|
dotenv_settings,
|
|
file_secret_settings,
|
|
)
|
|
|
|
# Discord settings
|
|
discord_token: SecretStr = Field(..., description="Discord bot token")
|
|
discord_prefix: str = Field(default="!", description="Default command prefix")
|
|
|
|
# Database settings
|
|
database_url: SecretStr = Field(
|
|
default=SecretStr("postgresql://guardden:guardden@localhost:5432/guardden"),
|
|
description="PostgreSQL connection URL",
|
|
)
|
|
database_pool_min: int = Field(default=5, description="Minimum database pool size")
|
|
database_pool_max: int = Field(default=20, description="Maximum database pool size")
|
|
|
|
# AI settings (optional)
|
|
ai_provider: Literal["anthropic", "openai", "none"] = Field(
|
|
default="none", description="AI provider for content moderation"
|
|
)
|
|
anthropic_api_key: SecretStr | None = Field(default=None, description="Anthropic API key")
|
|
openai_api_key: SecretStr | None = Field(default=None, description="OpenAI API key")
|
|
|
|
# Logging
|
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
|
default="INFO", description="Logging level"
|
|
)
|
|
log_json: bool = Field(default=False, description="Use JSON structured logging format")
|
|
log_file: str | None = Field(default=None, description="Log file path (optional)")
|
|
|
|
# Access control
|
|
allowed_guilds: list[int] = Field(
|
|
default_factory=list,
|
|
description="Guild IDs the bot is allowed to join (empty = allow all)",
|
|
)
|
|
owner_ids: list[int] = Field(
|
|
default_factory=list,
|
|
description="Owner user IDs with elevated access (empty = allow admins)",
|
|
)
|
|
|
|
# Paths
|
|
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
|
|
|
# Wordlist sync
|
|
wordlist_enabled: bool = Field(
|
|
default=True, description="Enable automatic managed wordlist syncing"
|
|
)
|
|
wordlist_update_hours: int = Field(
|
|
default=168, description="Managed wordlist sync interval in hours"
|
|
)
|
|
wordlist_sources: list[WordlistSourceConfig] = Field(
|
|
default_factory=list,
|
|
description="Managed wordlist sources (JSON array via env overrides)",
|
|
)
|
|
|
|
# Guild defaults (used when creating new guild configurations)
|
|
guild_default: GuildDefaults = Field(
|
|
default_factory=GuildDefaults,
|
|
description="Default values for new guild settings",
|
|
)
|
|
|
|
@field_validator("allowed_guilds", "owner_ids", mode="before")
|
|
@classmethod
|
|
def _validate_id_list(cls, value: Any) -> list[int]:
|
|
return _parse_id_list(value)
|
|
|
|
@field_validator("wordlist_sources", mode="before")
|
|
@classmethod
|
|
def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]:
|
|
if value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
return [WordlistSourceConfig.model_validate(item) for item in value]
|
|
if isinstance(value, str):
|
|
text = value.strip()
|
|
if not text:
|
|
return []
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError as exc:
|
|
raise ValueError("Invalid JSON for wordlist_sources") from exc
|
|
if not isinstance(data, list):
|
|
raise ValueError("wordlist_sources must be a JSON array")
|
|
return [WordlistSourceConfig.model_validate(item) for item in data]
|
|
return []
|
|
|
|
@field_validator("discord_token")
|
|
@classmethod
|
|
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
|
|
"""Validate Discord bot token format."""
|
|
token = value.get_secret_value()
|
|
if not token:
|
|
raise ValueError("Discord token cannot be empty")
|
|
|
|
# Basic Discord token format validation (not perfect but catches common issues)
|
|
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
|
|
raise ValueError("Invalid Discord token format")
|
|
|
|
return value
|
|
|
|
@field_validator("anthropic_api_key", "openai_api_key")
|
|
@classmethod
|
|
def _validate_api_key(cls, value: SecretStr | None) -> SecretStr | None:
|
|
"""Validate API key format if provided."""
|
|
if value is None:
|
|
return None
|
|
|
|
key = value.get_secret_value()
|
|
if not key:
|
|
return None
|
|
|
|
# Basic API key validation
|
|
if len(key) < 20:
|
|
raise ValueError("API key too short to be valid")
|
|
|
|
return value
|
|
|
|
def validate_configuration(self) -> None:
|
|
"""Validate the settings for runtime usage."""
|
|
# AI provider validation
|
|
if self.ai_provider == "anthropic" and not self.anthropic_api_key:
|
|
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
|
|
if self.ai_provider == "openai" and not self.openai_api_key:
|
|
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
|
|
|
|
# Database pool validation
|
|
if self.database_pool_min > self.database_pool_max:
|
|
raise ValueError("database_pool_min cannot be greater than database_pool_max")
|
|
if self.database_pool_min < 1:
|
|
raise ValueError("database_pool_min must be at least 1")
|
|
|
|
# Data directory validation
|
|
if not isinstance(self.data_dir, Path):
|
|
raise ValueError("data_dir must be a valid path")
|
|
|
|
# Wordlist validation
|
|
if self.wordlist_update_hours < 1:
|
|
raise ValueError("wordlist_update_hours must be at least 1")
|
|
|
|
|
|
def get_settings() -> Settings:
|
|
"""Get application settings instance."""
|
|
settings = Settings()
|
|
settings.validate_configuration()
|
|
return settings
|