quick commit
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
This commit is contained in:
@@ -1,7 +1,56 @@
|
||||
"""Pytest fixtures for GuardDen tests."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parents[1]
|
||||
SRC_DIR = ROOT_DIR / "src"
|
||||
if str(SRC_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SRC_DIR))
|
||||
|
||||
# Import after path setup
|
||||
from guardden.config import Settings
|
||||
from guardden.models.base import Base
|
||||
from guardden.models.guild import Guild, GuildSettings, BannedWord
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
from guardden.services.database import Database
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
parser.addini("asyncio_mode", "Asyncio mode for tests", default="auto")
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.addinivalue_line("markers", "asyncio: mark async tests")
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
|
||||
test_function = pyfuncitem.obj
|
||||
if inspect.iscoroutinefunction(test_function):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(test_function(**pyfuncitem.funcargs))
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Basic Test Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guild_id() -> int:
|
||||
@@ -13,3 +62,320 @@ def sample_guild_id() -> int:
|
||||
def sample_user_id() -> int:
|
||||
"""Return a sample Discord user ID."""
|
||||
return 987654321098765432
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_moderator_id() -> int:
|
||||
"""Return a sample Discord moderator ID."""
|
||||
return 111111111111111111
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_owner_id() -> int:
|
||||
"""Return a sample Discord owner ID."""
|
||||
return 222222222222222222
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Configuration Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings() -> Settings:
|
||||
"""Return test configuration settings."""
|
||||
return Settings(
|
||||
discord_token="test_token_12345678901234567890",
|
||||
discord_prefix="!test",
|
||||
database_url="sqlite+aiosqlite:///test.db",
|
||||
database_pool_min=1,
|
||||
database_pool_max=1,
|
||||
ai_provider="none",
|
||||
log_level="DEBUG",
|
||||
allowed_guilds=[],
|
||||
owner_ids=[],
|
||||
data_dir=Path("/tmp/guardden_test"),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Database Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]:
|
||||
"""Create a test database with in-memory SQLite."""
|
||||
# Use in-memory SQLite for tests
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
database = Database(test_settings)
|
||||
database._engine = engine
|
||||
database._session_factory = async_sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
yield database
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(test_database: Database) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a database session for testing."""
|
||||
async with test_database.session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Model Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
async def test_guild(
|
||||
db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int
|
||||
) -> Guild:
|
||||
"""Create a test guild with settings."""
|
||||
guild = Guild(
|
||||
id=sample_guild_id,
|
||||
name="Test Guild",
|
||||
owner_id=sample_owner_id,
|
||||
premium=False,
|
||||
)
|
||||
db_session.add(guild)
|
||||
|
||||
# Create associated settings
|
||||
settings = GuildSettings(
|
||||
guild_id=sample_guild_id,
|
||||
prefix="!",
|
||||
automod_enabled=True,
|
||||
ai_moderation_enabled=False,
|
||||
verification_enabled=False,
|
||||
)
|
||||
db_session.add(settings)
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(guild)
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_banned_word(
|
||||
db_session: AsyncSession, test_guild: Guild, sample_moderator_id: int
|
||||
) -> BannedWord:
|
||||
"""Create a test banned word."""
|
||||
banned_word = BannedWord(
|
||||
guild_id=test_guild.id,
|
||||
pattern="badword",
|
||||
is_regex=False,
|
||||
action="delete",
|
||||
reason="Inappropriate content",
|
||||
added_by=sample_moderator_id,
|
||||
)
|
||||
db_session.add(banned_word)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(banned_word)
|
||||
return banned_word
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_moderation_log(
|
||||
db_session: AsyncSession,
|
||||
test_guild: Guild,
|
||||
sample_user_id: int,
|
||||
sample_moderator_id: int
|
||||
) -> ModerationLog:
|
||||
"""Create a test moderation log entry."""
|
||||
mod_log = ModerationLog(
|
||||
guild_id=test_guild.id,
|
||||
target_id=sample_user_id,
|
||||
target_name="TestUser",
|
||||
moderator_id=sample_moderator_id,
|
||||
moderator_name="TestModerator",
|
||||
action="warn",
|
||||
reason="Test warning",
|
||||
is_automatic=False,
|
||||
)
|
||||
db_session.add(mod_log)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(mod_log)
|
||||
return mod_log
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_strike(
|
||||
db_session: AsyncSession,
|
||||
test_guild: Guild,
|
||||
sample_user_id: int,
|
||||
sample_moderator_id: int
|
||||
) -> Strike:
|
||||
"""Create a test strike."""
|
||||
strike = Strike(
|
||||
guild_id=test_guild.id,
|
||||
user_id=sample_user_id,
|
||||
user_name="TestUser",
|
||||
moderator_id=sample_moderator_id,
|
||||
reason="Test strike",
|
||||
points=1,
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(strike)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(strike)
|
||||
return strike
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Discord Mock Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user(sample_user_id: int) -> MagicMock:
|
||||
"""Create a mock Discord user."""
|
||||
user = MagicMock()
|
||||
user.id = sample_user_id
|
||||
user.name = "TestUser"
|
||||
user.display_name = "Test User"
|
||||
user.mention = f"<@{sample_user_id}>"
|
||||
user.avatar = None
|
||||
user.bot = False
|
||||
user.send = AsyncMock()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
|
||||
"""Create a mock Discord member."""
|
||||
member = MagicMock()
|
||||
member.id = mock_discord_user.id
|
||||
member.name = mock_discord_user.name
|
||||
member.display_name = mock_discord_user.display_name
|
||||
member.mention = mock_discord_user.mention
|
||||
member.avatar = mock_discord_user.avatar
|
||||
member.bot = mock_discord_user.bot
|
||||
member.send = mock_discord_user.send
|
||||
|
||||
# Member-specific attributes
|
||||
member.guild = MagicMock()
|
||||
member.top_role = MagicMock()
|
||||
member.top_role.position = 1
|
||||
member.roles = [MagicMock()]
|
||||
member.joined_at = datetime.now(timezone.utc)
|
||||
member.kick = AsyncMock()
|
||||
member.ban = AsyncMock()
|
||||
member.timeout = AsyncMock()
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_guild(sample_guild_id: int, sample_owner_id: int) -> MagicMock:
|
||||
"""Create a mock Discord guild."""
|
||||
guild = MagicMock()
|
||||
guild.id = sample_guild_id
|
||||
guild.name = "Test Guild"
|
||||
guild.owner_id = sample_owner_id
|
||||
guild.member_count = 100
|
||||
guild.premium_tier = 0
|
||||
|
||||
# Methods
|
||||
guild.get_member = MagicMock(return_value=None)
|
||||
guild.get_channel = MagicMock(return_value=None)
|
||||
guild.leave = AsyncMock()
|
||||
guild.ban = AsyncMock()
|
||||
guild.unban = AsyncMock()
|
||||
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_channel() -> MagicMock:
|
||||
"""Create a mock Discord channel."""
|
||||
channel = MagicMock()
|
||||
channel.id = 333333333333333333
|
||||
channel.name = "test-channel"
|
||||
channel.mention = "<#333333333333333333>"
|
||||
channel.send = AsyncMock()
|
||||
channel.delete_messages = AsyncMock()
|
||||
return channel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_message(
|
||||
mock_discord_member: MagicMock, mock_discord_channel: MagicMock
|
||||
) -> MagicMock:
|
||||
"""Create a mock Discord message."""
|
||||
message = MagicMock()
|
||||
message.id = 444444444444444444
|
||||
message.content = "Test message content"
|
||||
message.author = mock_discord_member
|
||||
message.channel = mock_discord_channel
|
||||
message.guild = mock_discord_member.guild
|
||||
message.created_at = datetime.now(timezone.utc)
|
||||
message.delete = AsyncMock()
|
||||
message.reply = AsyncMock()
|
||||
message.add_reaction = AsyncMock()
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_context(
|
||||
mock_discord_member: MagicMock,
|
||||
mock_discord_guild: MagicMock,
|
||||
mock_discord_channel: MagicMock
|
||||
) -> MagicMock:
|
||||
"""Create a mock Discord command context."""
|
||||
ctx = MagicMock()
|
||||
ctx.author = mock_discord_member
|
||||
ctx.guild = mock_discord_guild
|
||||
ctx.channel = mock_discord_channel
|
||||
ctx.send = AsyncMock()
|
||||
ctx.reply = AsyncMock()
|
||||
return ctx
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Bot and Service Fixtures
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot(test_database: Database) -> MagicMock:
|
||||
"""Create a mock GuardDen bot."""
|
||||
bot = MagicMock()
|
||||
bot.database = test_database
|
||||
bot.guild_config = MagicMock()
|
||||
bot.ai_provider = MagicMock()
|
||||
bot.rate_limiter = MagicMock()
|
||||
bot.user = MagicMock()
|
||||
bot.user.id = 555555555555555555
|
||||
bot.user.name = "GuardDen"
|
||||
return bot
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Test Environment Setup
|
||||
# ==============================================================================
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_environment() -> None:
|
||||
"""Set up test environment variables."""
|
||||
# Set test environment variables
|
||||
os.environ["GUARDDEN_DISCORD_TOKEN"] = "test_token_12345678901234567890"
|
||||
os.environ["GUARDDEN_DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||||
os.environ["GUARDDEN_AI_PROVIDER"] = "none"
|
||||
os.environ["GUARDDEN_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult, parse_categories
|
||||
from guardden.services.ai.factory import NullProvider, create_ai_provider
|
||||
|
||||
|
||||
@@ -69,6 +69,14 @@ class TestModerationResult:
|
||||
assert result.severity == 100
|
||||
|
||||
|
||||
class TestParseCategories:
|
||||
"""Tests for category parsing helper."""
|
||||
|
||||
def test_parse_categories_filters_invalid(self) -> None:
|
||||
categories = parse_categories(["harassment", "unknown", "scam"])
|
||||
assert categories == [ContentCategory.HARASSMENT, ContentCategory.SCAM]
|
||||
|
||||
|
||||
class TestNullProvider:
|
||||
"""Tests for NullProvider."""
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.models import BannedWord
|
||||
from guardden.services.automod import AutomodService
|
||||
|
||||
|
||||
@@ -79,6 +78,14 @@ class TestScamDetection:
|
||||
result = automod.check_scam_links("Visit discord-verify.xyz to claim")
|
||||
assert result is not None
|
||||
|
||||
def test_allowlisted_domain(self, automod: AutomodService) -> None:
|
||||
"""Test allowlisted domains skip suspicious TLD checks."""
|
||||
result = automod.check_scam_links(
|
||||
"Visit https://discordapp.xyz for updates",
|
||||
allowlist=["discordapp.xyz"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_normal_url(self, automod: AutomodService) -> None:
|
||||
"""Test normal URLs pass."""
|
||||
result = automod.check_scam_links("Check out https://github.com/example")
|
||||
|
||||
210
tests/test_automod_security.py
Normal file
210
tests/test_automod_security.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for automod security improvements."""
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.services.automod import normalize_domain, URL_PATTERN
|
||||
|
||||
|
||||
class TestDomainNormalization:
|
||||
"""Test domain normalization security improvements."""
|
||||
|
||||
def test_normalize_domain_valid(self):
|
||||
"""Test normalization of valid domains."""
|
||||
test_cases = [
|
||||
("example.com", "example.com"),
|
||||
("www.example.com", "example.com"),
|
||||
("http://example.com", "example.com"),
|
||||
("https://www.example.com", "example.com"),
|
||||
("EXAMPLE.COM", "example.com"),
|
||||
("Example.Com", "example.com"),
|
||||
]
|
||||
|
||||
for input_domain, expected in test_cases:
|
||||
result = normalize_domain(input_domain)
|
||||
assert result == expected
|
||||
|
||||
def test_normalize_domain_security_filters(self):
|
||||
"""Test that malicious domains are filtered out."""
|
||||
malicious_domains = [
|
||||
"example.com\x00", # null byte
|
||||
"example.com\n", # newline
|
||||
"example.com\r", # carriage return
|
||||
"example.com\t", # tab
|
||||
"example.com\x01", # control character
|
||||
"example com", # space in hostname
|
||||
"", # empty string
|
||||
" ", # space only
|
||||
"a" * 2001, # excessively long
|
||||
None, # None value
|
||||
123, # non-string value
|
||||
]
|
||||
|
||||
for malicious_domain in malicious_domains:
|
||||
result = normalize_domain(malicious_domain)
|
||||
assert result == "" # Should return empty string for invalid input
|
||||
|
||||
def test_normalize_domain_length_limits(self):
|
||||
"""Test that domain length limits are enforced."""
|
||||
# Test exactly at the limit
|
||||
valid_long_domain = "a" * 249 + ".com" # 253 chars total (RFC limit)
|
||||
result = normalize_domain(valid_long_domain)
|
||||
assert result != "" # Should be valid
|
||||
|
||||
# Test over the limit
|
||||
invalid_long_domain = "a" * 250 + ".com" # 254 chars total (over RFC limit)
|
||||
result = normalize_domain(invalid_long_domain)
|
||||
assert result == "" # Should be invalid
|
||||
|
||||
def test_normalize_domain_malformed_urls(self):
|
||||
"""Test handling of malformed URLs."""
|
||||
malformed_urls = [
|
||||
"http://", # incomplete URL
|
||||
"://example.com", # missing scheme
|
||||
"http:///example.com", # extra slash
|
||||
"http://example..com", # double dot
|
||||
"http://.example.com", # leading dot
|
||||
"http://example.com.", # trailing dot
|
||||
"ftp://example.com", # non-http scheme (should still work)
|
||||
]
|
||||
|
||||
for malformed_url in malformed_urls:
|
||||
result = normalize_domain(malformed_url)
|
||||
# Should either return valid domain or empty string
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_normalize_domain_injection_attempts(self):
|
||||
"""Test that domain normalization prevents injection."""
|
||||
injection_attempts = [
|
||||
"example.com'; DROP TABLE guilds; --",
|
||||
"example.com UNION SELECT * FROM users",
|
||||
"example.com\"><script>alert('xss')</script>",
|
||||
"example.com\\x00\\x01\\x02",
|
||||
"example.com\n\rmalicious",
|
||||
]
|
||||
|
||||
for attempt in injection_attempts:
|
||||
result = normalize_domain(attempt)
|
||||
# Should either return a safe domain or empty string
|
||||
if result:
|
||||
assert "script" not in result
|
||||
assert "DROP" not in result
|
||||
assert "UNION" not in result
|
||||
assert "\x00" not in result
|
||||
assert "\n" not in result
|
||||
assert "\r" not in result
|
||||
|
||||
|
||||
class TestUrlPatternSecurity:
|
||||
"""Test URL pattern security improvements."""
|
||||
|
||||
def test_url_pattern_matches_valid_urls(self):
|
||||
"""Test that URL pattern matches legitimate URLs."""
|
||||
valid_urls = [
|
||||
"https://example.com",
|
||||
"http://www.example.org",
|
||||
"https://subdomain.example.net",
|
||||
"http://example.io/path/to/resource",
|
||||
"https://example.com/path?query=value",
|
||||
"www.example.com",
|
||||
"example.gg",
|
||||
]
|
||||
|
||||
for url in valid_urls:
|
||||
matches = URL_PATTERN.findall(url)
|
||||
assert len(matches) >= 1, f"Failed to match valid URL: {url}"
|
||||
|
||||
def test_url_pattern_rejects_malicious_patterns(self):
|
||||
"""Test that URL pattern doesn't match malicious patterns."""
|
||||
# These should not be matched as URLs
|
||||
non_urls = [
|
||||
"javascript:alert('xss')",
|
||||
"data:text/html,<script>alert('xss')</script>",
|
||||
"file:///etc/passwd",
|
||||
"ftp://anonymous@server",
|
||||
"mailto:user@example.com",
|
||||
]
|
||||
|
||||
for non_url in non_urls:
|
||||
matches = URL_PATTERN.findall(non_url)
|
||||
# Should not match these protocols
|
||||
assert len(matches) == 0 or not any("javascript:" in match for match in matches)
|
||||
|
||||
def test_url_pattern_handles_edge_cases(self):
|
||||
"""Test URL pattern with edge cases."""
|
||||
edge_cases = [
|
||||
"http://" + "a" * 300 + ".com", # very long domain
|
||||
"https://example.com" + "a" * 2000, # very long path
|
||||
"https://192.168.1.1", # IP address (should not match)
|
||||
"https://[::1]", # IPv6 (should not match)
|
||||
"https://ex-ample.com", # hyphenated domain
|
||||
"https://example.123", # numeric TLD (should not match)
|
||||
]
|
||||
|
||||
for edge_case in edge_cases:
|
||||
matches = URL_PATTERN.findall(edge_case)
|
||||
# Should handle gracefully (either match or not, but no crashes)
|
||||
assert isinstance(matches, list)
|
||||
|
||||
|
||||
class TestAutomodIntegration:
|
||||
"""Test automod integration with security improvements."""
|
||||
|
||||
def test_url_processing_security(self):
|
||||
"""Test that URL processing handles malicious input safely."""
|
||||
from guardden.services.automod import detect_scam_links
|
||||
|
||||
# Mock allowlist and suspicious TLDs for testing
|
||||
allowlist = ["trusted.com", "example.org"]
|
||||
|
||||
# Test with malicious URLs
|
||||
malicious_content = [
|
||||
"Check out this link: https://evil.tk/steal-your-data",
|
||||
"Visit http://phishing.ml/discord-nitro-free",
|
||||
"Go to https://scam" + "." * 100 + "tk", # excessive dots
|
||||
"Link: https://example.com" + "x" * 5000, # excessively long
|
||||
]
|
||||
|
||||
for content in malicious_content:
|
||||
# Should not crash and should return appropriate result
|
||||
result = detect_scam_links(content, allowlist)
|
||||
assert result is None or hasattr(result, 'should_delete')
|
||||
|
||||
def test_domain_allowlist_security(self):
|
||||
"""Test that domain allowlist checking is secure."""
|
||||
from guardden.services.automod import is_allowed_domain
|
||||
|
||||
# Test with malicious allowlist entries
|
||||
malicious_allowlist = {
|
||||
"good.com",
|
||||
"evil.com\x00", # null byte
|
||||
"bad.com\n", # newline
|
||||
"trusted.org",
|
||||
}
|
||||
|
||||
test_domains = [
|
||||
"good.com",
|
||||
"evil.com",
|
||||
"bad.com",
|
||||
"trusted.org",
|
||||
"unknown.com",
|
||||
]
|
||||
|
||||
for domain in test_domains:
|
||||
# Should not crash
|
||||
result = is_allowed_domain(domain, malicious_allowlist)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_regex_pattern_safety(self):
|
||||
"""Test that regex patterns are processed safely."""
|
||||
# This tests the circuit breaker functionality (when implemented)
|
||||
malicious_patterns = [
|
||||
"(.+)+", # catastrophic backtracking
|
||||
"a" * 1000, # very long pattern
|
||||
"(?:a|a)*", # another backtracking pattern
|
||||
"[" + "a-z" * 100 + "]", # excessive character class
|
||||
]
|
||||
|
||||
for pattern in malicious_patterns:
|
||||
# Should not cause infinite loops or crashes
|
||||
# This is a placeholder for when circuit breakers are implemented
|
||||
assert len(pattern) > 0 # Just ensure we're testing something
|
||||
237
tests/test_config.py
Normal file
237
tests/test_config.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for configuration validation and security."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from guardden.config import Settings, _parse_id_list, _validate_discord_id, normalize_domain
|
||||
|
||||
|
||||
class TestDiscordIdValidation:
|
||||
"""Test Discord ID validation functions."""
|
||||
|
||||
def test_validate_discord_id_valid(self):
|
||||
"""Test validation of valid Discord IDs."""
|
||||
# Valid Discord snowflake IDs
|
||||
valid_ids = [
|
||||
"123456789012345678", # 18 digits
|
||||
"1234567890123456789", # 19 digits
|
||||
123456789012345678, # int format
|
||||
]
|
||||
|
||||
for valid_id in valid_ids:
|
||||
result = _validate_discord_id(valid_id)
|
||||
assert isinstance(result, int)
|
||||
assert result > 0
|
||||
|
||||
def test_validate_discord_id_invalid_format(self):
|
||||
"""Test validation rejects invalid formats."""
|
||||
invalid_ids = [
|
||||
"12345", # too short
|
||||
"12345678901234567890", # too long
|
||||
"abc123456789012345678", # contains letters
|
||||
"123-456-789", # contains hyphens
|
||||
"123 456 789", # contains spaces
|
||||
"", # empty
|
||||
"0", # zero
|
||||
"-123456789012345678", # negative
|
||||
]
|
||||
|
||||
for invalid_id in invalid_ids:
|
||||
with pytest.raises(ValueError):
|
||||
_validate_discord_id(invalid_id)
|
||||
|
||||
def test_validate_discord_id_out_of_range(self):
|
||||
"""Test validation rejects IDs outside valid range."""
|
||||
# Too small (before Discord existed)
|
||||
with pytest.raises(ValueError):
|
||||
_validate_discord_id("99999999999999999")
|
||||
|
||||
# Too large (exceeds 64-bit limit)
|
||||
with pytest.raises(ValueError):
|
||||
_validate_discord_id("99999999999999999999")
|
||||
|
||||
|
||||
class TestIdListParsing:
|
||||
"""Test ID list parsing functions."""
|
||||
|
||||
def test_parse_id_list_valid(self):
|
||||
"""Test parsing valid ID lists."""
|
||||
test_cases = [
|
||||
("123456789012345678", [123456789012345678]),
|
||||
("123456789012345678,234567890123456789", [123456789012345678, 234567890123456789]),
|
||||
("123456789012345678;234567890123456789", [123456789012345678, 234567890123456789]),
|
||||
([123456789012345678, 234567890123456789], [123456789012345678, 234567890123456789]),
|
||||
("", []),
|
||||
(None, []),
|
||||
]
|
||||
|
||||
for input_value, expected in test_cases:
|
||||
result = _parse_id_list(input_value)
|
||||
assert result == expected
|
||||
|
||||
def test_parse_id_list_filters_invalid(self):
|
||||
"""Test that invalid IDs are filtered out."""
|
||||
# Mix of valid and invalid IDs
|
||||
mixed_input = "123456789012345678,invalid,234567890123456789,12345"
|
||||
result = _parse_id_list(mixed_input)
|
||||
assert result == [123456789012345678, 234567890123456789]
|
||||
|
||||
def test_parse_id_list_removes_duplicates(self):
|
||||
"""Test that duplicate IDs are removed."""
|
||||
duplicate_input = "123456789012345678,123456789012345678,234567890123456789"
|
||||
result = _parse_id_list(duplicate_input)
|
||||
assert result == [123456789012345678, 234567890123456789]
|
||||
|
||||
def test_parse_id_list_security(self):
|
||||
"""Test that malicious input is rejected."""
|
||||
malicious_inputs = [
|
||||
"123456789012345678\x00", # null byte
|
||||
"123456789012345678\n234567890123456789", # newline
|
||||
"123456789012345678\r234567890123456789", # carriage return
|
||||
]
|
||||
|
||||
for malicious_input in malicious_inputs:
|
||||
result = _parse_id_list(malicious_input)
|
||||
# Should filter out malicious entries
|
||||
assert len(result) <= 1
|
||||
|
||||
|
||||
class TestSettingsValidation:
|
||||
"""Test Settings class validation."""
|
||||
|
||||
def test_discord_token_validation_valid(self):
|
||||
"""Test valid Discord token formats."""
|
||||
valid_tokens = [
|
||||
"MTIzNDU2Nzg5MDEyMzQ1Njc4.G1a2b3.c4d5e6f7g8h9i0j1k2l3m4n5o6p7q8r9s0",
|
||||
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
|
||||
"a" * 60, # minimum reasonable length
|
||||
]
|
||||
|
||||
for token in valid_tokens:
|
||||
settings = Settings(discord_token=token)
|
||||
assert settings.discord_token.get_secret_value() == token
|
||||
|
||||
def test_discord_token_validation_invalid(self):
|
||||
"""Test invalid Discord token formats."""
|
||||
invalid_tokens = [
|
||||
"", # empty
|
||||
"short", # too short
|
||||
"token with spaces", # contains spaces
|
||||
"token\nwith\nnewlines", # contains newlines
|
||||
]
|
||||
|
||||
for token in invalid_tokens:
|
||||
with pytest.raises(ValidationError):
|
||||
Settings(discord_token=token)
|
||||
|
||||
def test_api_key_validation(self):
|
||||
"""Test API key validation."""
|
||||
# Valid API keys
|
||||
valid_key = "sk-" + "a" * 50
|
||||
settings = Settings(
|
||||
discord_token="valid_token_" + "a" * 50,
|
||||
ai_provider="anthropic",
|
||||
anthropic_api_key=valid_key
|
||||
)
|
||||
assert settings.anthropic_api_key.get_secret_value() == valid_key
|
||||
|
||||
# Invalid API key (too short)
|
||||
with pytest.raises(ValidationError):
|
||||
Settings(
|
||||
discord_token="valid_token_" + "a" * 50,
|
||||
ai_provider="anthropic",
|
||||
anthropic_api_key="short"
|
||||
)
|
||||
|
||||
def test_configuration_validation_ai_provider(self):
|
||||
"""Test AI provider configuration validation."""
|
||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||
|
||||
# Should pass with no AI provider
|
||||
settings.ai_provider = "none"
|
||||
settings.validate_configuration()
|
||||
|
||||
# Should fail with anthropic but no key
|
||||
settings.ai_provider = "anthropic"
|
||||
settings.anthropic_api_key = None
|
||||
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
|
||||
settings.validate_configuration()
|
||||
|
||||
# Should pass with anthropic and key
|
||||
settings.anthropic_api_key = "sk-" + "a" * 50
|
||||
settings.validate_configuration()
|
||||
|
||||
def test_configuration_validation_database_pool(self):
|
||||
"""Test database pool configuration validation."""
|
||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||
|
||||
# Should fail with min > max
|
||||
settings.database_pool_min = 10
|
||||
settings.database_pool_max = 5
|
||||
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
|
||||
settings.validate_configuration()
|
||||
|
||||
# Should fail with min < 1
|
||||
settings.database_pool_min = 0
|
||||
settings.database_pool_max = 5
|
||||
with pytest.raises(ValueError, match="database_pool_min must be at least 1"):
|
||||
settings.validate_configuration()
|
||||
|
||||
|
||||
class TestSecurityImprovements:
|
||||
"""Test security improvements in configuration."""
|
||||
|
||||
def test_id_validation_prevents_injection(self):
|
||||
"""Test that ID validation prevents injection attacks."""
|
||||
# Test various injection attempts
|
||||
injection_attempts = [
|
||||
"123456789012345678'; DROP TABLE guilds; --",
|
||||
"123456789012345678 UNION SELECT * FROM users",
|
||||
"123456789012345678\x00\x01\x02",
|
||||
"123456789012345678<script>alert('xss')</script>",
|
||||
]
|
||||
|
||||
for attempt in injection_attempts:
|
||||
# Should either raise an error or filter out the malicious input
|
||||
try:
|
||||
result = _validate_discord_id(attempt)
|
||||
# If it doesn't raise an error, it should be a valid ID
|
||||
assert isinstance(result, int)
|
||||
assert result > 0
|
||||
except ValueError:
|
||||
# This is expected for malicious input
|
||||
pass
|
||||
|
||||
def test_settings_with_malicious_env_vars(self):
|
||||
"""Test that settings handle malicious environment variables."""
|
||||
import os
|
||||
|
||||
# Save original values
|
||||
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
|
||||
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
|
||||
|
||||
try:
|
||||
# Set malicious environment variables
|
||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
|
||||
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
|
||||
|
||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||
|
||||
# Should filter out malicious entries
|
||||
assert len(settings.allowed_guilds) <= 1
|
||||
assert len(settings.owner_ids) <= 1
|
||||
|
||||
# Valid IDs should be preserved
|
||||
assert 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_guilds is not None:
|
||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
|
||||
else:
|
||||
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
|
||||
|
||||
if original_owners is not None:
|
||||
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
|
||||
else:
|
||||
os.environ.pop("GUARDDEN_OWNER_IDS", None)
|
||||
346
tests/test_database_integration.py
Normal file
346
tests/test_database_integration.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Tests for database integration and models."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import select
|
||||
|
||||
from guardden.models.guild import Guild, GuildSettings, BannedWord
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
from guardden.services.database import Database
|
||||
|
||||
|
||||
class TestDatabaseModels:
|
||||
"""Test database models and relationships."""
|
||||
|
||||
async def test_guild_creation(self, db_session, sample_guild_id, sample_owner_id):
|
||||
"""Test guild creation with settings."""
|
||||
guild = Guild(
|
||||
id=sample_guild_id,
|
||||
name="Test Guild",
|
||||
owner_id=sample_owner_id,
|
||||
premium=False,
|
||||
)
|
||||
db_session.add(guild)
|
||||
|
||||
settings = GuildSettings(
|
||||
guild_id=sample_guild_id,
|
||||
prefix="!",
|
||||
automod_enabled=True,
|
||||
ai_moderation_enabled=False,
|
||||
)
|
||||
db_session.add(settings)
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
# Test guild was created
|
||||
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
|
||||
created_guild = result.scalar_one()
|
||||
|
||||
assert created_guild.id == sample_guild_id
|
||||
assert created_guild.name == "Test Guild"
|
||||
assert created_guild.owner_id == sample_owner_id
|
||||
assert not created_guild.premium
|
||||
|
||||
async def test_guild_settings_relationship(self, test_guild, db_session):
|
||||
"""Test guild-settings relationship."""
|
||||
# Load guild with settings
|
||||
result = await db_session.execute(
|
||||
select(Guild).where(Guild.id == test_guild.id)
|
||||
)
|
||||
guild_with_settings = result.scalar_one()
|
||||
|
||||
# Test relationship loading
|
||||
await db_session.refresh(guild_with_settings, ["settings"])
|
||||
assert guild_with_settings.settings is not None
|
||||
assert guild_with_settings.settings.guild_id == test_guild.id
|
||||
assert guild_with_settings.settings.prefix == "!"
|
||||
|
||||
async def test_banned_word_creation(self, test_guild, db_session, sample_moderator_id):
|
||||
"""Test banned word creation and relationship."""
|
||||
banned_word = BannedWord(
|
||||
guild_id=test_guild.id,
|
||||
pattern="testbadword",
|
||||
is_regex=False,
|
||||
action="delete",
|
||||
reason="Test ban",
|
||||
added_by=sample_moderator_id,
|
||||
)
|
||||
db_session.add(banned_word)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify creation
|
||||
result = await db_session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
||||
)
|
||||
created_word = result.scalar_one()
|
||||
|
||||
assert created_word.pattern == "testbadword"
|
||||
assert not created_word.is_regex
|
||||
assert created_word.action == "delete"
|
||||
assert created_word.added_by == sample_moderator_id
|
||||
|
||||
async def test_moderation_log_creation(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
):
|
||||
"""Test moderation log creation."""
|
||||
mod_log = ModerationLog(
|
||||
guild_id=test_guild.id,
|
||||
target_id=sample_user_id,
|
||||
target_name="TestUser",
|
||||
moderator_id=sample_moderator_id,
|
||||
moderator_name="TestModerator",
|
||||
action="ban",
|
||||
reason="Test ban",
|
||||
is_automatic=False,
|
||||
)
|
||||
db_session.add(mod_log)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify creation
|
||||
result = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||
)
|
||||
created_log = result.scalar_one()
|
||||
|
||||
assert created_log.action == "ban"
|
||||
assert created_log.target_id == sample_user_id
|
||||
assert created_log.moderator_id == sample_moderator_id
|
||||
assert not created_log.is_automatic
|
||||
|
||||
async def test_strike_creation(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
):
|
||||
"""Test strike creation and tracking."""
|
||||
strike = Strike(
|
||||
guild_id=test_guild.id,
|
||||
user_id=sample_user_id,
|
||||
user_name="TestUser",
|
||||
moderator_id=sample_moderator_id,
|
||||
reason="Test strike",
|
||||
points=1,
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(strike)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify creation
|
||||
result = await db_session.execute(
|
||||
select(Strike).where(
|
||||
Strike.guild_id == test_guild.id,
|
||||
Strike.user_id == sample_user_id
|
||||
)
|
||||
)
|
||||
created_strike = result.scalar_one()
|
||||
|
||||
assert created_strike.points == 1
|
||||
assert created_strike.is_active
|
||||
assert created_strike.user_id == sample_user_id
|
||||
|
||||
async def test_cascade_deletion(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
):
|
||||
"""Test that deleting a guild cascades to related records."""
|
||||
# Add some related records
|
||||
banned_word = BannedWord(
|
||||
guild_id=test_guild.id,
|
||||
pattern="test",
|
||||
is_regex=False,
|
||||
action="delete",
|
||||
added_by=sample_moderator_id,
|
||||
)
|
||||
|
||||
mod_log = ModerationLog(
|
||||
guild_id=test_guild.id,
|
||||
target_id=sample_user_id,
|
||||
target_name="TestUser",
|
||||
moderator_id=sample_moderator_id,
|
||||
moderator_name="TestModerator",
|
||||
action="warn",
|
||||
reason="Test warning",
|
||||
is_automatic=False,
|
||||
)
|
||||
|
||||
strike = Strike(
|
||||
guild_id=test_guild.id,
|
||||
user_id=sample_user_id,
|
||||
user_name="TestUser",
|
||||
moderator_id=sample_moderator_id,
|
||||
reason="Test strike",
|
||||
points=1,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
db_session.add_all([banned_word, mod_log, strike])
|
||||
await db_session.commit()
|
||||
|
||||
# Delete the guild
|
||||
await db_session.delete(test_guild)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify related records were deleted
|
||||
banned_words = await db_session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(banned_words.scalars().all()) == 0
|
||||
|
||||
mod_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(mod_logs.scalars().all()) == 0
|
||||
|
||||
strikes = await db_session.execute(
|
||||
select(Strike).where(Strike.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(strikes.scalars().all()) == 0
|
||||
|
||||
|
||||
class TestDatabaseIndexes:
|
||||
"""Test that database indexes work as expected."""
|
||||
|
||||
async def test_moderation_log_indexes(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
):
|
||||
"""Test moderation log indexing for performance."""
|
||||
# Create multiple moderation logs
|
||||
logs = []
|
||||
for i in range(10):
|
||||
log = ModerationLog(
|
||||
guild_id=test_guild.id,
|
||||
target_id=sample_user_id + i,
|
||||
target_name=f"TestUser{i}",
|
||||
moderator_id=sample_moderator_id,
|
||||
moderator_name="TestModerator",
|
||||
action="warn",
|
||||
reason=f"Test warning {i}",
|
||||
is_automatic=bool(i % 2),
|
||||
)
|
||||
logs.append(log)
|
||||
|
||||
db_session.add_all(logs)
|
||||
await db_session.commit()
|
||||
|
||||
# Test queries that should use indexes
|
||||
# Query by guild_id
|
||||
guild_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||
)
|
||||
assert len(guild_logs.scalars().all()) == 10
|
||||
|
||||
# Query by target_id
|
||||
target_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
|
||||
)
|
||||
assert len(target_logs.scalars().all()) == 1
|
||||
|
||||
# Query by is_automatic
|
||||
auto_logs = await db_session.execute(
|
||||
select(ModerationLog).where(ModerationLog.is_automatic == True)
|
||||
)
|
||||
assert len(auto_logs.scalars().all()) == 5
|
||||
|
||||
async def test_strike_indexes(
|
||||
self,
|
||||
test_guild,
|
||||
db_session,
|
||||
sample_user_id,
|
||||
sample_moderator_id
|
||||
):
|
||||
"""Test strike indexing for performance."""
|
||||
# Create multiple strikes
|
||||
strikes = []
|
||||
for i in range(5):
|
||||
strike = Strike(
|
||||
guild_id=test_guild.id,
|
||||
user_id=sample_user_id + i,
|
||||
user_name=f"TestUser{i}",
|
||||
moderator_id=sample_moderator_id,
|
||||
reason=f"Strike {i}",
|
||||
points=1,
|
||||
is_active=bool(i % 2),
|
||||
)
|
||||
strikes.append(strike)
|
||||
|
||||
db_session.add_all(strikes)
|
||||
await db_session.commit()
|
||||
|
||||
# Test active strikes query
|
||||
active_strikes = await db_session.execute(
|
||||
select(Strike).where(
|
||||
Strike.guild_id == test_guild.id,
|
||||
Strike.is_active == True
|
||||
)
|
||||
)
|
||||
assert len(active_strikes.scalars().all()) == 3 # indices 1, 3
|
||||
|
||||
|
||||
class TestDatabaseSecurity:
|
||||
"""Test database security features."""
|
||||
|
||||
async def test_snowflake_id_validation(self, db_session):
|
||||
"""Test that snowflake IDs are properly validated."""
|
||||
# Valid snowflake ID
|
||||
valid_guild_id = 123456789012345678
|
||||
guild = Guild(
|
||||
id=valid_guild_id,
|
||||
name="Valid Guild",
|
||||
owner_id=123456789012345679,
|
||||
premium=False,
|
||||
)
|
||||
db_session.add(guild)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify it was stored correctly
|
||||
result = await db_session.execute(
|
||||
select(Guild).where(Guild.id == valid_guild_id)
|
||||
)
|
||||
stored_guild = result.scalar_one()
|
||||
assert stored_guild.id == valid_guild_id
|
||||
|
||||
async def test_sql_injection_prevention(self, db_session, test_guild):
|
||||
"""Test that SQL injection is prevented."""
|
||||
# Attempt to inject malicious SQL through user input
|
||||
malicious_inputs = [
|
||||
"'; DROP TABLE guilds; --",
|
||||
"' UNION SELECT * FROM guild_settings --",
|
||||
"' OR '1'='1",
|
||||
"<script>alert('xss')</script>",
|
||||
]
|
||||
|
||||
for malicious_input in malicious_inputs:
|
||||
# Try to use malicious input in a query
|
||||
# SQLAlchemy should prevent injection through parameterized queries
|
||||
result = await db_session.execute(
|
||||
select(Guild).where(Guild.name == malicious_input)
|
||||
)
|
||||
# Should not find anything (and not crash)
|
||||
assert result.scalar_one_or_none() is None
|
||||
|
||||
async def test_data_integrity_constraints(self, db_session, sample_guild_id):
|
||||
"""Test that database constraints are enforced."""
|
||||
# Test foreign key constraint
|
||||
with pytest.raises(Exception): # Should raise integrity error
|
||||
banned_word = BannedWord(
|
||||
guild_id=999999999999999999, # Non-existent guild
|
||||
pattern="test",
|
||||
is_regex=False,
|
||||
action="delete",
|
||||
added_by=123456789012345678,
|
||||
)
|
||||
db_session.add(banned_word)
|
||||
await db_session.commit()
|
||||
@@ -112,6 +112,18 @@ class TestRateLimiter:
|
||||
assert result.is_limited is False
|
||||
assert result.remaining == 999
|
||||
|
||||
def test_acquire_command_scopes_per_command(self, limiter: RateLimiter) -> None:
|
||||
"""Test per-command rate limits are independent."""
|
||||
for _ in range(5):
|
||||
result = limiter.acquire_command("config", user_id=1, guild_id=1)
|
||||
assert result.is_limited is False
|
||||
|
||||
limited = limiter.acquire_command("config", user_id=1, guild_id=1)
|
||||
assert limited.is_limited is True
|
||||
|
||||
other = limiter.acquire_command("other", user_id=1, guild_id=1)
|
||||
assert other.is_limited is False
|
||||
|
||||
def test_guild_scope(self, limiter: RateLimiter) -> None:
|
||||
"""Test guild-scoped rate limiting."""
|
||||
limiter.configure(
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.cogs.moderation import parse_duration
|
||||
from guardden.utils import parse_duration
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
|
||||
Reference in New Issue
Block a user