Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Failing after 4m58s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m0s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
420 lines
15 KiB
Python
420 lines
15 KiB
Python
"""Tests for database integration and models."""
|
|
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
|
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
|
from guardden.services.database import Database
|
|
from guardden.services.guild_config import GuildConfigService
|
|
|
|
|
|
class TestDatabaseModels:
|
|
"""Test database models and relationships."""
|
|
|
|
async def test_guild_creation(self, db_session, sample_guild_id, sample_owner_id):
|
|
"""Test guild creation with settings."""
|
|
guild = Guild(
|
|
id=sample_guild_id,
|
|
name="Test Guild",
|
|
owner_id=sample_owner_id,
|
|
premium=False,
|
|
)
|
|
db_session.add(guild)
|
|
|
|
settings = GuildSettings(
|
|
guild_id=sample_guild_id,
|
|
prefix="!",
|
|
automod_enabled=True,
|
|
ai_moderation_enabled=False,
|
|
)
|
|
db_session.add(settings)
|
|
|
|
await db_session.commit()
|
|
|
|
# Test guild was created
|
|
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
|
|
created_guild = result.scalar_one()
|
|
|
|
assert created_guild.id == sample_guild_id
|
|
assert created_guild.name == "Test Guild"
|
|
assert created_guild.owner_id == sample_owner_id
|
|
assert not created_guild.premium
|
|
|
|
async def test_guild_settings_relationship(self, test_guild, db_session):
|
|
"""Test guild-settings relationship."""
|
|
# Load guild with settings
|
|
result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id))
|
|
guild_with_settings = result.scalar_one()
|
|
|
|
# Test relationship loading
|
|
await db_session.refresh(guild_with_settings, ["settings"])
|
|
assert guild_with_settings.settings is not None
|
|
assert guild_with_settings.settings.guild_id == test_guild.id
|
|
assert guild_with_settings.settings.prefix == "!"
|
|
|
|
async def test_banned_word_creation(self, test_guild, db_session, sample_moderator_id):
|
|
"""Test banned word creation and relationship."""
|
|
banned_word = BannedWord(
|
|
guild_id=test_guild.id,
|
|
pattern="testbadword",
|
|
is_regex=False,
|
|
action="delete",
|
|
reason="Test ban",
|
|
added_by=sample_moderator_id,
|
|
)
|
|
db_session.add(banned_word)
|
|
await db_session.commit()
|
|
|
|
# Verify creation
|
|
result = await db_session.execute(
|
|
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
|
)
|
|
created_word = result.scalar_one()
|
|
|
|
assert created_word.pattern == "testbadword"
|
|
assert not created_word.is_regex
|
|
assert created_word.action == "delete"
|
|
assert created_word.added_by == sample_moderator_id
|
|
|
|
async def test_moderation_log_creation(
|
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
|
):
|
|
"""Test moderation log creation."""
|
|
mod_log = ModerationLog(
|
|
guild_id=test_guild.id,
|
|
target_id=sample_user_id,
|
|
target_name="TestUser",
|
|
moderator_id=sample_moderator_id,
|
|
moderator_name="TestModerator",
|
|
action="ban",
|
|
reason="Test ban",
|
|
is_automatic=False,
|
|
)
|
|
db_session.add(mod_log)
|
|
await db_session.commit()
|
|
|
|
# Verify creation
|
|
result = await db_session.execute(
|
|
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
|
)
|
|
created_log = result.scalar_one()
|
|
|
|
assert created_log.action == "ban"
|
|
assert created_log.target_id == sample_user_id
|
|
assert created_log.moderator_id == sample_moderator_id
|
|
assert not created_log.is_automatic
|
|
|
|
async def test_strike_creation(
|
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
|
):
|
|
"""Test strike creation and tracking."""
|
|
strike = Strike(
|
|
guild_id=test_guild.id,
|
|
user_id=sample_user_id,
|
|
user_name="TestUser",
|
|
moderator_id=sample_moderator_id,
|
|
reason="Test strike",
|
|
points=1,
|
|
is_active=True,
|
|
)
|
|
db_session.add(strike)
|
|
await db_session.commit()
|
|
|
|
# Verify creation
|
|
result = await db_session.execute(
|
|
select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id)
|
|
)
|
|
created_strike = result.scalar_one()
|
|
|
|
assert created_strike.points == 1
|
|
assert created_strike.is_active
|
|
assert created_strike.user_id == sample_user_id
|
|
|
|
async def test_cascade_deletion(
|
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
|
):
|
|
"""Test that deleting a guild cascades to related records."""
|
|
# Add some related records
|
|
banned_word = BannedWord(
|
|
guild_id=test_guild.id,
|
|
pattern="test",
|
|
is_regex=False,
|
|
action="delete",
|
|
added_by=sample_moderator_id,
|
|
)
|
|
|
|
mod_log = ModerationLog(
|
|
guild_id=test_guild.id,
|
|
target_id=sample_user_id,
|
|
target_name="TestUser",
|
|
moderator_id=sample_moderator_id,
|
|
moderator_name="TestModerator",
|
|
action="warn",
|
|
reason="Test warning",
|
|
is_automatic=False,
|
|
)
|
|
|
|
strike = Strike(
|
|
guild_id=test_guild.id,
|
|
user_id=sample_user_id,
|
|
user_name="TestUser",
|
|
moderator_id=sample_moderator_id,
|
|
reason="Test strike",
|
|
points=1,
|
|
is_active=True,
|
|
)
|
|
|
|
db_session.add_all([banned_word, mod_log, strike])
|
|
await db_session.commit()
|
|
|
|
# Delete the guild
|
|
await db_session.delete(test_guild)
|
|
await db_session.commit()
|
|
|
|
# Verify related records were deleted
|
|
banned_words = await db_session.execute(
|
|
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
|
)
|
|
assert len(banned_words.scalars().all()) == 0
|
|
|
|
mod_logs = await db_session.execute(
|
|
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
|
)
|
|
assert len(mod_logs.scalars().all()) == 0
|
|
|
|
strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id))
|
|
assert len(strikes.scalars().all()) == 0
|
|
|
|
|
|
class TestDatabaseIndexes:
|
|
"""Test that database indexes work as expected."""
|
|
|
|
async def test_moderation_log_indexes(
|
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
|
):
|
|
"""Test moderation log indexing for performance."""
|
|
# Create multiple moderation logs
|
|
logs = []
|
|
for i in range(10):
|
|
log = ModerationLog(
|
|
guild_id=test_guild.id,
|
|
target_id=sample_user_id + i,
|
|
target_name=f"TestUser{i}",
|
|
moderator_id=sample_moderator_id,
|
|
moderator_name="TestModerator",
|
|
action="warn",
|
|
reason=f"Test warning {i}",
|
|
is_automatic=bool(i % 2),
|
|
)
|
|
logs.append(log)
|
|
|
|
db_session.add_all(logs)
|
|
await db_session.commit()
|
|
|
|
# Test queries that should use indexes
|
|
# Query by guild_id
|
|
guild_logs = await db_session.execute(
|
|
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
|
)
|
|
assert len(guild_logs.scalars().all()) == 10
|
|
|
|
# Query by target_id
|
|
target_logs = await db_session.execute(
|
|
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
|
|
)
|
|
assert len(target_logs.scalars().all()) == 1
|
|
|
|
# Query by is_automatic
|
|
auto_logs = await db_session.execute(
|
|
select(ModerationLog).where(ModerationLog.is_automatic == True)
|
|
)
|
|
assert len(auto_logs.scalars().all()) == 5
|
|
|
|
async def test_strike_indexes(
|
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
|
):
|
|
"""Test strike indexing for performance."""
|
|
# Create multiple strikes
|
|
strikes = []
|
|
for i in range(5):
|
|
strike = Strike(
|
|
guild_id=test_guild.id,
|
|
user_id=sample_user_id + i,
|
|
user_name=f"TestUser{i}",
|
|
moderator_id=sample_moderator_id,
|
|
reason=f"Strike {i}",
|
|
points=1,
|
|
is_active=bool(i % 2),
|
|
)
|
|
strikes.append(strike)
|
|
|
|
db_session.add_all(strikes)
|
|
await db_session.commit()
|
|
|
|
# Test active strikes query
|
|
active_strikes = await db_session.execute(
|
|
select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True)
|
|
)
|
|
assert len(active_strikes.scalars().all()) == 2 # indices 1, 3
|
|
|
|
|
|
class TestDatabaseSecurity:
|
|
"""Test database security features."""
|
|
|
|
async def test_snowflake_id_validation(self, db_session):
|
|
"""Test that snowflake IDs are properly validated."""
|
|
# Valid snowflake ID
|
|
valid_guild_id = 123456789012345678
|
|
guild = Guild(
|
|
id=valid_guild_id,
|
|
name="Valid Guild",
|
|
owner_id=123456789012345679,
|
|
premium=False,
|
|
)
|
|
db_session.add(guild)
|
|
await db_session.commit()
|
|
|
|
# Verify it was stored correctly
|
|
result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id))
|
|
stored_guild = result.scalar_one()
|
|
assert stored_guild.id == valid_guild_id
|
|
|
|
async def test_sql_injection_prevention(self, db_session, test_guild):
|
|
"""Test that SQL injection is prevented."""
|
|
# Attempt to inject malicious SQL through user input
|
|
malicious_inputs = [
|
|
"'; DROP TABLE guilds; --",
|
|
"' UNION SELECT * FROM guild_settings --",
|
|
"' OR '1'='1",
|
|
"<script>alert('xss')</script>",
|
|
]
|
|
|
|
for malicious_input in malicious_inputs:
|
|
# Try to use malicious input in a query
|
|
# SQLAlchemy should prevent injection through parameterized queries
|
|
result = await db_session.execute(select(Guild).where(Guild.name == malicious_input))
|
|
# Should not find anything (and not crash)
|
|
assert result.scalar_one_or_none() is None
|
|
|
|
async def test_data_integrity_constraints(self, db_session, sample_guild_id):
|
|
"""Test that database constraints are enforced."""
|
|
# Test foreign key constraint
|
|
with pytest.raises(Exception): # Should raise integrity error
|
|
banned_word = BannedWord(
|
|
guild_id=999999999999999999, # Non-existent guild
|
|
pattern="test",
|
|
is_regex=False,
|
|
action="delete",
|
|
added_by=123456789012345678,
|
|
)
|
|
db_session.add(banned_word)
|
|
await db_session.commit()
|
|
await db_session.rollback()
|
|
|
|
|
|
class TestGuildConfigServiceWithDefaults:
|
|
"""Test GuildConfigService.create_guild() with settings defaults."""
|
|
|
|
async def test_create_guild_uses_settings_defaults(
|
|
self, test_database, settings_with_custom_defaults, sample_guild_id, sample_owner_id
|
|
):
|
|
"""Test create_guild applies settings.guild_default values."""
|
|
service = GuildConfigService(test_database, settings=settings_with_custom_defaults)
|
|
|
|
# Create mock Discord guild
|
|
mock_guild = MagicMock()
|
|
mock_guild.id = sample_guild_id
|
|
mock_guild.name = "Test Guild"
|
|
mock_guild.owner_id = sample_owner_id
|
|
|
|
# Create guild
|
|
db_guild = await service.create_guild(mock_guild)
|
|
|
|
# Verify guild was created
|
|
assert db_guild.id == sample_guild_id
|
|
assert db_guild.name == "Test Guild"
|
|
|
|
# Get settings and verify defaults were applied
|
|
guild_settings = await service.get_config(sample_guild_id)
|
|
assert guild_settings is not None
|
|
assert guild_settings.prefix == "?" # Custom default
|
|
assert guild_settings.ai_sensitivity == 50 # Custom default
|
|
assert guild_settings.automod_enabled is False # Custom default
|
|
assert guild_settings.verification_enabled is True # Custom default
|
|
assert guild_settings.verification_type == "captcha" # Custom default
|
|
|
|
async def test_create_guild_without_settings(
|
|
self, test_database, sample_guild_id, sample_owner_id
|
|
):
|
|
"""Test create_guild works when settings is None."""
|
|
service = GuildConfigService(test_database, settings=None)
|
|
|
|
# Create mock Discord guild
|
|
mock_guild = MagicMock()
|
|
mock_guild.id = sample_guild_id
|
|
mock_guild.name = "Test Guild"
|
|
mock_guild.owner_id = sample_owner_id
|
|
|
|
# Create guild
|
|
db_guild = await service.create_guild(mock_guild)
|
|
|
|
# Verify guild was created
|
|
assert db_guild.id == sample_guild_id
|
|
|
|
# Get settings and verify hardcoded defaults were used
|
|
guild_settings = await service.get_config(sample_guild_id)
|
|
assert guild_settings is not None
|
|
assert guild_settings.prefix == "!" # Hardcoded default
|
|
assert guild_settings.ai_sensitivity == 80 # Hardcoded default
|
|
assert guild_settings.automod_enabled is True # Hardcoded default
|
|
|
|
async def test_create_guild_existing_guild_unchanged(
|
|
self, test_database, settings_with_custom_defaults, sample_guild_id, sample_owner_id
|
|
):
|
|
"""Test create_guild returns existing guild without changes."""
|
|
service = GuildConfigService(test_database, settings=settings_with_custom_defaults)
|
|
|
|
# Create mock Discord guild
|
|
mock_guild = MagicMock()
|
|
mock_guild.id = sample_guild_id
|
|
mock_guild.name = "Test Guild"
|
|
mock_guild.owner_id = sample_owner_id
|
|
|
|
# Create guild first time
|
|
first_guild = await service.create_guild(mock_guild)
|
|
assert first_guild.id == sample_guild_id
|
|
|
|
# Try to create again with different name
|
|
mock_guild.name = "Different Name"
|
|
second_guild = await service.create_guild(mock_guild)
|
|
|
|
# Should return existing guild
|
|
assert second_guild.id == first_guild.id
|
|
assert second_guild.name == "Test Guild" # Original name
|
|
|
|
async def test_create_guild_with_standard_settings(
|
|
self, test_database, test_settings, sample_guild_id, sample_owner_id
|
|
):
|
|
"""Test create_guild with standard test_settings fixture."""
|
|
service = GuildConfigService(test_database, settings=test_settings)
|
|
|
|
# Create mock Discord guild
|
|
mock_guild = MagicMock()
|
|
mock_guild.id = sample_guild_id
|
|
mock_guild.name = "Test Guild"
|
|
mock_guild.owner_id = sample_owner_id
|
|
|
|
# Create guild
|
|
await service.create_guild(mock_guild)
|
|
|
|
# Get settings and verify standard defaults
|
|
guild_settings = await service.get_config(sample_guild_id)
|
|
assert guild_settings is not None
|
|
# Standard settings use default GuildDefaults
|
|
assert guild_settings.prefix == "!"
|
|
assert guild_settings.ai_sensitivity == 80
|