quick commit
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s

This commit is contained in:
2026-01-17 20:24:43 +01:00
parent 95cc3cdb8f
commit 831eed8dbc
82 changed files with 8860 additions and 167 deletions

View File

@@ -1,7 +1,56 @@
"""Pytest fixtures for GuardDen tests."""
import pytest
import asyncio
import inspect
import os
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from typing import AsyncGenerator
import pytest
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
ROOT_DIR = Path(__file__).resolve().parents[1]
SRC_DIR = ROOT_DIR / "src"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
# Import after path setup
from guardden.config import Settings
from guardden.models.base import Base
from guardden.models.guild import Guild, GuildSettings, BannedWord
from guardden.models.moderation import ModerationLog, Strike, UserNote
from guardden.services.database import Database
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addini("asyncio_mode", "Asyncio mode for tests", default="auto")
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("markers", "asyncio: mark async tests")
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
test_function = pyfuncitem.obj
if inspect.iscoroutinefunction(test_function):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(test_function(**pyfuncitem.funcargs))
loop.close()
asyncio.set_event_loop(None)
return True
return None
# ==============================================================================
# Basic Test Fixtures
# ==============================================================================
@pytest.fixture
def sample_guild_id() -> int:
@@ -13,3 +62,320 @@ def sample_guild_id() -> int:
def sample_user_id() -> int:
"""Return a sample Discord user ID."""
return 987654321098765432
@pytest.fixture
def sample_moderator_id() -> int:
"""Return a sample Discord moderator ID."""
return 111111111111111111
@pytest.fixture
def sample_owner_id() -> int:
"""Return a sample Discord owner ID."""
return 222222222222222222
# ==============================================================================
# Configuration Fixtures
# ==============================================================================
@pytest.fixture
def test_settings() -> Settings:
"""Return test configuration settings."""
return Settings(
discord_token="test_token_12345678901234567890",
discord_prefix="!test",
database_url="sqlite+aiosqlite:///test.db",
database_pool_min=1,
database_pool_max=1,
ai_provider="none",
log_level="DEBUG",
allowed_guilds=[],
owner_ids=[],
data_dir=Path("/tmp/guardden_test"),
)
# ==============================================================================
# Database Fixtures
# ==============================================================================
@pytest.fixture
async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]:
"""Create a test database with in-memory SQLite."""
# Use in-memory SQLite for tests
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
database = Database(test_settings)
database._engine = engine
database._session_factory = async_sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
yield database
await engine.dispose()
@pytest.fixture
async def db_session(test_database: Database) -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing."""
async with test_database.session() as session:
yield session
# ==============================================================================
# Model Fixtures
# ==============================================================================
@pytest.fixture
async def test_guild(
db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int
) -> Guild:
"""Create a test guild with settings."""
guild = Guild(
id=sample_guild_id,
name="Test Guild",
owner_id=sample_owner_id,
premium=False,
)
db_session.add(guild)
# Create associated settings
settings = GuildSettings(
guild_id=sample_guild_id,
prefix="!",
automod_enabled=True,
ai_moderation_enabled=False,
verification_enabled=False,
)
db_session.add(settings)
await db_session.commit()
await db_session.refresh(guild)
return guild
@pytest.fixture
async def test_banned_word(
db_session: AsyncSession, test_guild: Guild, sample_moderator_id: int
) -> BannedWord:
"""Create a test banned word."""
banned_word = BannedWord(
guild_id=test_guild.id,
pattern="badword",
is_regex=False,
action="delete",
reason="Inappropriate content",
added_by=sample_moderator_id,
)
db_session.add(banned_word)
await db_session.commit()
await db_session.refresh(banned_word)
return banned_word
@pytest.fixture
async def test_moderation_log(
db_session: AsyncSession,
test_guild: Guild,
sample_user_id: int,
sample_moderator_id: int
) -> ModerationLog:
"""Create a test moderation log entry."""
mod_log = ModerationLog(
guild_id=test_guild.id,
target_id=sample_user_id,
target_name="TestUser",
moderator_id=sample_moderator_id,
moderator_name="TestModerator",
action="warn",
reason="Test warning",
is_automatic=False,
)
db_session.add(mod_log)
await db_session.commit()
await db_session.refresh(mod_log)
return mod_log
@pytest.fixture
async def test_strike(
db_session: AsyncSession,
test_guild: Guild,
sample_user_id: int,
sample_moderator_id: int
) -> Strike:
"""Create a test strike."""
strike = Strike(
guild_id=test_guild.id,
user_id=sample_user_id,
user_name="TestUser",
moderator_id=sample_moderator_id,
reason="Test strike",
points=1,
is_active=True,
)
db_session.add(strike)
await db_session.commit()
await db_session.refresh(strike)
return strike
# ==============================================================================
# Discord Mock Fixtures
# ==============================================================================
@pytest.fixture
def mock_discord_user(sample_user_id: int) -> MagicMock:
"""Create a mock Discord user."""
user = MagicMock()
user.id = sample_user_id
user.name = "TestUser"
user.display_name = "Test User"
user.mention = f"<@{sample_user_id}>"
user.avatar = None
user.bot = False
user.send = AsyncMock()
return user
@pytest.fixture
def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
"""Create a mock Discord member."""
member = MagicMock()
member.id = mock_discord_user.id
member.name = mock_discord_user.name
member.display_name = mock_discord_user.display_name
member.mention = mock_discord_user.mention
member.avatar = mock_discord_user.avatar
member.bot = mock_discord_user.bot
member.send = mock_discord_user.send
# Member-specific attributes
member.guild = MagicMock()
member.top_role = MagicMock()
member.top_role.position = 1
member.roles = [MagicMock()]
member.joined_at = datetime.now(timezone.utc)
member.kick = AsyncMock()
member.ban = AsyncMock()
member.timeout = AsyncMock()
return member
@pytest.fixture
def mock_discord_guild(sample_guild_id: int, sample_owner_id: int) -> MagicMock:
"""Create a mock Discord guild."""
guild = MagicMock()
guild.id = sample_guild_id
guild.name = "Test Guild"
guild.owner_id = sample_owner_id
guild.member_count = 100
guild.premium_tier = 0
# Methods
guild.get_member = MagicMock(return_value=None)
guild.get_channel = MagicMock(return_value=None)
guild.leave = AsyncMock()
guild.ban = AsyncMock()
guild.unban = AsyncMock()
return guild
@pytest.fixture
def mock_discord_channel() -> MagicMock:
"""Create a mock Discord channel."""
channel = MagicMock()
channel.id = 333333333333333333
channel.name = "test-channel"
channel.mention = "<#333333333333333333>"
channel.send = AsyncMock()
channel.delete_messages = AsyncMock()
return channel
@pytest.fixture
def mock_discord_message(
mock_discord_member: MagicMock, mock_discord_channel: MagicMock
) -> MagicMock:
"""Create a mock Discord message."""
message = MagicMock()
message.id = 444444444444444444
message.content = "Test message content"
message.author = mock_discord_member
message.channel = mock_discord_channel
message.guild = mock_discord_member.guild
message.created_at = datetime.now(timezone.utc)
message.delete = AsyncMock()
message.reply = AsyncMock()
message.add_reaction = AsyncMock()
return message
@pytest.fixture
def mock_discord_context(
mock_discord_member: MagicMock,
mock_discord_guild: MagicMock,
mock_discord_channel: MagicMock
) -> MagicMock:
"""Create a mock Discord command context."""
ctx = MagicMock()
ctx.author = mock_discord_member
ctx.guild = mock_discord_guild
ctx.channel = mock_discord_channel
ctx.send = AsyncMock()
ctx.reply = AsyncMock()
return ctx
# ==============================================================================
# Bot and Service Fixtures
# ==============================================================================
@pytest.fixture
def mock_bot(test_database: Database) -> MagicMock:
"""Create a mock GuardDen bot."""
bot = MagicMock()
bot.database = test_database
bot.guild_config = MagicMock()
bot.ai_provider = MagicMock()
bot.rate_limiter = MagicMock()
bot.user = MagicMock()
bot.user.id = 555555555555555555
bot.user.name = "GuardDen"
return bot
# ==============================================================================
# Test Environment Setup
# ==============================================================================
@pytest.fixture(autouse=True)
def setup_test_environment() -> None:
"""Set up test environment variables."""
# Set test environment variables
os.environ["GUARDDEN_DISCORD_TOKEN"] = "test_token_12345678901234567890"
os.environ["GUARDDEN_DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
os.environ["GUARDDEN_AI_PROVIDER"] = "none"
os.environ["GUARDDEN_LOG_LEVEL"] = "DEBUG"
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()

View File

@@ -2,7 +2,7 @@
import pytest
from guardden.services.ai.base import ContentCategory, ModerationResult
from guardden.services.ai.base import ContentCategory, ModerationResult, parse_categories
from guardden.services.ai.factory import NullProvider, create_ai_provider
@@ -69,6 +69,14 @@ class TestModerationResult:
assert result.severity == 100
class TestParseCategories:
"""Tests for category parsing helper."""
def test_parse_categories_filters_invalid(self) -> None:
categories = parse_categories(["harassment", "unknown", "scam"])
assert categories == [ContentCategory.HARASSMENT, ContentCategory.SCAM]
class TestNullProvider:
"""Tests for NullProvider."""

View File

@@ -2,7 +2,6 @@
import pytest
from guardden.models import BannedWord
from guardden.services.automod import AutomodService
@@ -79,6 +78,14 @@ class TestScamDetection:
result = automod.check_scam_links("Visit discord-verify.xyz to claim")
assert result is not None
def test_allowlisted_domain(self, automod: AutomodService) -> None:
"""Test allowlisted domains skip suspicious TLD checks."""
result = automod.check_scam_links(
"Visit https://discordapp.xyz for updates",
allowlist=["discordapp.xyz"],
)
assert result is None
def test_normal_url(self, automod: AutomodService) -> None:
"""Test normal URLs pass."""
result = automod.check_scam_links("Check out https://github.com/example")

View File

@@ -0,0 +1,210 @@
"""Tests for automod security improvements."""
import pytest
from guardden.services.automod import normalize_domain, URL_PATTERN
class TestDomainNormalization:
"""Test domain normalization security improvements."""
def test_normalize_domain_valid(self):
"""Test normalization of valid domains."""
test_cases = [
("example.com", "example.com"),
("www.example.com", "example.com"),
("http://example.com", "example.com"),
("https://www.example.com", "example.com"),
("EXAMPLE.COM", "example.com"),
("Example.Com", "example.com"),
]
for input_domain, expected in test_cases:
result = normalize_domain(input_domain)
assert result == expected
def test_normalize_domain_security_filters(self):
"""Test that malicious domains are filtered out."""
malicious_domains = [
"example.com\x00", # null byte
"example.com\n", # newline
"example.com\r", # carriage return
"example.com\t", # tab
"example.com\x01", # control character
"example com", # space in hostname
"", # empty string
" ", # space only
"a" * 2001, # excessively long
None, # None value
123, # non-string value
]
for malicious_domain in malicious_domains:
result = normalize_domain(malicious_domain)
assert result == "" # Should return empty string for invalid input
def test_normalize_domain_length_limits(self):
"""Test that domain length limits are enforced."""
# Test exactly at the limit
valid_long_domain = "a" * 249 + ".com" # 253 chars total (RFC limit)
result = normalize_domain(valid_long_domain)
assert result != "" # Should be valid
# Test over the limit
invalid_long_domain = "a" * 250 + ".com" # 254 chars total (over RFC limit)
result = normalize_domain(invalid_long_domain)
assert result == "" # Should be invalid
def test_normalize_domain_malformed_urls(self):
"""Test handling of malformed URLs."""
malformed_urls = [
"http://", # incomplete URL
"://example.com", # missing scheme
"http:///example.com", # extra slash
"http://example..com", # double dot
"http://.example.com", # leading dot
"http://example.com.", # trailing dot
"ftp://example.com", # non-http scheme (should still work)
]
for malformed_url in malformed_urls:
result = normalize_domain(malformed_url)
# Should either return valid domain or empty string
assert isinstance(result, str)
def test_normalize_domain_injection_attempts(self):
"""Test that domain normalization prevents injection."""
injection_attempts = [
"example.com'; DROP TABLE guilds; --",
"example.com UNION SELECT * FROM users",
"example.com\"><script>alert('xss')</script>",
"example.com\\x00\\x01\\x02",
"example.com\n\rmalicious",
]
for attempt in injection_attempts:
result = normalize_domain(attempt)
# Should either return a safe domain or empty string
if result:
assert "script" not in result
assert "DROP" not in result
assert "UNION" not in result
assert "\x00" not in result
assert "\n" not in result
assert "\r" not in result
class TestUrlPatternSecurity:
"""Test URL pattern security improvements."""
def test_url_pattern_matches_valid_urls(self):
"""Test that URL pattern matches legitimate URLs."""
valid_urls = [
"https://example.com",
"http://www.example.org",
"https://subdomain.example.net",
"http://example.io/path/to/resource",
"https://example.com/path?query=value",
"www.example.com",
"example.gg",
]
for url in valid_urls:
matches = URL_PATTERN.findall(url)
assert len(matches) >= 1, f"Failed to match valid URL: {url}"
def test_url_pattern_rejects_malicious_patterns(self):
"""Test that URL pattern doesn't match malicious patterns."""
# These should not be matched as URLs
non_urls = [
"javascript:alert('xss')",
"data:text/html,<script>alert('xss')</script>",
"file:///etc/passwd",
"ftp://anonymous@server",
"mailto:user@example.com",
]
for non_url in non_urls:
matches = URL_PATTERN.findall(non_url)
# Should not match these protocols
assert len(matches) == 0 or not any("javascript:" in match for match in matches)
def test_url_pattern_handles_edge_cases(self):
"""Test URL pattern with edge cases."""
edge_cases = [
"http://" + "a" * 300 + ".com", # very long domain
"https://example.com" + "a" * 2000, # very long path
"https://192.168.1.1", # IP address (should not match)
"https://[::1]", # IPv6 (should not match)
"https://ex-ample.com", # hyphenated domain
"https://example.123", # numeric TLD (should not match)
]
for edge_case in edge_cases:
matches = URL_PATTERN.findall(edge_case)
# Should handle gracefully (either match or not, but no crashes)
assert isinstance(matches, list)
class TestAutomodIntegration:
"""Test automod integration with security improvements."""
def test_url_processing_security(self):
"""Test that URL processing handles malicious input safely."""
from guardden.services.automod import detect_scam_links
# Mock allowlist and suspicious TLDs for testing
allowlist = ["trusted.com", "example.org"]
# Test with malicious URLs
malicious_content = [
"Check out this link: https://evil.tk/steal-your-data",
"Visit http://phishing.ml/discord-nitro-free",
"Go to https://scam" + "." * 100 + "tk", # excessive dots
"Link: https://example.com" + "x" * 5000, # excessively long
]
for content in malicious_content:
# Should not crash and should return appropriate result
result = detect_scam_links(content, allowlist)
assert result is None or hasattr(result, 'should_delete')
def test_domain_allowlist_security(self):
"""Test that domain allowlist checking is secure."""
from guardden.services.automod import is_allowed_domain
# Test with malicious allowlist entries
malicious_allowlist = {
"good.com",
"evil.com\x00", # null byte
"bad.com\n", # newline
"trusted.org",
}
test_domains = [
"good.com",
"evil.com",
"bad.com",
"trusted.org",
"unknown.com",
]
for domain in test_domains:
# Should not crash
result = is_allowed_domain(domain, malicious_allowlist)
assert isinstance(result, bool)
def test_regex_pattern_safety(self):
"""Test that regex patterns are processed safely."""
# This tests the circuit breaker functionality (when implemented)
malicious_patterns = [
"(.+)+", # catastrophic backtracking
"a" * 1000, # very long pattern
"(?:a|a)*", # another backtracking pattern
"[" + "a-z" * 100 + "]", # excessive character class
]
for pattern in malicious_patterns:
# Should not cause infinite loops or crashes
# This is a placeholder for when circuit breakers are implemented
assert len(pattern) > 0 # Just ensure we're testing something

