update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s

This commit is contained in:
2026-01-17 21:57:04 +01:00
parent 831eed8dbc
commit abef368a68
19 changed files with 677 additions and 757 deletions

View File

@@ -1,10 +1,11 @@
"""Tests for database integration and models."""
import pytest
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from guardden.models.guild import Guild, GuildSettings, BannedWord
from guardden.models.guild import BannedWord, Guild, GuildSettings
from guardden.models.moderation import ModerationLog, Strike, UserNote
from guardden.services.database import Database
@@ -21,7 +22,7 @@ class TestDatabaseModels:
premium=False,
)
db_session.add(guild)
settings = GuildSettings(
guild_id=sample_guild_id,
prefix="!",
@@ -29,13 +30,13 @@ class TestDatabaseModels:
ai_moderation_enabled=False,
)
db_session.add(settings)
await db_session.commit()
# Test guild was created
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
created_guild = result.scalar_one()
assert created_guild.id == sample_guild_id
assert created_guild.name == "Test Guild"
assert created_guild.owner_id == sample_owner_id
@@ -44,11 +45,9 @@ class TestDatabaseModels:
async def test_guild_settings_relationship(self, test_guild, db_session):
"""Test guild-settings relationship."""
# Load guild with settings
result = await db_session.execute(
select(Guild).where(Guild.id == test_guild.id)
)
result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id))
guild_with_settings = result.scalar_one()
# Test relationship loading
await db_session.refresh(guild_with_settings, ["settings"])
assert guild_with_settings.settings is not None
@@ -67,24 +66,20 @@ class TestDatabaseModels:
)
db_session.add(banned_word)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
)
created_word = result.scalar_one()
assert created_word.pattern == "testbadword"
assert not created_word.is_regex
assert created_word.action == "delete"
assert created_word.added_by == sample_moderator_id
async def test_moderation_log_creation(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test moderation log creation."""
mod_log = ModerationLog(
@@ -99,24 +94,20 @@ class TestDatabaseModels:
)
db_session.add(mod_log)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
created_log = result.scalar_one()
assert created_log.action == "ban"
assert created_log.target_id == sample_user_id
assert created_log.moderator_id == sample_moderator_id
assert not created_log.is_automatic
async def test_strike_creation(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test strike creation and tracking."""
strike = Strike(
@@ -130,26 +121,19 @@ class TestDatabaseModels:
)
db_session.add(strike)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(Strike).where(
Strike.guild_id == test_guild.id,
Strike.user_id == sample_user_id
)
select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id)
)
created_strike = result.scalar_one()
assert created_strike.points == 1
assert created_strike.is_active
assert created_strike.user_id == sample_user_id
async def test_cascade_deletion(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test that deleting a guild cascades to related records."""
# Add some related records
@@ -160,7 +144,7 @@ class TestDatabaseModels:
action="delete",
added_by=sample_moderator_id,
)
mod_log = ModerationLog(
guild_id=test_guild.id,
target_id=sample_user_id,
@@ -171,7 +155,7 @@ class TestDatabaseModels:
reason="Test warning",
is_automatic=False,
)
strike = Strike(
guild_id=test_guild.id,
user_id=sample_user_id,
@@ -181,28 +165,26 @@ class TestDatabaseModels:
points=1,
is_active=True,
)
db_session.add_all([banned_word, mod_log, strike])
await db_session.commit()
# Delete the guild
await db_session.delete(test_guild)
await db_session.commit()
# Verify related records were deleted
banned_words = await db_session.execute(
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
)
assert len(banned_words.scalars().all()) == 0
mod_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
assert len(mod_logs.scalars().all()) == 0
strikes = await db_session.execute(
select(Strike).where(Strike.guild_id == test_guild.id)
)
strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id))
assert len(strikes.scalars().all()) == 0
@@ -210,11 +192,7 @@ class TestDatabaseIndexes:
"""Test that database indexes work as expected."""
async def test_moderation_log_indexes(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test moderation log indexing for performance."""
# Create multiple moderation logs
@@ -231,23 +209,23 @@ class TestDatabaseIndexes:
is_automatic=bool(i % 2),
)
logs.append(log)
db_session.add_all(logs)
await db_session.commit()
# Test queries that should use indexes
# Query by guild_id
guild_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
assert len(guild_logs.scalars().all()) == 10
# Query by target_id
target_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
)
assert len(target_logs.scalars().all()) == 1
# Query by is_automatic
auto_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.is_automatic == True)
@@ -255,11 +233,7 @@ class TestDatabaseIndexes:
assert len(auto_logs.scalars().all()) == 5
async def test_strike_indexes(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test strike indexing for performance."""
# Create multiple strikes
@@ -275,18 +249,15 @@ class TestDatabaseIndexes:
is_active=bool(i % 2),
)
strikes.append(strike)
db_session.add_all(strikes)
await db_session.commit()
# Test active strikes query
active_strikes = await db_session.execute(
select(Strike).where(
Strike.guild_id == test_guild.id,
Strike.is_active == True
)
select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True)
)
assert len(active_strikes.scalars().all()) == 3 # indices 1, 3
assert len(active_strikes.scalars().all()) == 2 # indices 1, 3
class TestDatabaseSecurity:
@@ -304,11 +275,9 @@ class TestDatabaseSecurity:
)
db_session.add(guild)
await db_session.commit()
# Verify it was stored correctly
result = await db_session.execute(
select(Guild).where(Guild.id == valid_guild_id)
)
result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id))
stored_guild = result.scalar_one()
assert stored_guild.id == valid_guild_id
@@ -321,13 +290,11 @@ class TestDatabaseSecurity:
"' OR '1'='1",
"<script>alert('xss')</script>",
]
for malicious_input in malicious_inputs:
# Try to use malicious input in a query
# SQLAlchemy should prevent injection through parameterized queries
result = await db_session.execute(
select(Guild).where(Guild.name == malicious_input)
)
result = await db_session.execute(select(Guild).where(Guild.name == malicious_input))
# Should not find anything (and not crash)
assert result.scalar_one_or_none() is None
@@ -343,4 +310,5 @@ class TestDatabaseSecurity:
added_by=123456789012345678,
)
db_session.add(banned_word)
await db_session.commit()
await db_session.commit()
await db_session.rollback()