"""Tests for automod security improvements.""" import pytest from guardden.services.automod import normalize_domain, URL_PATTERN class TestDomainNormalization: """Test domain normalization security improvements.""" def test_normalize_domain_valid(self): """Test normalization of valid domains.""" test_cases = [ ("example.com", "example.com"), ("www.example.com", "example.com"), ("http://example.com", "example.com"), ("https://www.example.com", "example.com"), ("EXAMPLE.COM", "example.com"), ("Example.Com", "example.com"), ] for input_domain, expected in test_cases: result = normalize_domain(input_domain) assert result == expected def test_normalize_domain_security_filters(self): """Test that malicious domains are filtered out.""" malicious_domains = [ "example.com\x00", # null byte "example.com\n", # newline "example.com\r", # carriage return "example.com\t", # tab "example.com\x01", # control character "example com", # space in hostname "", # empty string " ", # space only "a" * 2001, # excessively long None, # None value 123, # non-string value ] for malicious_domain in malicious_domains: result = normalize_domain(malicious_domain) assert result == "" # Should return empty string for invalid input def test_normalize_domain_length_limits(self): """Test that domain length limits are enforced.""" # Test exactly at the limit valid_long_domain = "a" * 249 + ".com" # 253 chars total (RFC limit) result = normalize_domain(valid_long_domain) assert result != "" # Should be valid # Test over the limit invalid_long_domain = "a" * 250 + ".com" # 254 chars total (over RFC limit) result = normalize_domain(invalid_long_domain) assert result == "" # Should be invalid def test_normalize_domain_malformed_urls(self): """Test handling of malformed URLs.""" malformed_urls = [ "http://", # incomplete URL "://example.com", # missing scheme "http:///example.com", # extra slash "http://example..com", # double dot "http://.example.com", # leading dot "http://example.com.", # trailing dot "ftp://example.com", # non-http scheme (should still work) ] for malformed_url in malformed_urls: result = normalize_domain(malformed_url) # Should either return valid domain or empty string assert isinstance(result, str) def test_normalize_domain_injection_attempts(self): """Test that domain normalization prevents injection.""" injection_attempts = [ "example.com'; DROP TABLE guilds; --", "example.com UNION SELECT * FROM users", "example.com\">", "example.com\\x00\\x01\\x02", "example.com\n\rmalicious", ] for attempt in injection_attempts: result = normalize_domain(attempt) # Should either return a safe domain or empty string if result: assert "script" not in result assert "DROP" not in result assert "UNION" not in result assert "\x00" not in result assert "\n" not in result assert "\r" not in result class TestUrlPatternSecurity: """Test URL pattern security improvements.""" def test_url_pattern_matches_valid_urls(self): """Test that URL pattern matches legitimate URLs.""" valid_urls = [ "https://example.com", "http://www.example.org", "https://subdomain.example.net", "http://example.io/path/to/resource", "https://example.com/path?query=value", "www.example.com", "example.gg", ] for url in valid_urls: matches = URL_PATTERN.findall(url) assert len(matches) >= 1, f"Failed to match valid URL: {url}" def test_url_pattern_rejects_malicious_patterns(self): """Test that URL pattern doesn't match malicious patterns.""" # These should not be matched as URLs non_urls = [ "javascript:alert('xss')", "data:text/html,", "file:///etc/passwd", "ftp://anonymous@server", "mailto:user@example.com", ] for non_url in non_urls: matches = URL_PATTERN.findall(non_url) # Should not match these protocols assert len(matches) == 0 or not any("javascript:" in match for match in matches) def test_url_pattern_handles_edge_cases(self): """Test URL pattern with edge cases.""" edge_cases = [ "http://" + "a" * 300 + ".com", # very long domain "https://example.com" + "a" * 2000, # very long path "https://192.168.1.1", # IP address (should not match) "https://[::1]", # IPv6 (should not match) "https://ex-ample.com", # hyphenated domain "https://example.123", # numeric TLD (should not match) ] for edge_case in edge_cases: matches = URL_PATTERN.findall(edge_case) # Should handle gracefully (either match or not, but no crashes) assert isinstance(matches, list) class TestAutomodIntegration: """Test automod integration with security improvements.""" def test_url_processing_security(self): """Test that URL processing handles malicious input safely.""" from guardden.services.automod import detect_scam_links # Mock allowlist and suspicious TLDs for testing allowlist = ["trusted.com", "example.org"] # Test with malicious URLs malicious_content = [ "Check out this link: https://evil.tk/steal-your-data", "Visit http://phishing.ml/discord-nitro-free", "Go to https://scam" + "." * 100 + "tk", # excessive dots "Link: https://example.com" + "x" * 5000, # excessively long ] for content in malicious_content: # Should not crash and should return appropriate result result = detect_scam_links(content, allowlist) assert result is None or hasattr(result, 'should_delete') def test_domain_allowlist_security(self): """Test that domain allowlist checking is secure.""" from guardden.services.automod import is_allowed_domain # Test with malicious allowlist entries malicious_allowlist = { "good.com", "evil.com\x00", # null byte "bad.com\n", # newline "trusted.org", } test_domains = [ "good.com", "evil.com", "bad.com", "trusted.org", "unknown.com", ] for domain in test_domains: # Should not crash result = is_allowed_domain(domain, malicious_allowlist) assert isinstance(result, bool) def test_regex_pattern_safety(self): """Test that regex patterns are processed safely.""" # This tests the circuit breaker functionality (when implemented) malicious_patterns = [ "(.+)+", # catastrophic backtracking "a" * 1000, # very long pattern "(?:a|a)*", # another backtracking pattern "[" + "a-z" * 100 + "]", # excessive character class ] for pattern in malicious_patterns: # Should not cause infinite loops or crashes # This is a placeholder for when circuit breakers are implemented assert len(pattern) > 0 # Just ensure we're testing something