237
tests/test_config.py Normal file
View File

@@ -0,0 +1,237 @@
"""Tests for configuration validation and security."""
import pytest
from pydantic import ValidationError
from guardden.config import Settings, _parse_id_list, _validate_discord_id, normalize_domain
class TestDiscordIdValidation:
"""Test Discord ID validation functions."""
def test_validate_discord_id_valid(self):
"""Test validation of valid Discord IDs."""
# Valid Discord snowflake IDs
valid_ids = [
"123456789012345678", # 18 digits
"1234567890123456789", # 19 digits
123456789012345678, # int format
]
for valid_id in valid_ids:
result = _validate_discord_id(valid_id)
assert isinstance(result, int)
assert result > 0
def test_validate_discord_id_invalid_format(self):
"""Test validation rejects invalid formats."""
invalid_ids = [
"12345", # too short
"12345678901234567890", # too long
"abc123456789012345678", # contains letters
"123-456-789", # contains hyphens
"123 456 789", # contains spaces
"", # empty
"0", # zero
"-123456789012345678", # negative
]
for invalid_id in invalid_ids:
with pytest.raises(ValueError):
_validate_discord_id(invalid_id)
def test_validate_discord_id_out_of_range(self):
"""Test validation rejects IDs outside valid range."""
# Too small (before Discord existed)
with pytest.raises(ValueError):
_validate_discord_id("99999999999999999")
# Too large (exceeds 64-bit limit)
with pytest.raises(ValueError):
_validate_discord_id("99999999999999999999")
class TestIdListParsing:
"""Test ID list parsing functions."""
def test_parse_id_list_valid(self):
"""Test parsing valid ID lists."""
test_cases = [
("123456789012345678", [123456789012345678]),
("123456789012345678,234567890123456789", [123456789012345678, 234567890123456789]),
("123456789012345678;234567890123456789", [123456789012345678, 234567890123456789]),
([123456789012345678, 234567890123456789], [123456789012345678, 234567890123456789]),
("", []),
(None, []),
]
for input_value, expected in test_cases:
result = _parse_id_list(input_value)
assert result == expected
def test_parse_id_list_filters_invalid(self):
"""Test that invalid IDs are filtered out."""
# Mix of valid and invalid IDs
mixed_input = "123456789012345678,invalid,234567890123456789,12345"
result = _parse_id_list(mixed_input)
assert result == [123456789012345678, 234567890123456789]
def test_parse_id_list_removes_duplicates(self):
"""Test that duplicate IDs are removed."""
duplicate_input = "123456789012345678,123456789012345678,234567890123456789"
result = _parse_id_list(duplicate_input)
assert result == [123456789012345678, 234567890123456789]
def test_parse_id_list_security(self):
"""Test that malicious input is rejected."""
malicious_inputs = [
"123456789012345678\x00", # null byte
"123456789012345678\n234567890123456789", # newline
"123456789012345678\r234567890123456789", # carriage return
]
for malicious_input in malicious_inputs:
result = _parse_id_list(malicious_input)
# Should filter out malicious entries
assert len(result) <= 1
class TestSettingsValidation:
"""Test Settings class validation."""
def test_discord_token_validation_valid(self):
"""Test valid Discord token formats."""
valid_tokens = [
"MTIzNDU2Nzg5MDEyMzQ1Njc4.G1a2b3.c4d5e6f7g8h9i0j1k2l3m4n5o6p7q8r9s0",
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
"a" * 60, # minimum reasonable length
]
for token in valid_tokens:
settings = Settings(discord_token=token)
assert settings.discord_token.get_secret_value() == token
def test_discord_token_validation_invalid(self):
"""Test invalid Discord token formats."""
invalid_tokens = [
"", # empty
"short", # too short
"token with spaces", # contains spaces
"token\nwith\nnewlines", # contains newlines
]
for token in invalid_tokens:
with pytest.raises(ValidationError):
Settings(discord_token=token)
def test_api_key_validation(self):
"""Test API key validation."""
# Valid API keys
valid_key = "sk-" + "a" * 50
settings = Settings(
discord_token="valid_token_" + "a" * 50,
ai_provider="anthropic",
anthropic_api_key=valid_key
)
assert settings.anthropic_api_key.get_secret_value() == valid_key
# Invalid API key (too short)
with pytest.raises(ValidationError):
Settings(
discord_token="valid_token_" + "a" * 50,
ai_provider="anthropic",
anthropic_api_key="short"
)
def test_configuration_validation_ai_provider(self):
"""Test AI provider configuration validation."""
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should pass with no AI provider
settings.ai_provider = "none"
settings.validate_configuration()
# Should fail with anthropic but no key
settings.ai_provider = "anthropic"
settings.anthropic_api_key = None
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
settings.validate_configuration()
# Should pass with anthropic and key
settings.anthropic_api_key = "sk-" + "a" * 50
settings.validate_configuration()
def test_configuration_validation_database_pool(self):
"""Test database pool configuration validation."""
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should fail with min > max
settings.database_pool_min = 10
settings.database_pool_max = 5
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
settings.validate_configuration()
# Should fail with min < 1
settings.database_pool_min = 0
settings.database_pool_max = 5
with pytest.raises(ValueError, match="database_pool_min must be at least 1"):
settings.validate_configuration()
class TestSecurityImprovements:
"""Test security improvements in configuration."""
def test_id_validation_prevents_injection(self):
"""Test that ID validation prevents injection attacks."""
# Test various injection attempts
injection_attempts = [
"123456789012345678'; DROP TABLE guilds; --",
"123456789012345678 UNION SELECT * FROM users",
"123456789012345678\x00\x01\x02",
"123456789012345678<script>alert('xss')</script>",
]
for attempt in injection_attempts:
# Should either raise an error or filter out the malicious input
try:
result = _validate_discord_id(attempt)
# If it doesn't raise an error, it should be a valid ID
assert isinstance(result, int)
assert result > 0
except ValueError:
# This is expected for malicious input
pass
def test_settings_with_malicious_env_vars(self):
"""Test that settings handle malicious environment variables."""
import os
# Save original values
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
try:
# Set malicious environment variables
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should filter out malicious entries
assert len(settings.allowed_guilds) <= 1
assert len(settings.owner_ids) <= 1
# Valid IDs should be preserved
assert 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
finally:
# Restore original values
if original_guilds is not None:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
else:
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
if original_owners is not None:
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
else:
os.environ.pop("GUARDDEN_OWNER_IDS", None)

