diff --git a/src/guardden/bot.py b/src/guardden/bot.py index 36787b9..ed13f2a 100644 --- a/src/guardden/bot.py +++ b/src/guardden/bot.py @@ -90,7 +90,7 @@ class GuardDen(commands.Bot): # Initialize services from guardden.services.guild_config import GuildConfigService - self.guild_config = GuildConfigService(self.database) + self.guild_config = GuildConfigService(self.database, settings=self.settings) from guardden.services.wordlist import WordlistService self.wordlist_service = WordlistService(self.database, self.settings) diff --git a/src/guardden/config.py b/src/guardden/config.py index 69c1add..b4b67c7 100644 --- a/src/guardden/config.py +++ b/src/guardden/config.py @@ -86,6 +86,43 @@ class WordlistSourceConfig(BaseModel): enabled: bool = True +class GuildDefaults(BaseModel): + """Default values for new guild settings (configurable via env). + + These values are used when creating a new guild configuration. + Override via environment variables with GUARDDEN_GUILD_DEFAULT_ prefix. + Example: GUARDDEN_GUILD_DEFAULT_PREFIX=? sets the default prefix to "?" + """ + + prefix: str = Field(default="!", min_length=1, max_length=10) + locale: str = Field(default="en", min_length=2, max_length=10) + automod_enabled: bool = True + anti_spam_enabled: bool = True + link_filter_enabled: bool = False + message_rate_limit: int = Field(default=5, ge=1) + message_rate_window: int = Field(default=5, ge=1) + duplicate_threshold: int = Field(default=3, ge=1) + mention_limit: int = Field(default=5, ge=1) + mention_rate_limit: int = Field(default=10, ge=1) + mention_rate_window: int = Field(default=60, ge=1) + ai_moderation_enabled: bool = True + ai_sensitivity: int = Field(default=80, ge=0, le=100) + ai_confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0) + ai_log_only: bool = False + nsfw_detection_enabled: bool = True + verification_enabled: bool = False + verification_type: Literal["button", "captcha", "math", "emoji"] = "button" + strike_actions: dict = Field( + default_factory=lambda: { + "1": {"action": "warn"}, + "3": {"action": "timeout", "duration": 3600}, + "5": {"action": "kick"}, + "7": {"action": "ban"}, + } + ) + scam_allowlist: list[str] = Field(default_factory=list) + + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -95,6 +132,7 @@ class Settings(BaseSettings): case_sensitive=False, env_prefix="GUARDDEN_", env_parse_none_str="", + env_nested_delimiter="_", ) @classmethod @@ -164,6 +202,12 @@ class Settings(BaseSettings): description="Managed wordlist sources (JSON array via env overrides)", ) + # Guild defaults (used when creating new guild configurations) + guild_default: GuildDefaults = Field( + default_factory=GuildDefaults, + description="Default values for new guild settings", + ) + @field_validator("allowed_guilds", "owner_ids", mode="before") @classmethod def _validate_id_list(cls, value: Any) -> list[int]: diff --git a/src/guardden/services/guild_config.py b/src/guardden/services/guild_config.py index 92f6249..0283c02 100644 --- a/src/guardden/services/guild_config.py +++ b/src/guardden/services/guild_config.py @@ -1,6 +1,9 @@ """Guild configuration service.""" +from __future__ import annotations + import logging +from typing import TYPE_CHECKING import discord from sqlalchemy import select @@ -10,14 +13,23 @@ from guardden.models import BannedWord, Guild, GuildSettings from guardden.services.cache import CacheService, get_cache_service from guardden.services.database import Database +if TYPE_CHECKING: + from guardden.config import Settings + logger = logging.getLogger(__name__) class GuildConfigService: """Manages guild configurations with multi-tier caching.""" - def __init__(self, database: Database, cache: CacheService | None = None) -> None: + def __init__( + self, + database: Database, + settings: Settings | None = None, + cache: CacheService | None = None, + ) -> None: self.database = database + self.settings = settings self.cache = cache or get_cache_service() self._memory_cache: dict[int, GuildSettings] = {} self._cache_ttl = 300 # 5 minutes @@ -88,9 +100,35 @@ class GuildConfigService: session.add(db_guild) await session.flush() - # Create default settings - settings = GuildSettings(guild_id=guild.id) - session.add(settings) + # Create settings with defaults from config (if available) + if self.settings and self.settings.guild_default: + defaults = self.settings.guild_default + guild_settings = GuildSettings( + guild_id=guild.id, + prefix=defaults.prefix, + locale=defaults.locale, + automod_enabled=defaults.automod_enabled, + anti_spam_enabled=defaults.anti_spam_enabled, + link_filter_enabled=defaults.link_filter_enabled, + message_rate_limit=defaults.message_rate_limit, + message_rate_window=defaults.message_rate_window, + duplicate_threshold=defaults.duplicate_threshold, + mention_limit=defaults.mention_limit, + mention_rate_limit=defaults.mention_rate_limit, + mention_rate_window=defaults.mention_rate_window, + ai_moderation_enabled=defaults.ai_moderation_enabled, + ai_sensitivity=defaults.ai_sensitivity, + ai_confidence_threshold=defaults.ai_confidence_threshold, + ai_log_only=defaults.ai_log_only, + nsfw_detection_enabled=defaults.nsfw_detection_enabled, + verification_enabled=defaults.verification_enabled, + verification_type=defaults.verification_type, + strike_actions=defaults.strike_actions, + scam_allowlist=defaults.scam_allowlist, + ) + else: + guild_settings = GuildSettings(guild_id=guild.id) + session.add(guild_settings) await session.commit() @@ -115,7 +153,7 @@ class GuildConfigService: await session.commit() # Invalidate cache - self._cache.pop(guild_id, None) + self._memory_cache.pop(guild_id, None) return settings diff --git a/tests/conftest.py b/tests/conftest.py index 2c833ab..ad05a6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ if str(SRC_DIR) not in sys.path: sys.path.insert(0, str(SRC_DIR)) # Import after path setup -from guardden.config import Settings +from guardden.config import GuildDefaults, Settings from guardden.models.base import Base from guardden.models.guild import BannedWord, Guild, GuildSettings from guardden.models.moderation import ModerationLog, Strike, UserNote @@ -99,6 +99,31 @@ def test_settings() -> Settings: ) +@pytest.fixture +def settings_with_custom_defaults() -> Settings: + """Return test settings with custom guild defaults.""" + custom_defaults = GuildDefaults( + prefix="?", + ai_sensitivity=50, + automod_enabled=False, + verification_enabled=True, + verification_type="captcha", + ) + return Settings( + discord_token="a" * 60, + discord_prefix="!test", + database_url="sqlite+aiosqlite:///test.db", + database_pool_min=1, + database_pool_max=1, + ai_provider="none", + log_level="DEBUG", + allowed_guilds=[], + owner_ids=[], + data_dir=Path("/tmp/guardden_test"), + guild_default=custom_defaults, + ) + + # ============================================================================== # Database Fixtures # ============================================================================== diff --git a/tests/test_config.py b/tests/test_config.py index 6c0850b..87c3b3a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,7 @@ import pytest from pydantic import ValidationError -from guardden.config import Settings, _parse_id_list, _validate_discord_id +from guardden.config import GuildDefaults, Settings, _parse_id_list, _validate_discord_id from guardden.services.automod import normalize_domain @@ -244,3 +244,143 @@ class TestSecurityImprovements: os.environ["GUARDDEN_OWNER_IDS"] = original_owners else: os.environ.pop("GUARDDEN_OWNER_IDS", None) + + +class TestGuildDefaultsValidation: + """Test GuildDefaults model validation.""" + + def test_default_values(self): + """Test default factory creates valid GuildDefaults.""" + defaults = GuildDefaults() + assert defaults.prefix == "!" + assert defaults.locale == "en" + assert defaults.automod_enabled is True + assert defaults.ai_sensitivity == 80 + assert defaults.ai_confidence_threshold == 0.7 + assert defaults.verification_type == "button" + + def test_ai_sensitivity_valid_range(self): + """Test ai_sensitivity accepts values 0-100.""" + assert GuildDefaults(ai_sensitivity=0).ai_sensitivity == 0 + assert GuildDefaults(ai_sensitivity=50).ai_sensitivity == 50 + assert GuildDefaults(ai_sensitivity=100).ai_sensitivity == 100 + + def test_ai_sensitivity_invalid_range(self): + """Test ai_sensitivity rejects values outside 0-100.""" + with pytest.raises(ValidationError): + GuildDefaults(ai_sensitivity=-1) + with pytest.raises(ValidationError): + GuildDefaults(ai_sensitivity=101) + + def test_ai_confidence_threshold_valid_range(self): + """Test ai_confidence_threshold accepts values 0.0-1.0.""" + assert GuildDefaults(ai_confidence_threshold=0.0).ai_confidence_threshold == 0.0 + assert GuildDefaults(ai_confidence_threshold=0.5).ai_confidence_threshold == 0.5 + assert GuildDefaults(ai_confidence_threshold=1.0).ai_confidence_threshold == 1.0 + + def test_ai_confidence_threshold_invalid_range(self): + """Test ai_confidence_threshold rejects values outside 0.0-1.0.""" + with pytest.raises(ValidationError): + GuildDefaults(ai_confidence_threshold=-0.1) + with pytest.raises(ValidationError): + GuildDefaults(ai_confidence_threshold=1.1) + + def test_verification_type_valid_values(self): + """Test verification_type only accepts valid types.""" + valid_types = ["button", "captcha", "math", "emoji"] + for vtype in valid_types: + assert GuildDefaults(verification_type=vtype).verification_type == vtype + + def test_verification_type_invalid_values(self): + """Test verification_type rejects invalid types.""" + with pytest.raises(ValidationError): + GuildDefaults(verification_type="invalid") + with pytest.raises(ValidationError): + GuildDefaults(verification_type="") + + def test_positive_rate_limits(self): + """Test rate limit fields must be positive.""" + # Valid positive values + defaults = GuildDefaults( + message_rate_limit=1, + message_rate_window=1, + duplicate_threshold=1, + mention_limit=1, + mention_rate_limit=1, + mention_rate_window=1, + ) + assert defaults.message_rate_limit == 1 + + # Invalid zero or negative values + with pytest.raises(ValidationError): + GuildDefaults(message_rate_limit=0) + with pytest.raises(ValidationError): + GuildDefaults(message_rate_window=-1) + + def test_prefix_length_constraints(self): + """Test prefix has length constraints.""" + # Valid prefixes + assert GuildDefaults(prefix="!").prefix == "!" + assert GuildDefaults(prefix="??").prefix == "??" + assert GuildDefaults(prefix="!" * 10).prefix == "!" * 10 + + # Invalid: empty prefix + with pytest.raises(ValidationError): + GuildDefaults(prefix="") + + # Invalid: too long + with pytest.raises(ValidationError): + GuildDefaults(prefix="!" * 11) + + +class TestSettingsGuildDefaults: + """Test Settings.guild_default field.""" + + def test_guild_default_factory(self): + """Test guild_default uses factory default.""" + settings = Settings(discord_token="a" * 60) + assert settings.guild_default is not None + assert isinstance(settings.guild_default, GuildDefaults) + assert settings.guild_default.prefix == "!" + + def test_guild_default_custom_values(self): + """Test guild_default can be set with custom values.""" + custom_defaults = GuildDefaults( + prefix="?", + ai_sensitivity=50, + verification_enabled=True, + ) + settings = Settings(discord_token="a" * 60, guild_default=custom_defaults) + assert settings.guild_default.prefix == "?" + assert settings.guild_default.ai_sensitivity == 50 + assert settings.guild_default.verification_enabled is True + + def test_strike_actions_default(self): + """Test strike_actions has correct default structure.""" + defaults = GuildDefaults() + assert defaults.strike_actions == { + "1": {"action": "warn"}, + "3": {"action": "timeout", "duration": 3600}, + "5": {"action": "kick"}, + "7": {"action": "ban"}, + } + + def test_strike_actions_custom(self): + """Test strike_actions can be customized.""" + custom_actions = { + "1": {"action": "warn"}, + "5": {"action": "ban"}, + } + defaults = GuildDefaults(strike_actions=custom_actions) + assert defaults.strike_actions == custom_actions + + def test_scam_allowlist_default(self): + """Test scam_allowlist defaults to empty list.""" + defaults = GuildDefaults() + assert defaults.scam_allowlist == [] + + def test_scam_allowlist_custom(self): + """Test scam_allowlist can be customized.""" + custom_list = ["discord.com", "github.com"] + defaults = GuildDefaults(scam_allowlist=custom_list) + assert defaults.scam_allowlist == custom_list diff --git a/tests/test_database_integration.py b/tests/test_database_integration.py index 848ea12..cad172b 100644 --- a/tests/test_database_integration.py +++ b/tests/test_database_integration.py @@ -1,6 +1,7 @@ """Tests for database integration and models.""" from datetime import datetime, timezone +from unittest.mock import MagicMock import pytest from sqlalchemy import select @@ -8,6 +9,7 @@ from sqlalchemy import select from guardden.models.guild import BannedWord, Guild, GuildSettings from guardden.models.moderation import ModerationLog, Strike, UserNote from guardden.services.database import Database +from guardden.services.guild_config import GuildConfigService class TestDatabaseModels: @@ -312,3 +314,106 @@ class TestDatabaseSecurity: db_session.add(banned_word) await db_session.commit() await db_session.rollback() + + +class TestGuildConfigServiceWithDefaults: + """Test GuildConfigService.create_guild() with settings defaults.""" + + async def test_create_guild_uses_settings_defaults( + self, test_database, settings_with_custom_defaults, sample_guild_id, sample_owner_id + ): + """Test create_guild applies settings.guild_default values.""" + service = GuildConfigService(test_database, settings=settings_with_custom_defaults) + + # Create mock Discord guild + mock_guild = MagicMock() + mock_guild.id = sample_guild_id + mock_guild.name = "Test Guild" + mock_guild.owner_id = sample_owner_id + + # Create guild + db_guild = await service.create_guild(mock_guild) + + # Verify guild was created + assert db_guild.id == sample_guild_id + assert db_guild.name == "Test Guild" + + # Get settings and verify defaults were applied + guild_settings = await service.get_config(sample_guild_id) + assert guild_settings is not None + assert guild_settings.prefix == "?" # Custom default + assert guild_settings.ai_sensitivity == 50 # Custom default + assert guild_settings.automod_enabled is False # Custom default + assert guild_settings.verification_enabled is True # Custom default + assert guild_settings.verification_type == "captcha" # Custom default + + async def test_create_guild_without_settings( + self, test_database, sample_guild_id, sample_owner_id + ): + """Test create_guild works when settings is None.""" + service = GuildConfigService(test_database, settings=None) + + # Create mock Discord guild + mock_guild = MagicMock() + mock_guild.id = sample_guild_id + mock_guild.name = "Test Guild" + mock_guild.owner_id = sample_owner_id + + # Create guild + db_guild = await service.create_guild(mock_guild) + + # Verify guild was created + assert db_guild.id == sample_guild_id + + # Get settings and verify hardcoded defaults were used + guild_settings = await service.get_config(sample_guild_id) + assert guild_settings is not None + assert guild_settings.prefix == "!" # Hardcoded default + assert guild_settings.ai_sensitivity == 80 # Hardcoded default + assert guild_settings.automod_enabled is True # Hardcoded default + + async def test_create_guild_existing_guild_unchanged( + self, test_database, settings_with_custom_defaults, sample_guild_id, sample_owner_id + ): + """Test create_guild returns existing guild without changes.""" + service = GuildConfigService(test_database, settings=settings_with_custom_defaults) + + # Create mock Discord guild + mock_guild = MagicMock() + mock_guild.id = sample_guild_id + mock_guild.name = "Test Guild" + mock_guild.owner_id = sample_owner_id + + # Create guild first time + first_guild = await service.create_guild(mock_guild) + assert first_guild.id == sample_guild_id + + # Try to create again with different name + mock_guild.name = "Different Name" + second_guild = await service.create_guild(mock_guild) + + # Should return existing guild + assert second_guild.id == first_guild.id + assert second_guild.name == "Test Guild" # Original name + + async def test_create_guild_with_standard_settings( + self, test_database, test_settings, sample_guild_id, sample_owner_id + ): + """Test create_guild with standard test_settings fixture.""" + service = GuildConfigService(test_database, settings=test_settings) + + # Create mock Discord guild + mock_guild = MagicMock() + mock_guild.id = sample_guild_id + mock_guild.name = "Test Guild" + mock_guild.owner_id = sample_owner_id + + # Create guild + await service.create_guild(mock_guild) + + # Get settings and verify standard defaults + guild_settings = await service.get_config(sample_guild_id) + assert guild_settings is not None + # Standard settings use default GuildDefaults + assert guild_settings.prefix == "!" + assert guild_settings.ai_sensitivity == 80