update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
This commit is contained in:
@@ -5,9 +5,9 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator, ValidationError
|
||||
from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from pydantic_settings.sources import EnvSettingsSource
|
||||
|
||||
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
|
||||
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
|
||||
@@ -19,17 +19,17 @@ def _validate_discord_id(value: str | int) -> int:
|
||||
id_str = str(value)
|
||||
else:
|
||||
id_str = str(value).strip()
|
||||
|
||||
|
||||
# Check format
|
||||
if not DISCORD_ID_PATTERN.match(id_str):
|
||||
raise ValueError(f"Invalid Discord ID format: {id_str}")
|
||||
|
||||
|
||||
# Convert to int and validate range
|
||||
discord_id = int(id_str)
|
||||
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
|
||||
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
|
||||
raise ValueError(f"Discord ID out of valid range: {discord_id}")
|
||||
|
||||
|
||||
return discord_id
|
||||
|
||||
|
||||
@@ -65,6 +65,27 @@ def _parse_id_list(value: Any) -> list[int]:
|
||||
return parsed
|
||||
|
||||
|
||||
class GuardDenEnvSettingsSource(EnvSettingsSource):
|
||||
"""Environment settings source with safe list parsing."""
|
||||
|
||||
def decode_complex_value(self, field_name: str, field, value: Any):
|
||||
if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str):
|
||||
return value
|
||||
return super().decode_complex_value(field_name, field, value)
|
||||
|
||||
|
||||
class WordlistSourceConfig(BaseModel):
|
||||
"""Configuration for a managed wordlist source."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
category: Literal["hard", "soft", "context"]
|
||||
action: Literal["delete", "warn", "strike"]
|
||||
reason: str
|
||||
is_regex: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
@@ -73,8 +94,25 @@ class Settings(BaseSettings):
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
env_prefix="GUARDDEN_",
|
||||
env_parse_none_str="",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls,
|
||||
init_settings,
|
||||
env_settings,
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
):
|
||||
return (
|
||||
init_settings,
|
||||
GuardDenEnvSettingsSource(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
# Discord settings
|
||||
discord_token: SecretStr = Field(..., description="Discord bot token")
|
||||
discord_prefix: str = Field(default="!", description="Default command prefix")
|
||||
@@ -114,11 +152,43 @@ class Settings(BaseSettings):
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||
|
||||
# Wordlist sync
|
||||
wordlist_enabled: bool = Field(
|
||||
default=True, description="Enable automatic managed wordlist syncing"
|
||||
)
|
||||
wordlist_update_hours: int = Field(
|
||||
default=168, description="Managed wordlist sync interval in hours"
|
||||
)
|
||||
wordlist_sources: list[WordlistSourceConfig] = Field(
|
||||
default_factory=list,
|
||||
description="Managed wordlist sources (JSON array via env overrides)",
|
||||
)
|
||||
|
||||
@field_validator("allowed_guilds", "owner_ids", mode="before")
|
||||
@classmethod
|
||||
def _validate_id_list(cls, value: Any) -> list[int]:
|
||||
return _parse_id_list(value)
|
||||
|
||||
@field_validator("wordlist_sources", mode="before")
|
||||
@classmethod
|
||||
def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [WordlistSourceConfig.model_validate(item) for item in value]
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("Invalid JSON for wordlist_sources") from exc
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("wordlist_sources must be a JSON array")
|
||||
return [WordlistSourceConfig.model_validate(item) for item in data]
|
||||
return []
|
||||
|
||||
@field_validator("discord_token")
|
||||
@classmethod
|
||||
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
|
||||
@@ -126,11 +196,11 @@ class Settings(BaseSettings):
|
||||
token = value.get_secret_value()
|
||||
if not token:
|
||||
raise ValueError("Discord token cannot be empty")
|
||||
|
||||
|
||||
# Basic Discord token format validation (not perfect but catches common issues)
|
||||
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
|
||||
raise ValueError("Invalid Discord token format")
|
||||
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("anthropic_api_key", "openai_api_key")
|
||||
@@ -139,15 +209,15 @@ class Settings(BaseSettings):
|
||||
"""Validate API key format if provided."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
|
||||
key = value.get_secret_value()
|
||||
if not key:
|
||||
return None
|
||||
|
||||
|
||||
# Basic API key validation
|
||||
if len(key) < 20:
|
||||
raise ValueError("API key too short to be valid")
|
||||
|
||||
|
||||
return value
|
||||
|
||||
def validate_configuration(self) -> None:
|
||||
@@ -157,17 +227,21 @@ class Settings(BaseSettings):
|
||||
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
|
||||
if self.ai_provider == "openai" and not self.openai_api_key:
|
||||
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
|
||||
|
||||
|
||||
# Database pool validation
|
||||
if self.database_pool_min > self.database_pool_max:
|
||||
raise ValueError("database_pool_min cannot be greater than database_pool_max")
|
||||
if self.database_pool_min < 1:
|
||||
raise ValueError("database_pool_min must be at least 1")
|
||||
|
||||
|
||||
# Data directory validation
|
||||
if not isinstance(self.data_dir, Path):
|
||||
raise ValueError("data_dir must be a valid path")
|
||||
|
||||
# Wordlist validation
|
||||
if self.wordlist_update_hours < 1:
|
||||
raise ValueError("wordlist_update_hours must be at least 1")
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
|
||||
Reference in New Issue
Block a user