Files
GuardDen/src/guardden/config.py
latte 824dd681f7
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Failing after 4m58s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m0s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
quick update
2026-01-24 19:14:33 +01:00

295 lines
10 KiB
Python

"""Configuration management for GuardDen."""
import json
import re
from pathlib import Path
from typing import Any, Literal
from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings.sources import EnvSettingsSource
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
def _validate_discord_id(value: str | int) -> int:
"""Validate a Discord snowflake ID."""
if isinstance(value, int):
id_str = str(value)
else:
id_str = str(value).strip()
# Check format
if not DISCORD_ID_PATTERN.match(id_str):
raise ValueError(f"Invalid Discord ID format: {id_str}")
# Convert to int and validate range
discord_id = int(id_str)
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
raise ValueError(f"Discord ID out of valid range: {discord_id}")
return discord_id
def _parse_id_list(value: Any) -> list[int]:
"""Parse an environment value into a list of valid Discord IDs."""
if value is None:
return []
items: list[Any]
if isinstance(value, list):
items = value
elif isinstance(value, str):
text = value.strip()
if not text:
return []
# Only allow comma or semicolon separated values, no JSON parsing for security
items = [part.strip() for part in text.replace(";", ",").split(",") if part.strip()]
else:
items = [value]
parsed: list[int] = []
seen: set[int] = set()
for item in items:
try:
discord_id = _validate_discord_id(item)
if discord_id not in seen:
parsed.append(discord_id)
seen.add(discord_id)
except (ValueError, TypeError):
# Skip invalid IDs rather than failing silently
continue
return parsed
class GuardDenEnvSettingsSource(EnvSettingsSource):
"""Environment settings source with safe list parsing."""
def decode_complex_value(self, field_name: str, field, value: Any):
if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str):
return value
return super().decode_complex_value(field_name, field, value)
class WordlistSourceConfig(BaseModel):
"""Configuration for a managed wordlist source."""
name: str
url: str
category: Literal["hard", "soft", "context"]
action: Literal["delete", "warn", "strike"]
reason: str
is_regex: bool = False
enabled: bool = True
class GuildDefaults(BaseModel):
"""Default values for new guild settings (configurable via env).
These values are used when creating a new guild configuration.
Override via environment variables with GUARDDEN_GUILD_DEFAULT_ prefix.
Example: GUARDDEN_GUILD_DEFAULT_PREFIX=? sets the default prefix to "?"
"""
prefix: str = Field(default="!", min_length=1, max_length=10)
locale: str = Field(default="en", min_length=2, max_length=10)
automod_enabled: bool = True
anti_spam_enabled: bool = True
link_filter_enabled: bool = False
message_rate_limit: int = Field(default=5, ge=1)
message_rate_window: int = Field(default=5, ge=1)
duplicate_threshold: int = Field(default=3, ge=1)
mention_limit: int = Field(default=5, ge=1)
mention_rate_limit: int = Field(default=10, ge=1)
mention_rate_window: int = Field(default=60, ge=1)
ai_moderation_enabled: bool = True
ai_sensitivity: int = Field(default=80, ge=0, le=100)
ai_confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
ai_log_only: bool = False
nsfw_detection_enabled: bool = True
verification_enabled: bool = False
verification_type: Literal["button", "captcha", "math", "emoji"] = "button"
strike_actions: dict = Field(
default_factory=lambda: {
"1": {"action": "warn"},
"3": {"action": "timeout", "duration": 3600},
"5": {"action": "kick"},
"7": {"action": "ban"},
}
)
scam_allowlist: list[str] = Field(default_factory=list)
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
env_prefix="GUARDDEN_",
env_parse_none_str="",
env_nested_delimiter="_",
)
@classmethod
def settings_customise_sources(
cls,
settings_cls,
init_settings,
env_settings,
dotenv_settings,
file_secret_settings,
):
return (
init_settings,
GuardDenEnvSettingsSource(settings_cls),
dotenv_settings,
file_secret_settings,
)
# Discord settings
discord_token: SecretStr = Field(..., description="Discord bot token")
discord_prefix: str = Field(default="!", description="Default command prefix")
# Database settings
database_url: SecretStr = Field(
default=SecretStr("postgresql://guardden:guardden@localhost:5432/guardden"),
description="PostgreSQL connection URL",
)
database_pool_min: int = Field(default=5, description="Minimum database pool size")
database_pool_max: int = Field(default=20, description="Maximum database pool size")
# AI settings (optional)
ai_provider: Literal["anthropic", "openai", "none"] = Field(
default="none", description="AI provider for content moderation"
)
anthropic_api_key: SecretStr | None = Field(default=None, description="Anthropic API key")
openai_api_key: SecretStr | None = Field(default=None, description="OpenAI API key")
# Logging
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
default="INFO", description="Logging level"
)
log_json: bool = Field(default=False, description="Use JSON structured logging format")
log_file: str | None = Field(default=None, description="Log file path (optional)")
# Access control
allowed_guilds: list[int] = Field(
default_factory=list,
description="Guild IDs the bot is allowed to join (empty = allow all)",
)
owner_ids: list[int] = Field(
default_factory=list,
description="Owner user IDs with elevated access (empty = allow admins)",
)
# Paths
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
# Wordlist sync
wordlist_enabled: bool = Field(
default=True, description="Enable automatic managed wordlist syncing"
)
wordlist_update_hours: int = Field(
default=168, description="Managed wordlist sync interval in hours"
)
wordlist_sources: list[WordlistSourceConfig] = Field(
default_factory=list,
description="Managed wordlist sources (JSON array via env overrides)",
)
# Guild defaults (used when creating new guild configurations)
guild_default: GuildDefaults = Field(
default_factory=GuildDefaults,
description="Default values for new guild settings",
)
@field_validator("allowed_guilds", "owner_ids", mode="before")
@classmethod
def _validate_id_list(cls, value: Any) -> list[int]:
return _parse_id_list(value)
@field_validator("wordlist_sources", mode="before")
@classmethod
def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]:
if value is None:
return []
if isinstance(value, list):
return [WordlistSourceConfig.model_validate(item) for item in value]
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
data = json.loads(text)
except json.JSONDecodeError as exc:
raise ValueError("Invalid JSON for wordlist_sources") from exc
if not isinstance(data, list):
raise ValueError("wordlist_sources must be a JSON array")
return [WordlistSourceConfig.model_validate(item) for item in data]
return []
@field_validator("discord_token")
@classmethod
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
"""Validate Discord bot token format."""
token = value.get_secret_value()
if not token:
raise ValueError("Discord token cannot be empty")
# Basic Discord token format validation (not perfect but catches common issues)
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
raise ValueError("Invalid Discord token format")
return value
@field_validator("anthropic_api_key", "openai_api_key")
@classmethod
def _validate_api_key(cls, value: SecretStr | None) -> SecretStr | None:
"""Validate API key format if provided."""
if value is None:
return None
key = value.get_secret_value()
if not key:
return None
# Basic API key validation
if len(key) < 20:
raise ValueError("API key too short to be valid")
return value
def validate_configuration(self) -> None:
"""Validate the settings for runtime usage."""
# AI provider validation
if self.ai_provider == "anthropic" and not self.anthropic_api_key:
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
if self.ai_provider == "openai" and not self.openai_api_key:
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
# Database pool validation
if self.database_pool_min > self.database_pool_max:
raise ValueError("database_pool_min cannot be greater than database_pool_max")
if self.database_pool_min < 1:
raise ValueError("database_pool_min must be at least 1")
# Data directory validation
if not isinstance(self.data_dir, Path):
raise ValueError("data_dir must be a valid path")
# Wordlist validation
if self.wordlist_update_hours < 1:
raise ValueError("wordlist_update_hours must be at least 1")
def get_settings() -> Settings:
"""Get application settings instance."""
settings = Settings()
settings.validate_configuration()
return settings