update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s

This commit is contained in:
2026-01-17 21:57:04 +01:00
parent 831eed8dbc
commit abef368a68
19 changed files with 677 additions and 757 deletions

View File

@@ -7,11 +7,11 @@ import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
@@ -23,7 +23,7 @@ if str(SRC_DIR) not in sys.path:
# Import after path setup
from guardden.config import Settings
from guardden.models.base import Base
from guardden.models.guild import Guild, GuildSettings, BannedWord
from guardden.models.guild import BannedWord, Guild, GuildSettings
from guardden.models.moderation import ModerationLog, Strike, UserNote
from guardden.services.database import Database
@@ -52,6 +52,7 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
# Basic Test Fixtures
# ==============================================================================
@pytest.fixture
def sample_guild_id() -> int:
"""Return a sample Discord guild ID."""
@@ -80,11 +81,12 @@ def sample_owner_id() -> int:
# Configuration Fixtures
# ==============================================================================
@pytest.fixture
def test_settings() -> Settings:
"""Return test configuration settings."""
return Settings(
discord_token="test_token_12345678901234567890",
discord_token="a" * 60,
discord_prefix="!test",
database_url="sqlite+aiosqlite:///test.db",
database_pool_min=1,
@@ -101,6 +103,7 @@ def test_settings() -> Settings:
# Database Fixtures
# ==============================================================================
@pytest.fixture
async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]:
"""Create a test database with in-memory SQLite."""
@@ -111,19 +114,26 @@ async def test_database(test_settings: Settings) -> AsyncGenerator[Database, Non
poolclass=StaticPool,
echo=False,
)
@event.listens_for(engine.sync_engine, "connect")
def _enable_sqlite_foreign_keys(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
# Create all tables
async with engine.begin() as conn:
await conn.execute(text("PRAGMA foreign_keys=ON"))
await conn.run_sync(Base.metadata.create_all)
database = Database(test_settings)
database._engine = engine
database._session_factory = async_sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
yield database
await engine.dispose()
@@ -138,10 +148,9 @@ async def db_session(test_database: Database) -> AsyncGenerator[AsyncSession, No
# Model Fixtures
# ==============================================================================
@pytest.fixture
async def test_guild(
db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int
) -> Guild:
async def test_guild(db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int) -> Guild:
"""Create a test guild with settings."""
guild = Guild(
id=sample_guild_id,
@@ -150,7 +159,7 @@ async def test_guild(
premium=False,
)
db_session.add(guild)
# Create associated settings
settings = GuildSettings(
guild_id=sample_guild_id,
@@ -160,7 +169,7 @@ async def test_guild(
verification_enabled=False,
)
db_session.add(settings)
await db_session.commit()
await db_session.refresh(guild)
return guild
@@ -187,10 +196,7 @@ async def test_banned_word(
@pytest.fixture
async def test_moderation_log(
db_session: AsyncSession,
test_guild: Guild,
sample_user_id: int,
sample_moderator_id: int
db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int
) -> ModerationLog:
"""Create a test moderation log entry."""
mod_log = ModerationLog(
@@ -211,10 +217,7 @@ async def test_moderation_log(
@pytest.fixture
async def test_strike(
db_session: AsyncSession,
test_guild: Guild,
sample_user_id: int,
sample_moderator_id: int
db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int
) -> Strike:
"""Create a test strike."""
strike = Strike(
@@ -236,6 +239,7 @@ async def test_strike(
# Discord Mock Fixtures
# ==============================================================================
@pytest.fixture
def mock_discord_user(sample_user_id: int) -> MagicMock:
"""Create a mock Discord user."""
@@ -261,7 +265,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
member.avatar = mock_discord_user.avatar
member.bot = mock_discord_user.bot
member.send = mock_discord_user.send
# Member-specific attributes
member.guild = MagicMock()
member.top_role = MagicMock()
@@ -271,7 +275,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
member.kick = AsyncMock()
member.ban = AsyncMock()
member.timeout = AsyncMock()
return member
@@ -284,14 +288,14 @@ def mock_discord_guild(sample_guild_id: int, sample_owner_id: int) -> MagicMock:
guild.owner_id = sample_owner_id
guild.member_count = 100
guild.premium_tier = 0
# Methods
guild.get_member = MagicMock(return_value=None)
guild.get_channel = MagicMock(return_value=None)
guild.leave = AsyncMock()
guild.ban = AsyncMock()
guild.unban = AsyncMock()
return guild
@@ -327,9 +331,7 @@ def mock_discord_message(
@pytest.fixture
def mock_discord_context(
mock_discord_member: MagicMock,
mock_discord_guild: MagicMock,
mock_discord_channel: MagicMock
mock_discord_member: MagicMock, mock_discord_guild: MagicMock, mock_discord_channel: MagicMock
) -> MagicMock:
"""Create a mock Discord command context."""
ctx = MagicMock()
@@ -345,6 +347,7 @@ def mock_discord_context(
# Bot and Service Fixtures
# ==============================================================================
@pytest.fixture
def mock_bot(test_database: Database) -> MagicMock:
"""Create a mock GuardDen bot."""
@@ -363,6 +366,7 @@ def mock_bot(test_database: Database) -> MagicMock:
# Test Environment Setup
# ==============================================================================
@pytest.fixture(autouse=True)
def setup_test_environment() -> None:
"""Set up test environment variables."""

View File

@@ -3,7 +3,8 @@
import pytest
from pydantic import ValidationError
from guardden.config import Settings, _parse_id_list, _validate_discord_id, normalize_domain
from guardden.config import Settings, _parse_id_list, _validate_discord_id
from guardden.services.automod import normalize_domain
class TestDiscordIdValidation:
@@ -17,7 +18,7 @@ class TestDiscordIdValidation:
"1234567890123456789", # 19 digits
123456789012345678, # int format
]
for valid_id in valid_ids:
result = _validate_discord_id(valid_id)
assert isinstance(result, int)
@@ -35,7 +36,7 @@ class TestDiscordIdValidation:
"0", # zero
"-123456789012345678", # negative
]
for invalid_id in invalid_ids:
with pytest.raises(ValueError):
_validate_discord_id(invalid_id)
@@ -45,7 +46,7 @@ class TestDiscordIdValidation:
# Too small (before Discord existed)
with pytest.raises(ValueError):
_validate_discord_id("99999999999999999")
# Too large (exceeds 64-bit limit)
with pytest.raises(ValueError):
_validate_discord_id("99999999999999999999")
@@ -64,7 +65,7 @@ class TestIdListParsing:
("", []),
(None, []),
]
for input_value, expected in test_cases:
result = _parse_id_list(input_value)
assert result == expected
@@ -89,7 +90,7 @@ class TestIdListParsing:
"123456789012345678\n234567890123456789", # newline
"123456789012345678\r234567890123456789", # carriage return
]
for malicious_input in malicious_inputs:
result = _parse_id_list(malicious_input)
# Should filter out malicious entries
@@ -106,7 +107,7 @@ class TestSettingsValidation:
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
"a" * 60, # minimum reasonable length
]
for token in valid_tokens:
settings = Settings(discord_token=token)
assert settings.discord_token.get_secret_value() == token
@@ -119,7 +120,7 @@ class TestSettingsValidation:
"token with spaces", # contains spaces
"token\nwith\nnewlines", # contains newlines
]
for token in invalid_tokens:
with pytest.raises(ValidationError):
Settings(discord_token=token)
@@ -131,7 +132,7 @@ class TestSettingsValidation:
settings = Settings(
discord_token="valid_token_" + "a" * 50,
ai_provider="anthropic",
anthropic_api_key=valid_key
anthropic_api_key=valid_key,
)
assert settings.anthropic_api_key.get_secret_value() == valid_key
@@ -140,23 +141,23 @@ class TestSettingsValidation:
Settings(
discord_token="valid_token_" + "a" * 50,
ai_provider="anthropic",
anthropic_api_key="short"
anthropic_api_key="short",
)
def test_configuration_validation_ai_provider(self):
"""Test AI provider configuration validation."""
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should pass with no AI provider
settings.ai_provider = "none"
settings.validate_configuration()
# Should fail with anthropic but no key
settings.ai_provider = "anthropic"
settings.anthropic_api_key = None
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
settings.validate_configuration()
# Should pass with anthropic and key
settings.anthropic_api_key = "sk-" + "a" * 50
settings.validate_configuration()
@@ -164,13 +165,13 @@ class TestSettingsValidation:
def test_configuration_validation_database_pool(self):
"""Test database pool configuration validation."""
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should fail with min > max
settings.database_pool_min = 10
settings.database_pool_max = 5
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
settings.validate_configuration()
# Should fail with min < 1
settings.database_pool_min = 0
settings.database_pool_max = 5
@@ -190,7 +191,7 @@ class TestSecurityImprovements:
"123456789012345678\x00\x01\x02",
"123456789012345678<script>alert('xss')</script>",
]
for attempt in injection_attempts:
# Should either raise an error or filter out the malicious input
try:
@@ -205,33 +206,41 @@ class TestSecurityImprovements:
def test_settings_with_malicious_env_vars(self):
"""Test that settings handle malicious environment variables."""
import os
# Save original values
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
try:
# Set malicious environment variables
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
try:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
except ValueError:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678,malicious"
try:
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
except ValueError:
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789,567890123456789012"
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should filter out malicious entries
assert len(settings.allowed_guilds) <= 1
assert len(settings.owner_ids) <= 1
# Valid IDs should be preserved
assert 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
assert (
123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
)
finally:
# Restore original values
if original_guilds is not None:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
else:
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
if original_owners is not None:
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
else:
os.environ.pop("GUARDDEN_OWNER_IDS", None)
os.environ.pop("GUARDDEN_OWNER_IDS", None)

View File

@@ -1,10 +1,11 @@
"""Tests for database integration and models."""
import pytest
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from guardden.models.guild import Guild, GuildSettings, BannedWord
from guardden.models.guild import BannedWord, Guild, GuildSettings
from guardden.models.moderation import ModerationLog, Strike, UserNote
from guardden.services.database import Database
@@ -21,7 +22,7 @@ class TestDatabaseModels:
premium=False,
)
db_session.add(guild)
settings = GuildSettings(
guild_id=sample_guild_id,
prefix="!",
@@ -29,13 +30,13 @@ class TestDatabaseModels:
ai_moderation_enabled=False,
)
db_session.add(settings)
await db_session.commit()
# Test guild was created
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
created_guild = result.scalar_one()
assert created_guild.id == sample_guild_id
assert created_guild.name == "Test Guild"
assert created_guild.owner_id == sample_owner_id
@@ -44,11 +45,9 @@ class TestDatabaseModels:
async def test_guild_settings_relationship(self, test_guild, db_session):
"""Test guild-settings relationship."""
# Load guild with settings
result = await db_session.execute(
select(Guild).where(Guild.id == test_guild.id)
)
result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id))
guild_with_settings = result.scalar_one()
# Test relationship loading
await db_session.refresh(guild_with_settings, ["settings"])
assert guild_with_settings.settings is not None
@@ -67,24 +66,20 @@ class TestDatabaseModels:
)
db_session.add(banned_word)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
)
created_word = result.scalar_one()
assert created_word.pattern == "testbadword"
assert not created_word.is_regex
assert created_word.action == "delete"
assert created_word.added_by == sample_moderator_id
async def test_moderation_log_creation(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test moderation log creation."""
mod_log = ModerationLog(
@@ -99,24 +94,20 @@ class TestDatabaseModels:
)
db_session.add(mod_log)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
created_log = result.scalar_one()
assert created_log.action == "ban"
assert created_log.target_id == sample_user_id
assert created_log.moderator_id == sample_moderator_id
assert not created_log.is_automatic
async def test_strike_creation(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test strike creation and tracking."""
strike = Strike(
@@ -130,26 +121,19 @@ class TestDatabaseModels:
)
db_session.add(strike)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(Strike).where(
Strike.guild_id == test_guild.id,
Strike.user_id == sample_user_id
)
select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id)
)
created_strike = result.scalar_one()
assert created_strike.points == 1
assert created_strike.is_active
assert created_strike.user_id == sample_user_id
async def test_cascade_deletion(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test that deleting a guild cascades to related records."""
# Add some related records
@@ -160,7 +144,7 @@ class TestDatabaseModels:
action="delete",
added_by=sample_moderator_id,
)
mod_log = ModerationLog(
guild_id=test_guild.id,
target_id=sample_user_id,
@@ -171,7 +155,7 @@ class TestDatabaseModels:
reason="Test warning",
is_automatic=False,
)
strike = Strike(
guild_id=test_guild.id,
user_id=sample_user_id,
@@ -181,28 +165,26 @@ class TestDatabaseModels:
points=1,
is_active=True,
)
db_session.add_all([banned_word, mod_log, strike])
await db_session.commit()
# Delete the guild
await db_session.delete(test_guild)
await db_session.commit()
# Verify related records were deleted
banned_words = await db_session.execute(
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
)
assert len(banned_words.scalars().all()) == 0
mod_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
assert len(mod_logs.scalars().all()) == 0
strikes = await db_session.execute(
select(Strike).where(Strike.guild_id == test_guild.id)
)
strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id))
assert len(strikes.scalars().all()) == 0
@@ -210,11 +192,7 @@ class TestDatabaseIndexes:
"""Test that database indexes work as expected."""
async def test_moderation_log_indexes(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test moderation log indexing for performance."""
# Create multiple moderation logs
@@ -231,23 +209,23 @@ class TestDatabaseIndexes:
is_automatic=bool(i % 2),
)
logs.append(log)
db_session.add_all(logs)
await db_session.commit()
# Test queries that should use indexes
# Query by guild_id
guild_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
assert len(guild_logs.scalars().all()) == 10
# Query by target_id
target_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
)
assert len(target_logs.scalars().all()) == 1
# Query by is_automatic
auto_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.is_automatic == True)
@@ -255,11 +233,7 @@ class TestDatabaseIndexes:
assert len(auto_logs.scalars().all()) == 5
async def test_strike_indexes(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
self, test_guild, db_session, sample_user_id, sample_moderator_id
):
"""Test strike indexing for performance."""
# Create multiple strikes
@@ -275,18 +249,15 @@ class TestDatabaseIndexes:
is_active=bool(i % 2),
)
strikes.append(strike)
db_session.add_all(strikes)
await db_session.commit()
# Test active strikes query
active_strikes = await db_session.execute(
select(Strike).where(
Strike.guild_id == test_guild.id,
Strike.is_active == True
)
select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True)
)
assert len(active_strikes.scalars().all()) == 3 # indices 1, 3
assert len(active_strikes.scalars().all()) == 2 # indices 1, 3
class TestDatabaseSecurity:
@@ -304,11 +275,9 @@ class TestDatabaseSecurity:
)
db_session.add(guild)
await db_session.commit()
# Verify it was stored correctly
result = await db_session.execute(
select(Guild).where(Guild.id == valid_guild_id)
)
result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id))
stored_guild = result.scalar_one()
assert stored_guild.id == valid_guild_id
@@ -321,13 +290,11 @@ class TestDatabaseSecurity:
"' OR '1'='1",
"<script>alert('xss')</script>",
]
for malicious_input in malicious_inputs:
# Try to use malicious input in a query
# SQLAlchemy should prevent injection through parameterized queries
result = await db_session.execute(
select(Guild).where(Guild.name == malicious_input)
)
result = await db_session.execute(select(Guild).where(Guild.name == malicious_input))
# Should not find anything (and not crash)
assert result.scalar_one_or_none() is None
@@ -343,4 +310,5 @@ class TestDatabaseSecurity:
added_by=123456789012345678,
)
db_session.add(banned_word)
await db_session.commit()
await db_session.commit()
await db_session.rollback()