Files
GuardDen/tests/test_config.py
latte abef368a68
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
update
2026-01-17 21:57:04 +01:00

247 lines
9.2 KiB
Python

"""Tests for configuration validation and security."""
import pytest
from pydantic import ValidationError
from guardden.config import Settings, _parse_id_list, _validate_discord_id
from guardden.services.automod import normalize_domain
class TestDiscordIdValidation:
"""Test Discord ID validation functions."""
def test_validate_discord_id_valid(self):
"""Test validation of valid Discord IDs."""
# Valid Discord snowflake IDs
valid_ids = [
"123456789012345678", # 18 digits
"1234567890123456789", # 19 digits
123456789012345678, # int format
]
for valid_id in valid_ids:
result = _validate_discord_id(valid_id)
assert isinstance(result, int)
assert result > 0
def test_validate_discord_id_invalid_format(self):
"""Test validation rejects invalid formats."""
invalid_ids = [
"12345", # too short
"12345678901234567890", # too long
"abc123456789012345678", # contains letters
"123-456-789", # contains hyphens
"123 456 789", # contains spaces
"", # empty
"0", # zero
"-123456789012345678", # negative
]
for invalid_id in invalid_ids:
with pytest.raises(ValueError):
_validate_discord_id(invalid_id)
def test_validate_discord_id_out_of_range(self):
"""Test validation rejects IDs outside valid range."""
# Too small (before Discord existed)
with pytest.raises(ValueError):
_validate_discord_id("99999999999999999")
# Too large (exceeds 64-bit limit)
with pytest.raises(ValueError):
_validate_discord_id("99999999999999999999")
class TestIdListParsing:
"""Test ID list parsing functions."""
def test_parse_id_list_valid(self):
"""Test parsing valid ID lists."""
test_cases = [
("123456789012345678", [123456789012345678]),
("123456789012345678,234567890123456789", [123456789012345678, 234567890123456789]),
("123456789012345678;234567890123456789", [123456789012345678, 234567890123456789]),
([123456789012345678, 234567890123456789], [123456789012345678, 234567890123456789]),
("", []),
(None, []),
]
for input_value, expected in test_cases:
result = _parse_id_list(input_value)
assert result == expected
def test_parse_id_list_filters_invalid(self):
"""Test that invalid IDs are filtered out."""
# Mix of valid and invalid IDs
mixed_input = "123456789012345678,invalid,234567890123456789,12345"
result = _parse_id_list(mixed_input)
assert result == [123456789012345678, 234567890123456789]
def test_parse_id_list_removes_duplicates(self):
"""Test that duplicate IDs are removed."""
duplicate_input = "123456789012345678,123456789012345678,234567890123456789"
result = _parse_id_list(duplicate_input)
assert result == [123456789012345678, 234567890123456789]
def test_parse_id_list_security(self):
"""Test that malicious input is rejected."""
malicious_inputs = [
"123456789012345678\x00", # null byte
"123456789012345678\n234567890123456789", # newline
"123456789012345678\r234567890123456789", # carriage return
]
for malicious_input in malicious_inputs:
result = _parse_id_list(malicious_input)
# Should filter out malicious entries
assert len(result) <= 1
class TestSettingsValidation:
"""Test Settings class validation."""
def test_discord_token_validation_valid(self):
"""Test valid Discord token formats."""
valid_tokens = [
"MTIzNDU2Nzg5MDEyMzQ1Njc4.G1a2b3.c4d5e6f7g8h9i0j1k2l3m4n5o6p7q8r9s0",
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
"a" * 60, # minimum reasonable length
]
for token in valid_tokens:
settings = Settings(discord_token=token)
assert settings.discord_token.get_secret_value() == token
def test_discord_token_validation_invalid(self):
"""Test invalid Discord token formats."""
invalid_tokens = [
"", # empty
"short", # too short
"token with spaces", # contains spaces
"token\nwith\nnewlines", # contains newlines
]
for token in invalid_tokens:
with pytest.raises(ValidationError):
Settings(discord_token=token)
def test_api_key_validation(self):
"""Test API key validation."""
# Valid API keys
valid_key = "sk-" + "a" * 50
settings = Settings(
discord_token="valid_token_" + "a" * 50,
ai_provider="anthropic",
anthropic_api_key=valid_key,
)
assert settings.anthropic_api_key.get_secret_value() == valid_key
# Invalid API key (too short)
with pytest.raises(ValidationError):
Settings(
discord_token="valid_token_" + "a" * 50,
ai_provider="anthropic",
anthropic_api_key="short",
)
def test_configuration_validation_ai_provider(self):
"""Test AI provider configuration validation."""
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should pass with no AI provider
settings.ai_provider = "none"
settings.validate_configuration()
# Should fail with anthropic but no key
settings.ai_provider = "anthropic"
settings.anthropic_api_key = None
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
settings.validate_configuration()
# Should pass with anthropic and key
settings.anthropic_api_key = "sk-" + "a" * 50
settings.validate_configuration()
def test_configuration_validation_database_pool(self):
"""Test database pool configuration validation."""
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should fail with min > max
settings.database_pool_min = 10
settings.database_pool_max = 5
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
settings.validate_configuration()
# Should fail with min < 1
settings.database_pool_min = 0
settings.database_pool_max = 5
with pytest.raises(ValueError, match="database_pool_min must be at least 1"):
settings.validate_configuration()
class TestSecurityImprovements:
"""Test security improvements in configuration."""
def test_id_validation_prevents_injection(self):
"""Test that ID validation prevents injection attacks."""
# Test various injection attempts
injection_attempts = [
"123456789012345678'; DROP TABLE guilds; --",
"123456789012345678 UNION SELECT * FROM users",
"123456789012345678\x00\x01\x02",
"123456789012345678<script>alert('xss')</script>",
]
for attempt in injection_attempts:
# Should either raise an error or filter out the malicious input
try:
result = _validate_discord_id(attempt)
# If it doesn't raise an error, it should be a valid ID
assert isinstance(result, int)
assert result > 0
except ValueError:
# This is expected for malicious input
pass
def test_settings_with_malicious_env_vars(self):
"""Test that settings handle malicious environment variables."""
import os
# Save original values
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
try:
# Set malicious environment variables
try:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
except ValueError:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678,malicious"
try:
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
except ValueError:
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789,567890123456789012"
settings = Settings(discord_token="valid_token_" + "a" * 50)
# Should filter out malicious entries
assert len(settings.allowed_guilds) <= 1
assert len(settings.owner_ids) <= 1
# Valid IDs should be preserved
assert (
123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
)
finally:
# Restore original values
if original_guilds is not None:
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
else:
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
if original_owners is not None:
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
else:
os.environ.pop("GUARDDEN_OWNER_IDS", None)