"""Tests for database integration and models.""" from datetime import datetime, timezone from unittest.mock import MagicMock import pytest from sqlalchemy import select from guardden.models.guild import BannedWord, Guild, GuildSettings from guardden.models.moderation import ModerationLog, Strike, UserNote from guardden.services.database import Database from guardden.services.guild_config import GuildConfigService class TestDatabaseModels: """Test database models and relationships.""" async def test_guild_creation(self, db_session, sample_guild_id, sample_owner_id): """Test guild creation with settings.""" guild = Guild( id=sample_guild_id, name="Test Guild", owner_id=sample_owner_id, premium=False, ) db_session.add(guild) settings = GuildSettings( guild_id=sample_guild_id, prefix="!", automod_enabled=True, ai_moderation_enabled=False, ) db_session.add(settings) await db_session.commit() # Test guild was created result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id)) created_guild = result.scalar_one() assert created_guild.id == sample_guild_id assert created_guild.name == "Test Guild" assert created_guild.owner_id == sample_owner_id assert not created_guild.premium async def test_guild_settings_relationship(self, test_guild, db_session): """Test guild-settings relationship.""" # Load guild with settings result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id)) guild_with_settings = result.scalar_one() # Test relationship loading await db_session.refresh(guild_with_settings, ["settings"]) assert guild_with_settings.settings is not None assert guild_with_settings.settings.guild_id == test_guild.id assert guild_with_settings.settings.prefix == "!" async def test_banned_word_creation(self, test_guild, db_session, sample_moderator_id): """Test banned word creation and relationship.""" banned_word = BannedWord( guild_id=test_guild.id, pattern="testbadword", is_regex=False, action="delete", reason="Test ban", added_by=sample_moderator_id, ) db_session.add(banned_word) await db_session.commit() # Verify creation result = await db_session.execute( select(BannedWord).where(BannedWord.guild_id == test_guild.id) ) created_word = result.scalar_one() assert created_word.pattern == "testbadword" assert not created_word.is_regex assert created_word.action == "delete" assert created_word.added_by == sample_moderator_id async def test_moderation_log_creation( self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test moderation log creation.""" mod_log = ModerationLog( guild_id=test_guild.id, target_id=sample_user_id, target_name="TestUser", moderator_id=sample_moderator_id, moderator_name="TestModerator", action="ban", reason="Test ban", is_automatic=False, ) db_session.add(mod_log) await db_session.commit() # Verify creation result = await db_session.execute( select(ModerationLog).where(ModerationLog.guild_id == test_guild.id) ) created_log = result.scalar_one() assert created_log.action == "ban" assert created_log.target_id == sample_user_id assert created_log.moderator_id == sample_moderator_id assert not created_log.is_automatic async def test_strike_creation( self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test strike creation and tracking.""" strike = Strike( guild_id=test_guild.id, user_id=sample_user_id, user_name="TestUser", moderator_id=sample_moderator_id, reason="Test strike", points=1, is_active=True, ) db_session.add(strike) await db_session.commit() # Verify creation result = await db_session.execute( select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id) ) created_strike = result.scalar_one() assert created_strike.points == 1 assert created_strike.is_active assert created_strike.user_id == sample_user_id async def test_cascade_deletion( self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test that deleting a guild cascades to related records.""" # Add some related records banned_word = BannedWord( guild_id=test_guild.id, pattern="test", is_regex=False, action="delete", added_by=sample_moderator_id, ) mod_log = ModerationLog( guild_id=test_guild.id, target_id=sample_user_id, target_name="TestUser", moderator_id=sample_moderator_id, moderator_name="TestModerator", action="warn", reason="Test warning", is_automatic=False, ) strike = Strike( guild_id=test_guild.id, user_id=sample_user_id, user_name="TestUser", moderator_id=sample_moderator_id, reason="Test strike", points=1, is_active=True, ) db_session.add_all([banned_word, mod_log, strike]) await db_session.commit() # Delete the guild await db_session.delete(test_guild) await db_session.commit() # Verify related records were deleted banned_words = await db_session.execute( select(BannedWord).where(BannedWord.guild_id == test_guild.id) ) assert len(banned_words.scalars().all()) == 0 mod_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.guild_id == test_guild.id) ) assert len(mod_logs.scalars().all()) == 0 strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id)) assert len(strikes.scalars().all()) == 0 class TestDatabaseIndexes: """Test that database indexes work as expected.""" async def test_moderation_log_indexes( self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test moderation log indexing for performance.""" # Create multiple moderation logs logs = [] for i in range(10): log = ModerationLog( guild_id=test_guild.id, target_id=sample_user_id + i, target_name=f"TestUser{i}", moderator_id=sample_moderator_id, moderator_name="TestModerator", action="warn", reason=f"Test warning {i}", is_automatic=bool(i % 2), ) logs.append(log) db_session.add_all(logs) await db_session.commit() # Test queries that should use indexes # Query by guild_id guild_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.guild_id == test_guild.id) ) assert len(guild_logs.scalars().all()) == 10 # Query by target_id target_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.target_id == sample_user_id) ) assert len(target_logs.scalars().all()) == 1 # Query by is_automatic auto_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.is_automatic == True) ) assert len(auto_logs.scalars().all()) == 5 async def test_strike_indexes( self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test strike indexing for performance.""" # Create multiple strikes strikes = [] for i in range(5): strike = Strike( guild_id=test_guild.id, user_id=sample_user_id + i, user_name=f"TestUser{i}", moderator_id=sample_moderator_id, reason=f"Strike {i}", points=1, is_active=bool(i % 2), ) strikes.append(strike) db_session.add_all(strikes) await db_session.commit() # Test active strikes query active_strikes = await db_session.execute( select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True) ) assert len(active_strikes.scalars().all()) == 2 # indices 1, 3 class TestDatabaseSecurity: """Test database security features.""" async def test_snowflake_id_validation(self, db_session): """Test that snowflake IDs are properly validated.""" # Valid snowflake ID valid_guild_id = 123456789012345678 guild = Guild( id=valid_guild_id, name="Valid Guild", owner_id=123456789012345679, premium=False, ) db_session.add(guild) await db_session.commit() # Verify it was stored correctly result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id)) stored_guild = result.scalar_one() assert stored_guild.id == valid_guild_id async def test_sql_injection_prevention(self, db_session, test_guild): """Test that SQL injection is prevented.""" # Attempt to inject malicious SQL through user input malicious_inputs = [ "'; DROP TABLE guilds; --", "' UNION SELECT * FROM guild_settings --", "' OR '1'='1", "", ] for malicious_input in malicious_inputs: # Try to use malicious input in a query # SQLAlchemy should prevent injection through parameterized queries result = await db_session.execute(select(Guild).where(Guild.name == malicious_input)) # Should not find anything (and not crash) assert result.scalar_one_or_none() is None async def test_data_integrity_constraints(self, db_session, sample_guild_id): """Test that database constraints are enforced.""" # Test foreign key constraint with pytest.raises(Exception): # Should raise integrity error banned_word = BannedWord( guild_id=999999999999999999, # Non-existent guild pattern="test", is_regex=False, action="delete", added_by=123456789012345678, ) db_session.add(banned_word) await db_session.commit() await db_session.rollback() class TestGuildConfigServiceWithDefaults: """Test GuildConfigService.create_guild() with settings defaults.""" async def test_create_guild_uses_settings_defaults( self, test_database, settings_with_custom_defaults, sample_guild_id, sample_owner_id ): """Test create_guild applies settings.guild_default values.""" service = GuildConfigService(test_database, settings=settings_with_custom_defaults) # Create mock Discord guild mock_guild = MagicMock() mock_guild.id = sample_guild_id mock_guild.name = "Test Guild" mock_guild.owner_id = sample_owner_id # Create guild db_guild = await service.create_guild(mock_guild) # Verify guild was created assert db_guild.id == sample_guild_id assert db_guild.name == "Test Guild" # Get settings and verify defaults were applied guild_settings = await service.get_config(sample_guild_id) assert guild_settings is not None assert guild_settings.prefix == "?" # Custom default assert guild_settings.ai_sensitivity == 50 # Custom default assert guild_settings.automod_enabled is False # Custom default assert guild_settings.verification_enabled is True # Custom default assert guild_settings.verification_type == "captcha" # Custom default async def test_create_guild_without_settings( self, test_database, sample_guild_id, sample_owner_id ): """Test create_guild works when settings is None.""" service = GuildConfigService(test_database, settings=None) # Create mock Discord guild mock_guild = MagicMock() mock_guild.id = sample_guild_id mock_guild.name = "Test Guild" mock_guild.owner_id = sample_owner_id # Create guild db_guild = await service.create_guild(mock_guild) # Verify guild was created assert db_guild.id == sample_guild_id # Get settings and verify hardcoded defaults were used guild_settings = await service.get_config(sample_guild_id) assert guild_settings is not None assert guild_settings.prefix == "!" # Hardcoded default assert guild_settings.ai_sensitivity == 80 # Hardcoded default assert guild_settings.automod_enabled is True # Hardcoded default async def test_create_guild_existing_guild_unchanged( self, test_database, settings_with_custom_defaults, sample_guild_id, sample_owner_id ): """Test create_guild returns existing guild without changes.""" service = GuildConfigService(test_database, settings=settings_with_custom_defaults) # Create mock Discord guild mock_guild = MagicMock() mock_guild.id = sample_guild_id mock_guild.name = "Test Guild" mock_guild.owner_id = sample_owner_id # Create guild first time first_guild = await service.create_guild(mock_guild) assert first_guild.id == sample_guild_id # Try to create again with different name mock_guild.name = "Different Name" second_guild = await service.create_guild(mock_guild) # Should return existing guild assert second_guild.id == first_guild.id assert second_guild.name == "Test Guild" # Original name async def test_create_guild_with_standard_settings( self, test_database, test_settings, sample_guild_id, sample_owner_id ): """Test create_guild with standard test_settings fixture.""" service = GuildConfigService(test_database, settings=test_settings) # Create mock Discord guild mock_guild = MagicMock() mock_guild.id = sample_guild_id mock_guild.name = "Test Guild" mock_guild.owner_id = sample_owner_id # Create guild await service.create_guild(mock_guild) # Get settings and verify standard defaults guild_settings = await service.get_config(sample_guild_id) assert guild_settings is not None # Standard settings use default GuildDefaults assert guild_settings.prefix == "!" assert guild_settings.ai_sensitivity == 80