View File

@@ -0,0 +1,346 @@
"""Tests for database integration and models."""
import pytest
from datetime import datetime, timezone
from sqlalchemy import select
from guardden.models.guild import Guild, GuildSettings, BannedWord
from guardden.models.moderation import ModerationLog, Strike, UserNote
from guardden.services.database import Database
class TestDatabaseModels:
"""Test database models and relationships."""
async def test_guild_creation(self, db_session, sample_guild_id, sample_owner_id):
"""Test guild creation with settings."""
guild = Guild(
id=sample_guild_id,
name="Test Guild",
owner_id=sample_owner_id,
premium=False,
)
db_session.add(guild)
settings = GuildSettings(
guild_id=sample_guild_id,
prefix="!",
automod_enabled=True,
ai_moderation_enabled=False,
)
db_session.add(settings)
await db_session.commit()
# Test guild was created
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
created_guild = result.scalar_one()
assert created_guild.id == sample_guild_id
assert created_guild.name == "Test Guild"
assert created_guild.owner_id == sample_owner_id
assert not created_guild.premium
async def test_guild_settings_relationship(self, test_guild, db_session):
"""Test guild-settings relationship."""
# Load guild with settings
result = await db_session.execute(
select(Guild).where(Guild.id == test_guild.id)
)
guild_with_settings = result.scalar_one()
# Test relationship loading
await db_session.refresh(guild_with_settings, ["settings"])
assert guild_with_settings.settings is not None
assert guild_with_settings.settings.guild_id == test_guild.id
assert guild_with_settings.settings.prefix == "!"
async def test_banned_word_creation(self, test_guild, db_session, sample_moderator_id):
"""Test banned word creation and relationship."""
banned_word = BannedWord(
guild_id=test_guild.id,
pattern="testbadword",
is_regex=False,
action="delete",
reason="Test ban",
added_by=sample_moderator_id,
)
db_session.add(banned_word)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
)
created_word = result.scalar_one()
assert created_word.pattern == "testbadword"
assert not created_word.is_regex
assert created_word.action == "delete"
assert created_word.added_by == sample_moderator_id
async def test_moderation_log_creation(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
):
"""Test moderation log creation."""
mod_log = ModerationLog(
guild_id=test_guild.id,
target_id=sample_user_id,
target_name="TestUser",
moderator_id=sample_moderator_id,
moderator_name="TestModerator",
action="ban",
reason="Test ban",
is_automatic=False,
)
db_session.add(mod_log)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
created_log = result.scalar_one()
assert created_log.action == "ban"
assert created_log.target_id == sample_user_id
assert created_log.moderator_id == sample_moderator_id
assert not created_log.is_automatic
async def test_strike_creation(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
):
"""Test strike creation and tracking."""
strike = Strike(
guild_id=test_guild.id,
user_id=sample_user_id,
user_name="TestUser",
moderator_id=sample_moderator_id,
reason="Test strike",
points=1,
is_active=True,
)
db_session.add(strike)
await db_session.commit()
# Verify creation
result = await db_session.execute(
select(Strike).where(
Strike.guild_id == test_guild.id,
Strike.user_id == sample_user_id
)
)
created_strike = result.scalar_one()
assert created_strike.points == 1
assert created_strike.is_active
assert created_strike.user_id == sample_user_id
async def test_cascade_deletion(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
):
"""Test that deleting a guild cascades to related records."""
# Add some related records
banned_word = BannedWord(
guild_id=test_guild.id,
pattern="test",
is_regex=False,
action="delete",
added_by=sample_moderator_id,
)
mod_log = ModerationLog(
guild_id=test_guild.id,
target_id=sample_user_id,
target_name="TestUser",
moderator_id=sample_moderator_id,
moderator_name="TestModerator",
action="warn",
reason="Test warning",
is_automatic=False,
)
strike = Strike(
guild_id=test_guild.id,
user_id=sample_user_id,
user_name="TestUser",
moderator_id=sample_moderator_id,
reason="Test strike",
points=1,
is_active=True,
)
db_session.add_all([banned_word, mod_log, strike])
await db_session.commit()
# Delete the guild
await db_session.delete(test_guild)
await db_session.commit()
# Verify related records were deleted
banned_words = await db_session.execute(
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
)
assert len(banned_words.scalars().all()) == 0
mod_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
assert len(mod_logs.scalars().all()) == 0
strikes = await db_session.execute(
select(Strike).where(Strike.guild_id == test_guild.id)
)
assert len(strikes.scalars().all()) == 0
class TestDatabaseIndexes:
"""Test that database indexes work as expected."""
async def test_moderation_log_indexes(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
):
"""Test moderation log indexing for performance."""
# Create multiple moderation logs
logs = []
for i in range(10):
log = ModerationLog(
guild_id=test_guild.id,
target_id=sample_user_id + i,
target_name=f"TestUser{i}",
moderator_id=sample_moderator_id,
moderator_name="TestModerator",
action="warn",
reason=f"Test warning {i}",
is_automatic=bool(i % 2),
)
logs.append(log)
db_session.add_all(logs)
await db_session.commit()
# Test queries that should use indexes
# Query by guild_id
guild_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
)
assert len(guild_logs.scalars().all()) == 10
# Query by target_id
target_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
)
assert len(target_logs.scalars().all()) == 1
# Query by is_automatic
auto_logs = await db_session.execute(
select(ModerationLog).where(ModerationLog.is_automatic == True)
)
assert len(auto_logs.scalars().all()) == 5
async def test_strike_indexes(
self,
test_guild,
db_session,
sample_user_id,
sample_moderator_id
):
"""Test strike indexing for performance."""
# Create multiple strikes
strikes = []
for i in range(5):
strike = Strike(
guild_id=test_guild.id,
user_id=sample_user_id + i,
user_name=f"TestUser{i}",
moderator_id=sample_moderator_id,
reason=f"Strike {i}",
points=1,
is_active=bool(i % 2),
)
strikes.append(strike)
db_session.add_all(strikes)
await db_session.commit()
# Test active strikes query
active_strikes = await db_session.execute(
select(Strike).where(
Strike.guild_id == test_guild.id,
Strike.is_active == True
)
)
assert len(active_strikes.scalars().all()) == 3 # indices 1, 3
class TestDatabaseSecurity:
"""Test database security features."""
async def test_snowflake_id_validation(self, db_session):
"""Test that snowflake IDs are properly validated."""
# Valid snowflake ID
valid_guild_id = 123456789012345678
guild = Guild(
id=valid_guild_id,
name="Valid Guild",
owner_id=123456789012345679,
premium=False,
)
db_session.add(guild)
await db_session.commit()
# Verify it was stored correctly
result = await db_session.execute(
select(Guild).where(Guild.id == valid_guild_id)
)
stored_guild = result.scalar_one()
assert stored_guild.id == valid_guild_id
async def test_sql_injection_prevention(self, db_session, test_guild):
"""Test that SQL injection is prevented."""
# Attempt to inject malicious SQL through user input
malicious_inputs = [
"'; DROP TABLE guilds; --",
"' UNION SELECT * FROM guild_settings --",
"' OR '1'='1",
"<script>alert('xss')</script>",
]
for malicious_input in malicious_inputs:
# Try to use malicious input in a query
# SQLAlchemy should prevent injection through parameterized queries
result = await db_session.execute(
select(Guild).where(Guild.name == malicious_input)
)
# Should not find anything (and not crash)
assert result.scalar_one_or_none() is None
async def test_data_integrity_constraints(self, db_session, sample_guild_id):
"""Test that database constraints are enforced."""
# Test foreign key constraint
with pytest.raises(Exception): # Should raise integrity error
banned_word = BannedWord(
guild_id=999999999999999999, # Non-existent guild
pattern="test",
is_regex=False,
action="delete",
added_by=123456789012345678,
)
db_session.add(banned_word)
await db_session.commit()

View File

@@ -112,6 +112,18 @@ class TestRateLimiter:
assert result.is_limited is False
assert result.remaining == 999
def test_acquire_command_scopes_per_command(self, limiter: RateLimiter) -> None:
"""Test per-command rate limits are independent."""
for _ in range(5):
result = limiter.acquire_command("config", user_id=1, guild_id=1)
assert result.is_limited is False
limited = limiter.acquire_command("config", user_id=1, guild_id=1)
assert limited.is_limited is True
other = limiter.acquire_command("other", user_id=1, guild_id=1)
assert other.is_limited is False
def test_guild_scope(self, limiter: RateLimiter) -> None:
"""Test guild-scoped rate limiting."""
limiter.configure(

View File

@@ -4,7 +4,7 @@ from datetime import timedelta
import pytest
from guardden.cogs.moderation import parse_duration
from guardden.utils import parse_duration
class TestParseDuration: