update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
This commit is contained in:
@@ -7,11 +7,11 @@ import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
@@ -23,7 +23,7 @@ if str(SRC_DIR) not in sys.path:
|
||||
# Import after path setup
|
||||
from guardden.config import Settings
|
||||
from guardden.models.base import Base
|
||||
from guardden.models.guild import Guild, GuildSettings, BannedWord
|
||||
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
from guardden.services.database import Database
|
||||
|
||||
@@ -52,6 +52,7 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
|
||||
# Basic Test Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guild_id() -> int:
|
||||
"""Return a sample Discord guild ID."""
|
||||
@@ -80,11 +81,12 @@ def sample_owner_id() -> int:
|
||||
# Configuration Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
"""Return test configuration settings."""
|
||||
return Settings(
|
||||
discord_token="test_token_12345678901234567890",
|
||||
discord_token="a" * 60,
|
||||
discord_prefix="!test",
|
||||
database_url="sqlite+aiosqlite:///test.db",
|
||||
database_pool_min=1,
|
||||
@@ -101,6 +103,7 @@ def test_settings() -> Settings:
|
||||
# Database Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]:
|
||||
"""Create a test database with in-memory SQLite."""
|
||||
@@ -111,19 +114,26 @@ async def test_database(test_settings: Settings) -> AsyncGenerator[Database, Non
|
||||
poolclass=StaticPool,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _enable_sqlite_foreign_keys(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("PRAGMA foreign_keys=ON"))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
database = Database(test_settings)
|
||||
database._engine = engine
|
||||
database._session_factory = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
yield database
|
||||
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@@ -138,10 +148,9 @@ async def db_session(test_database: Database) -> AsyncGenerator[AsyncSession, No
|
||||
# Model Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_guild(
|
||||
db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int
|
||||
) -> Guild:
|
||||
async def test_guild(db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int) -> Guild:
|
||||
"""Create a test guild with settings."""
|
||||
guild = Guild(
|
||||
id=sample_guild_id,
|
||||
@@ -150,7 +159,7 @@ async def test_guild(
|
||||
premium=False,
|
||||
)
|
||||
db_session.add(guild)
|
||||
|
||||
|
||||
# Create associated settings
|
||||
settings = GuildSettings(
|
||||
guild_id=sample_guild_id,
|
||||
@@ -160,7 +169,7 @@ async def test_guild(
|
||||
verification_enabled=False,
|
||||
)
|
||||
db_session.add(settings)
|
||||
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(guild)
|
||||
return guild
|
||||
@@ -187,10 +196,7 @@ async def test_banned_word(
|
||||
|
||||
@pytest.fixture
|
||||
async def test_moderation_log(
|
||||
db_session: AsyncSession,
|
||||
test_guild: Guild,
|
||||
sample_user_id: int,
|
||||
sample_moderator_id: int
|
||||
db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int
|
||||
) -> ModerationLog:
|
||||
"""Create a test moderation log entry."""
|
||||
mod_log = ModerationLog(
|
||||
@@ -211,10 +217,7 @@ async def test_moderation_log(
|
||||
|
||||
@pytest.fixture
|
||||
async def test_strike(
|
||||
db_session: AsyncSession,
|
||||
test_guild: Guild,
|
||||
sample_user_id: int,
|
||||
sample_moderator_id: int
|
||||
db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int
|
||||
) -> Strike:
|
||||
"""Create a test strike."""
|
||||
strike = Strike(
|
||||
@@ -236,6 +239,7 @@ async def test_strike(
|
||||
# Discord Mock Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(sample_user_id: int) -> MagicMock:
|
||||
"""Create a mock Discord user."""
|
||||
@@ -261,7 +265,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
|
||||
member.avatar = mock_discord_user.avatar
|
||||
member.bot = mock_discord_user.bot
|
||||
member.send = mock_discord_user.send
|
||||
|
||||
|
||||
# Member-specific attributes
|
||||
member.guild = MagicMock()
|
||||
member.top_role = MagicMock()
|
||||
@@ -271,7 +275,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
|
||||
member.kick = AsyncMock()
|
||||
member.ban = AsyncMock()
|
||||
member.timeout = AsyncMock()
|
||||
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@@ -284,14 +288,14 @@ def mock_discord_guild(sample_guild_id: int, sample_owner_id: int) -> MagicMock:
|
||||
guild.owner_id = sample_owner_id
|
||||
guild.member_count = 100
|
||||
guild.premium_tier = 0
|
||||
|
||||
|
||||
# Methods
|
||||
guild.get_member = MagicMock(return_value=None)
|
||||
guild.get_channel = MagicMock(return_value=None)
|
||||
guild.leave = AsyncMock()
|
||||
guild.ban = AsyncMock()
|
||||
guild.unban = AsyncMock()
|
||||
|
||||
|
||||
return guild
|
||||
|
||||
|
||||
@@ -327,9 +331,7 @@ def mock_discord_message(
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_context(
|
||||
mock_discord_member: MagicMock,
|
||||
mock_discord_guild: MagicMock,
|
||||
mock_discord_channel: MagicMock
|
||||
mock_discord_member: MagicMock, mock_discord_guild: MagicMock, mock_discord_channel: MagicMock
|
||||
) -> MagicMock:
|
||||
"""Create a mock Discord command context."""
|
||||
ctx = MagicMock()
|
||||
@@ -345,6 +347,7 @@ def mock_discord_context(
|
||||
# Bot and Service Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot(test_database: Database) -> MagicMock:
|
||||
"""Create a mock GuardDen bot."""
|
||||
@@ -363,6 +366,7 @@ def mock_bot(test_database: Database) -> MagicMock:
|
||||
# Test Environment Setup
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_environment() -> None:
|
||||
"""Set up test environment variables."""
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from guardden.config import Settings, _parse_id_list, _validate_discord_id, normalize_domain
|
||||
from guardden.config import Settings, _parse_id_list, _validate_discord_id
|
||||
from guardden.services.automod import normalize_domain
|
||||
|
||||
|
||||
class TestDiscordIdValidation:
|
||||
@@ -17,7 +18,7 @@ class TestDiscordIdValidation:
|
||||
"1234567890123456789", # 19 digits
|
||||
123456789012345678, # int format
|
||||
]
|
||||
|
||||
|
||||
for valid_id in valid_ids:
|
||||
result = _validate_discord_id(valid_id)
|
||||
assert isinstance(result, int)
|
||||
@@ -35,7 +36,7 @@ class TestDiscordIdValidation:
|
||||
"0", # zero
|
||||
"-123456789012345678", # negative
|
||||
]
|
||||
|
||||
|
||||
for invalid_id in invalid_ids:
|
||||
with pytest.raises(ValueError):
|
||||
_validate_discord_id(invalid_id)
|
||||
@@ -45,7 +46,7 @@ class TestDiscordIdValidation:
|
||||
# Too small (before Discord existed)
|
||||
with pytest.raises(ValueError):
|
||||
_validate_discord_id("99999999999999999")
|
||||
|
||||
|
||||
# Too large (exceeds 64-bit limit)
|
||||
with pytest.raises(ValueError):
|
||||
_validate_discord_id("99999999999999999999")
|
||||
@@ -64,7 +65,7 @@ class TestIdListParsing:
|
||||
("", []),
|
||||
(None, []),
|
||||
]
|
||||
|
||||
|
||||
for input_value, expected in test_cases:
|
||||
result = _parse_id_list(input_value)
|
||||
assert result == expected
|
||||
@@ -89,7 +90,7 @@ class TestIdListParsing:
|
||||
"123456789012345678\n234567890123456789", # newline
|
||||
"123456789012345678\r234567890123456789", # carriage return
|
||||
]
|
||||
|
||||
|
||||
for malicious_input in malicious_inputs:
|
||||
result = _parse_id_list(malicious_input)
|
||||
# Should filter out malicious entries
|
||||
@@ -106,7 +107,7 @@ class TestSettingsValidation:
|
||||
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
|
||||
"a" * 60, # minimum reasonable length
|
||||
]
|
||||
|
||||
|
||||
for token in valid_tokens:
|
||||
settings = Settings(discord_token=token)
|
||||
assert settings.discord_token.get_secret_value() == token
|
||||
@@ -119,7 +120,7 @@ class TestSettingsValidation:
|
||||
"token with spaces", # contains spaces
|
||||
"token\nwith\nnewlines", # contains newlines
|
||||
]
|
||||
|
||||
|
||||
for token in invalid_tokens:
|
||||
with pytest.raises(ValidationError):
|
||||
Settings(discord_token=token)
|
||||
@@ -131,7 +132,7 @@ class TestSettingsValidation:
|
||||
settings = Settings(
|
||||
discord_token="valid_token_" + "a" * 50,
|
||||
ai_provider="anthropic",
|
||||
anthropic_api_key=valid_key
|
||||
anthropic_api_key=valid_key,
|
||||
)
|
||||
assert settings.anthropic_api_key.get_secret_value() == valid_key
|
||||
|
||||
@@ -140,23 +141,23 @@ class TestSettingsValidation:
|
||||
Settings(
|
||||
discord_token="valid_token_" + "a" * 50,
|
||||
ai_provider="anthropic",
|
||||
anthropic_api_key="short"
|
||||
anthropic_api_key="short",
|
||||
)
|
||||
|
||||
def test_configuration_validation_ai_provider(self):
|
||||
"""Test AI provider configuration validation."""
|
||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||
|
||||
|
||||
# Should pass with no AI provider
|
||||
settings.ai_provider = "none"
|
||||
settings.validate_configuration()
|
||||
|
||||
|
||||
# Should fail with anthropic but no key
|
||||
settings.ai_provider = "anthropic"
|
||||
settings.anthropic_api_key = None
|
||||
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
|
||||
settings.validate_configuration()
|
||||
|
||||
|
||||
# Should pass with anthropic and key
|
||||
settings.anthropic_api_key = "sk-" + "a" * 50
|
||||
settings.validate_configuration()
|
||||
@@ -164,13 +165,13 @@ class TestSettingsValidation:
|
||||
def test_configuration_validation_database_pool(self):
|
||||
"""Test database pool configuration validation."""
|
||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||
|
||||
|
||||
# Should fail with min > max
|
||||
settings.database_pool_min = 10
|
||||
settings.database_pool_max = 5
|
||||
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
|
||||
settings.validate_configuration()
|
||||
|
||||
|
||||
# Should fail with min < 1
|
||||
settings.database_pool_min = 0
|
||||
settings.database_pool_max = 5
|
||||
@@ -190,7 +191,7 @@ class TestSecurityImprovements:
|
||||
"123456789012345678\x00\x01\x02",
|
||||
"123456789012345678<script>alert('xss')</script>",
|
||||
]
|
||||
|
||||
|
||||
for attempt in injection_attempts:
|
||||
# Should either raise an error or filter out the malicious input
|
||||
try:
|
||||
@@ -205,33 +206,41 @@ class TestSecurityImprovements:
|
||||
def test_settings_with_malicious_env_vars(self):
|
||||
"""Test that settings handle malicious environment variables."""
|
||||
import os
|
||||
|
||||
|
||||
# Save original values
|
||||
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
|
||||
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
|
||||
|
||||
|
||||
try:
|
||||
# Set malicious environment variables
|
||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
|
||||
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
|
||||
|
||||
try:
|
||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
|
||||
except ValueError:
|
||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678,malicious"
|
||||
try:
|
||||
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
|
||||
except ValueError:
|
||||
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789,567890123456789012"
|
||||
|
||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||
|
||||
|
||||
# Should filter out malicious entries
|
||||
assert len(settings.allowed_guilds) <= 1
|
||||
assert len(settings.owner_ids) <= 1
|
||||
|
||||
|
||||
# Valid IDs should be preserved
|
||||
assert 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
|
||||
|
||||
assert (
|
||||
123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_guilds is not None:
|
||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
|
||||
else:
|
||||
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
|
||||
|
||||
|
||||
if original_owners is not None:
|
||||
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
|
||||
else:
|
||||
os.environ.pop("GUARDDEN_OWNER_IDS", None)
|
||||
os.environ.pop("GUARDDEN_OWNER_IDS", None)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Tests for database integration and models."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from guardden.models.guild import Guild, GuildSettings, BannedWord
|
||||
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
from guardden.services.database import Database
|
||||
|
||||
@@ -21,7 +22,7 @@ class TestDatabaseModels:
|
||||
premium=False,
|
||||
)
|
||||
db_session.add(guild)
|
||||
|
||||
|
||||
settings = GuildSettings(
|
||||
guild_id=sample_guild_id,
|
||||
prefix="!",
|
||||
@@ -29,13 +30,13 @@ class TestDatabaseModels:
|
||||
ai_moderation_enabled=False,
|
||||
)
|
||||
db_session.add(settings)
|
||||
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Test guild was created
|
||||
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
|
||||
created_guild = result.scalar_one()
|
||||
|
||||
|
||||
assert created_guild.id == sample_guild_id
|
||||
assert created_guild.name == "Test Guild"
|
||||
assert created_guild.owner_id == sample_owner_id
|
||||
@@ -44,11 +45,9 @@ class TestDatabaseModels:
|
||||
async def test_guild_settings_relationship(self, test_guild, db_session):
|
||||
"""Test guild-settings relationship."""
|
||||
# Load guild with settings
|
||||
result = await db_session.execute(
|
||||
select(Guild).where(Guild.id == test_guild.id)
|
||||
)
|
||||
result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id))
|
||||
guild_with_settings = result.scalar_one()
|
||||
|
||||
|
||||
# Test relationship loading
|
||||
await db_session.refresh(guild_with_settings, ["settings"])
|
||||
assert guild_with_settings.settings is not None
|
||||
@@ -67,24 +66,20 @@ class TestDatabaseModels:
|
||||
)
|
||||
db_session.add(banned_word)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Verify creation
|
||||
result = await db_session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
||||
)
|
||||
created_word = result.scalar_one()
|
||||
|
||||
|
||||
assert created_word.pattern == "testbadword"
|
||||
assert not created_word.is_regex
|
||||
assert created_word.action == "delete"
|
||||
assert created_word.added_by == sample_moderator_id
|
||||
|
||||
async def test_moderation_log_creation(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||
):
|
||||
"""Test moderation log creation."""
|
||||
mod_log = ModerationLog(
|
||||
@@ -99,24 +94,20 @@ class TestDatabaseModels:
|
||||
)
|
||||
db_session.add(mod_log)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Verify creation
|
||||
result = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||
)
|
||||
created_log = result.scalar_one()
|
||||
|
||||
|
||||
assert created_log.action == "ban"
|
||||
assert created_log.target_id == sample_user_id
|
||||
assert created_log.moderator_id == sample_moderator_id
|
||||
assert not created_log.is_automatic
|
||||
|
||||
async def test_strike_creation(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||
):
|
||||
"""Test strike creation and tracking."""
|
||||
strike = Strike(
|
||||
@@ -130,26 +121,19 @@ class TestDatabaseModels:
|
||||
)
|
||||
db_session.add(strike)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Verify creation
|
||||
result = await db_session.execute(
|
||||
select(Strike).where(
|
||||
Strike.guild_id == test_guild.id,
|
||||
Strike.user_id == sample_user_id
|
||||
)
|
||||
select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id)
|
||||
)
|
||||
created_strike = result.scalar_one()
|
||||
|
||||
|
||||
assert created_strike.points == 1
|
||||
assert created_strike.is_active
|
||||
assert created_strike.user_id == sample_user_id
|
||||
|
||||
async def test_cascade_deletion(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||
):
|
||||
"""Test that deleting a guild cascades to related records."""
|
||||
# Add some related records
|
||||
@@ -160,7 +144,7 @@ class TestDatabaseModels:
|
||||
action="delete",
|
||||
added_by=sample_moderator_id,
|
||||
)
|
||||
|
||||
|
||||
mod_log = ModerationLog(
|
||||
guild_id=test_guild.id,
|
||||
target_id=sample_user_id,
|
||||
@@ -171,7 +155,7 @@ class TestDatabaseModels:
|
||||
reason="Test warning",
|
||||
is_automatic=False,
|
||||
)
|
||||
|
||||
|
||||
strike = Strike(
|
||||
guild_id=test_guild.id,
|
||||
user_id=sample_user_id,
|
||||
@@ -181,28 +165,26 @@ class TestDatabaseModels:
|
||||
points=1,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
|
||||
db_session.add_all([banned_word, mod_log, strike])
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Delete the guild
|
||||
await db_session.delete(test_guild)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Verify related records were deleted
|
||||
banned_words = await db_session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(banned_words.scalars().all()) == 0
|
||||
|
||||
|
||||
mod_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(mod_logs.scalars().all()) == 0
|
||||
|
||||
strikes = await db_session.execute(
|
||||
select(Strike).where(Strike.guild_id == test_guild.id)
|
||||
)
|
||||
|
||||
strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id))
|
||||
assert len(strikes.scalars().all()) == 0
|
||||
|
||||
|
||||
@@ -210,11 +192,7 @@ class TestDatabaseIndexes:
|
||||
"""Test that database indexes work as expected."""
|
||||
|
||||
async def test_moderation_log_indexes(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||
):
|
||||
"""Test moderation log indexing for performance."""
|
||||
# Create multiple moderation logs
|
||||
@@ -231,23 +209,23 @@ class TestDatabaseIndexes:
|
||||
is_automatic=bool(i % 2),
|
||||
)
|
||||
logs.append(log)
|
||||
|
||||
|
||||
db_session.add_all(logs)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Test queries that should use indexes
|
||||
# Query by guild_id
|
||||
guild_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(guild_logs.scalars().all()) == 10
|
||||
|
||||
|
||||
# Query by target_id
|
||||
target_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
|
||||
)
|
||||
assert len(target_logs.scalars().all()) == 1
|
||||
|
||||
|
||||
# Query by is_automatic
|
||||
auto_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.is_automatic == True)
|
||||
@@ -255,11 +233,7 @@ class TestDatabaseIndexes:
|
||||
assert len(auto_logs.scalars().all()) == 5
|
||||
|
||||
async def test_strike_indexes(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||
):
|
||||
"""Test strike indexing for performance."""
|
||||
# Create multiple strikes
|
||||
@@ -275,18 +249,15 @@ class TestDatabaseIndexes:
|
||||
is_active=bool(i % 2),
|
||||
)
|
||||
strikes.append(strike)
|
||||
|
||||
|
||||
db_session.add_all(strikes)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Test active strikes query
|
||||
active_strikes = await db_session.execute(
|
||||
select(Strike).where(
|
||||
Strike.guild_id == test_guild.id,
|
||||
Strike.is_active == True
|
||||
)
|
||||
select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True)
|
||||
)
|
||||
assert len(active_strikes.scalars().all()) == 3 # indices 1, 3
|
||||
assert len(active_strikes.scalars().all()) == 2 # indices 1, 3
|
||||
|
||||
|
||||
class TestDatabaseSecurity:
|
||||
@@ -304,11 +275,9 @@ class TestDatabaseSecurity:
|
||||
)
|
||||
db_session.add(guild)
|
||||
await db_session.commit()
|
||||
|
||||
|
||||
# Verify it was stored correctly
|
||||
result = await db_session.execute(
|
||||
select(Guild).where(Guild.id == valid_guild_id)
|
||||
)
|
||||
result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id))
|
||||
stored_guild = result.scalar_one()
|
||||
assert stored_guild.id == valid_guild_id
|
||||
|
||||
@@ -321,13 +290,11 @@ class TestDatabaseSecurity:
|
||||
"' OR '1'='1",
|
||||
"<script>alert('xss')</script>",
|
||||
]
|
||||
|
||||
|
||||
for malicious_input in malicious_inputs:
|
||||
# Try to use malicious input in a query
|
||||
# SQLAlchemy should prevent injection through parameterized queries
|
||||
result = await db_session.execute(
|
||||
select(Guild).where(Guild.name == malicious_input)
|
||||
)
|
||||
result = await db_session.execute(select(Guild).where(Guild.name == malicious_input))
|
||||
# Should not find anything (and not crash)
|
||||
assert result.scalar_one_or_none() is None
|
||||
|
||||
@@ -343,4 +310,5 @@ class TestDatabaseSecurity:
|
||||
added_by=123456789012345678,
|
||||
)
|
||||
db_session.add(banned_word)
|
||||
await db_session.commit()
|
||||
await db_session.commit()
|
||||
await db_session.rollback()
|
||||
|
||||
Reference in New Issue
Block a user