dev #11
@@ -1,73 +1,57 @@
|
||||
"""AI-powered moderation cog."""
|
||||
"""AI-powered moderation cog - Images & GIFs only, with cost controls."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.models import ModerationLog
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||
from guardden.services.automod import URL_PATTERN, is_allowed_domain, normalize_domain
|
||||
from guardden.utils.notifications import send_moderation_notification
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NSFW video domain blocklist
|
||||
NSFW_VIDEO_DOMAINS = [] # Loaded from config
|
||||
|
||||
# URL pattern for finding links
|
||||
URL_PATTERN = re.compile(
|
||||
r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
|
||||
)
|
||||
|
||||
|
||||
def _get_action_for_nsfw(category: str) -> str:
|
||||
"""Map NSFW category to suggested action."""
|
||||
mapping = {
|
||||
"suggestive": "warn",
|
||||
"suggestive": "none",
|
||||
"partial_nudity": "delete",
|
||||
"nudity": "delete",
|
||||
"explicit": "timeout",
|
||||
"explicit": "delete",
|
||||
}
|
||||
return mapping.get(category, "none")
|
||||
|
||||
|
||||
class AIModeration(commands.Cog):
|
||||
"""AI-powered content moderation."""
|
||||
"""AI-powered NSFW image detection with strict cost controls."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
# Track recently analyzed messages to avoid duplicates (deque auto-removes oldest)
|
||||
# Track recently analyzed messages to avoid duplicates (cost control)
|
||||
self._analyzed_messages: deque[int] = deque(maxlen=1000)
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Optional owner allowlist for AI commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
return self.bot.is_owner_allowed(ctx.author.id)
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
# Load NSFW video domains from config
|
||||
global NSFW_VIDEO_DOMAINS
|
||||
NSFW_VIDEO_DOMAINS = bot.config_loader.get_setting("nsfw_video_domains", [])
|
||||
|
||||
def _should_analyze(self, message: discord.Message) -> bool:
|
||||
"""Determine if a message should be analyzed by AI."""
|
||||
# Skip if already analyzed
|
||||
# Skip if already analyzed (deduplication for cost control)
|
||||
if message.id in self._analyzed_messages:
|
||||
return False
|
||||
|
||||
# Skip short messages without media
|
||||
if len(message.content) < 20 and not message.attachments and not message.embeds:
|
||||
# Skip if no images/embeds
|
||||
if not message.attachments and not message.embeds:
|
||||
return False
|
||||
|
||||
# Skip messages from bots
|
||||
@@ -80,198 +64,22 @@ class AIModeration(commands.Cog):
|
||||
"""Track that a message has been analyzed."""
|
||||
self._analyzed_messages.append(message_id)
|
||||
|
||||
async def _handle_ai_result(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
) -> None:
|
||||
"""Handle the result of AI analysis."""
|
||||
if not result.is_flagged:
|
||||
return
|
||||
def _has_nsfw_video_link(self, content: str) -> bool:
|
||||
"""Check if message contains NSFW video domain."""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config:
|
||||
return
|
||||
content_lower = content.lower()
|
||||
for domain in NSFW_VIDEO_DOMAINS:
|
||||
if domain.lower() in content_lower:
|
||||
logger.info(f"Blocked NSFW video domain: {domain}")
|
||||
return True
|
||||
|
||||
# Check NSFW-only filtering mode
|
||||
if config.nsfw_only_filtering:
|
||||
# Only process SEXUAL content when NSFW-only mode is enabled
|
||||
if ContentCategory.SEXUAL not in result.categories:
|
||||
logger.debug(
|
||||
"NSFW-only mode enabled, ignoring non-sexual content: categories=%s",
|
||||
[cat.value for cat in result.categories],
|
||||
)
|
||||
return
|
||||
|
||||
# Check if severity meets threshold based on sensitivity
|
||||
# Higher sensitivity = lower threshold needed to trigger
|
||||
threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30
|
||||
if result.severity < threshold:
|
||||
logger.debug(
|
||||
"AI flagged content but below threshold: severity=%s, threshold=%s",
|
||||
result.severity,
|
||||
threshold,
|
||||
)
|
||||
return
|
||||
|
||||
if result.confidence < config.ai_confidence_threshold:
|
||||
logger.debug(
|
||||
"AI flagged content but below confidence threshold: confidence=%s, threshold=%s",
|
||||
result.confidence,
|
||||
config.ai_confidence_threshold,
|
||||
)
|
||||
return
|
||||
|
||||
log_only = config.ai_log_only
|
||||
|
||||
# Determine action based on suggested action and severity
|
||||
should_delete = not log_only and result.suggested_action in ("delete", "timeout", "ban")
|
||||
should_timeout = (
|
||||
not log_only and result.suggested_action in ("timeout", "ban") and result.severity > 70
|
||||
)
|
||||
timeout_duration: int | None = None
|
||||
|
||||
# Delete message if needed
|
||||
if should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.Forbidden:
|
||||
logger.warning("Cannot delete message: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
# Timeout user for severe violations
|
||||
if should_timeout and isinstance(message.author, discord.Member):
|
||||
timeout_duration = 300 if result.severity < 90 else 3600 # 5 min or 1 hour
|
||||
try:
|
||||
await message.author.timeout(
|
||||
timedelta(seconds=timeout_duration),
|
||||
reason=f"AI Moderation: {result.explanation[:100]}",
|
||||
)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await self._log_ai_db_action(
|
||||
message,
|
||||
result,
|
||||
analysis_type,
|
||||
log_only=log_only,
|
||||
timeout_duration=timeout_duration,
|
||||
)
|
||||
|
||||
# Log to mod channel
|
||||
await self._log_ai_action(message, result, analysis_type, log_only=log_only)
|
||||
|
||||
if log_only:
|
||||
return
|
||||
|
||||
# Notify user
|
||||
embed = discord.Embed(
|
||||
title=f"Message Flagged in {message.guild.name}",
|
||||
description=result.explanation,
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(
|
||||
name="Categories",
|
||||
value=", ".join(cat.value for cat in result.categories) or "Unknown",
|
||||
)
|
||||
if should_timeout:
|
||||
embed.add_field(name="Action", value="You have been timed out")
|
||||
|
||||
# Use notification utility to send DM with in-channel fallback
|
||||
if isinstance(message.channel, discord.TextChannel):
|
||||
await send_moderation_notification(
|
||||
user=message.author,
|
||||
channel=message.channel,
|
||||
embed=embed,
|
||||
send_in_channel=config.send_in_channel_warnings,
|
||||
)
|
||||
|
||||
async def _log_ai_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
log_only: bool = False,
|
||||
) -> None:
|
||||
"""Log an AI moderation action."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"AI Moderation - {analysis_type}",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(
|
||||
name=str(message.author),
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
|
||||
action_label = "log-only" if log_only else result.suggested_action
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Action", value=action_label, inline=True)
|
||||
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories or "None", inline=False)
|
||||
embed.add_field(name="Explanation", value=result.explanation[:500], inline=False)
|
||||
|
||||
if message.content:
|
||||
content = (
|
||||
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||
)
|
||||
embed.add_field(name="Content", value=f"```{content}```", inline=False)
|
||||
|
||||
embed.set_footer(text=f"User ID: {message.author.id} | Channel: #{message.channel.name}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
async def _log_ai_db_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
log_only: bool,
|
||||
timeout_duration: int | None,
|
||||
) -> None:
|
||||
"""Log an AI moderation action to the database."""
|
||||
action = "ai_log" if log_only else f"ai_{result.suggested_action}"
|
||||
reason = result.explanation or f"AI moderation flagged content ({analysis_type})"
|
||||
expires_at = None
|
||||
if timeout_duration:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout_duration)
|
||||
|
||||
async with self.bot.database.session() as session:
|
||||
entry = ModerationLog(
|
||||
guild_id=message.guild.id,
|
||||
target_id=message.author.id,
|
||||
target_name=str(message.author),
|
||||
moderator_id=self.bot.user.id if self.bot.user else 0,
|
||||
moderator_name=str(self.bot.user) if self.bot.user else "GuardDen",
|
||||
action=action,
|
||||
reason=reason,
|
||||
duration=timeout_duration,
|
||||
expires_at=expires_at,
|
||||
channel_id=message.channel.id,
|
||||
message_id=message.id,
|
||||
message_content=message.content,
|
||||
is_automatic=True,
|
||||
)
|
||||
session.add(entry)
|
||||
return False
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Analyze messages with AI moderation."""
|
||||
logger.debug("AI moderation received message from %s", message.author)
|
||||
|
||||
"""Analyze messages for NSFW images with strict cost controls."""
|
||||
# Skip bot messages early
|
||||
if message.author.bot:
|
||||
return
|
||||
@@ -279,109 +87,119 @@ class AIModeration(commands.Cog):
|
||||
if not message.guild:
|
||||
return
|
||||
|
||||
logger.info(f"AI mod checking message from {message.author} in {message.guild.name}")
|
||||
|
||||
# Check if AI moderation is enabled for this guild
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.ai_moderation_enabled:
|
||||
logger.debug(f"AI moderation disabled for guild {message.guild.id}")
|
||||
# Get config from YAML
|
||||
config = self.bot.config_loader
|
||||
if not config.get_setting("ai_moderation.enabled", True):
|
||||
return
|
||||
|
||||
# Check if user is whitelisted
|
||||
if message.author.id in config.whitelisted_user_ids:
|
||||
logger.debug(f"Skipping whitelisted user {message.author}")
|
||||
# Check NSFW video domain blocklist first (no AI cost)
|
||||
if self._has_nsfw_video_link(message.content):
|
||||
try:
|
||||
await message.delete()
|
||||
logger.info(f"Deleted message with NSFW video link from {message.author}")
|
||||
except (discord.Forbidden, discord.NotFound):
|
||||
pass
|
||||
return
|
||||
|
||||
# Skip users with manage_messages permission (disabled for testing)
|
||||
# if isinstance(message.author, discord.Member):
|
||||
# if message.author.guild_permissions.manage_messages:
|
||||
# logger.debug(f"Skipping message from privileged user {message.author}")
|
||||
# return
|
||||
|
||||
# Check if should analyze (has images/embeds, not analyzed yet)
|
||||
if not self._should_analyze(message):
|
||||
logger.debug(f"Message {message.id} skipped by _should_analyze")
|
||||
return
|
||||
|
||||
self._track_message(message.id)
|
||||
logger.info(f"Analyzing message {message.id} from {message.author}")
|
||||
# Check rate limits (CRITICAL for cost control)
|
||||
max_guild_per_hour = config.get_setting("ai_moderation.max_checks_per_hour_per_guild", 25)
|
||||
max_user_per_hour = config.get_setting("ai_moderation.max_checks_per_user_per_hour", 5)
|
||||
|
||||
# Analyze text content
|
||||
if message.content and len(message.content) >= 20:
|
||||
result = await self.bot.ai_provider.moderate_text(
|
||||
content=message.content,
|
||||
context=f"Discord server: {message.guild.name}, channel: {message.channel.name}",
|
||||
sensitivity=config.ai_sensitivity,
|
||||
rate_limit_result = self.bot.ai_rate_limiter.is_limited(
|
||||
message.guild.id,
|
||||
message.author.id,
|
||||
max_guild_per_hour,
|
||||
max_user_per_hour,
|
||||
)
|
||||
|
||||
if result.is_flagged:
|
||||
await self._handle_ai_result(message, result, "Text Analysis")
|
||||
return # Don't continue if already flagged
|
||||
if rate_limit_result["is_limited"]:
|
||||
logger.warning(
|
||||
f"AI rate limit hit: {rate_limit_result['reason']} "
|
||||
f"(guild: {rate_limit_result['guild_checks_this_hour']}/{max_guild_per_hour}, "
|
||||
f"user: {rate_limit_result['user_checks_this_hour']}/{max_user_per_hour})"
|
||||
)
|
||||
return
|
||||
|
||||
# Get AI settings
|
||||
sensitivity = config.get_setting("ai_moderation.sensitivity", 80)
|
||||
nsfw_only_filtering = config.get_setting("ai_moderation.nsfw_only_filtering", True)
|
||||
max_images = config.get_setting("ai_moderation.max_images_per_message", 2)
|
||||
max_size_mb = config.get_setting("ai_moderation.max_image_size_mb", 3)
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
check_embeds = config.get_setting("ai_moderation.check_embed_images", True)
|
||||
|
||||
# Analyze images if NSFW detection is enabled (limit to 3 per message)
|
||||
images_analyzed = 0
|
||||
if config.nsfw_detection_enabled and message.attachments:
|
||||
logger.info(f"Checking {len(message.attachments)} attachments for NSFW content")
|
||||
|
||||
# Analyze image attachments
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if images_analyzed >= 3:
|
||||
if images_analyzed >= max_images:
|
||||
break
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
|
||||
# Skip non-images
|
||||
if not attachment.content_type or not attachment.content_type.startswith("image/"):
|
||||
continue
|
||||
|
||||
# Skip large files (cost control)
|
||||
if attachment.size > max_size_bytes:
|
||||
logger.debug(f"Skipping large image: {attachment.size} bytes > {max_size_bytes}")
|
||||
continue
|
||||
|
||||
images_analyzed += 1
|
||||
logger.info(f"Analyzing image: {attachment.url[:80]}...")
|
||||
|
||||
logger.info(f"Analyzing image {images_analyzed}/{max_images} from {message.author}")
|
||||
|
||||
# AI check
|
||||
try:
|
||||
image_result = await self.bot.ai_provider.analyze_image(
|
||||
image_url=attachment.url,
|
||||
sensitivity=config.ai_sensitivity,
|
||||
sensitivity=sensitivity,
|
||||
)
|
||||
logger.info(
|
||||
f"Image result: nsfw={image_result.is_nsfw}, category={image_result.nsfw_category}, "
|
||||
f"severity={image_result.nsfw_severity}, violent={image_result.is_violent}, conf={image_result.confidence}"
|
||||
except Exception as e:
|
||||
logger.error(f"AI image analysis failed: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
f"Image result: nsfw={image_result.is_nsfw}, "
|
||||
f"category={image_result.nsfw_category}, "
|
||||
f"confidence={image_result.confidence}"
|
||||
)
|
||||
|
||||
# Filter based on NSFW-only mode setting
|
||||
should_flag_image = False
|
||||
categories = []
|
||||
# Track AI usage
|
||||
self.bot.ai_rate_limiter.track_usage(message.guild.id, message.author.id)
|
||||
self._track_message(message.id)
|
||||
|
||||
if config.nsfw_only_filtering:
|
||||
# In NSFW-only mode, only flag sexual content
|
||||
# Filter based on NSFW-only mode
|
||||
should_flag = False
|
||||
if nsfw_only_filtering:
|
||||
# Only flag sexual content
|
||||
if image_result.is_nsfw:
|
||||
should_flag_image = True
|
||||
categories.append(ContentCategory.SEXUAL)
|
||||
should_flag = True
|
||||
else:
|
||||
# Normal mode: flag all inappropriate content
|
||||
if image_result.is_nsfw:
|
||||
should_flag_image = True
|
||||
categories.append(ContentCategory.SEXUAL)
|
||||
if image_result.is_violent:
|
||||
should_flag_image = True
|
||||
categories.append(ContentCategory.VIOLENCE)
|
||||
if image_result.is_disturbing:
|
||||
should_flag_image = True
|
||||
# Flag all inappropriate content
|
||||
if image_result.is_nsfw or image_result.is_violent or image_result.is_disturbing:
|
||||
should_flag = True
|
||||
|
||||
if should_flag_image:
|
||||
# Use nsfw_severity if available, otherwise use None for default calculation
|
||||
severity_override = (
|
||||
image_result.nsfw_severity if image_result.nsfw_severity > 0 else None
|
||||
if should_flag:
|
||||
# Delete message (no logging, no timeout, no DM)
|
||||
try:
|
||||
await message.delete()
|
||||
logger.info(
|
||||
f"Deleted NSFW image from {message.author} in {message.guild.name}: "
|
||||
f"category={image_result.nsfw_category}, confidence={image_result.confidence:.2f}"
|
||||
)
|
||||
|
||||
# Include NSFW category in explanation for better logging
|
||||
explanation = image_result.description
|
||||
if image_result.nsfw_category and image_result.nsfw_category != "none":
|
||||
explanation = f"[{image_result.nsfw_category}] {explanation}"
|
||||
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=image_result.confidence,
|
||||
categories=categories,
|
||||
explanation=explanation,
|
||||
suggested_action=_get_action_for_nsfw(image_result.nsfw_category),
|
||||
severity_override=severity_override,
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Image Analysis")
|
||||
except (discord.Forbidden, discord.NotFound):
|
||||
pass
|
||||
return
|
||||
|
||||
# Also analyze images from embeds (GIFs from Discord's GIF picker use embeds)
|
||||
if config.nsfw_detection_enabled and message.embeds:
|
||||
# Optionally check embed images (GIFs from Discord picker)
|
||||
if check_embeds and message.embeds:
|
||||
for embed in message.embeds:
|
||||
if images_analyzed >= 3:
|
||||
if images_analyzed >= max_images:
|
||||
break
|
||||
|
||||
# Check embed image or thumbnail (GIFs often use thumbnail)
|
||||
@@ -391,272 +209,56 @@ class AIModeration(commands.Cog):
|
||||
elif embed.thumbnail and embed.thumbnail.url:
|
||||
image_url = embed.thumbnail.url
|
||||
|
||||
if image_url:
|
||||
if not image_url:
|
||||
continue
|
||||
|
||||
images_analyzed += 1
|
||||
logger.info(f"Analyzing embed image: {image_url[:80]}...")
|
||||
|
||||
logger.info(f"Analyzing embed image {images_analyzed}/{max_images} from {message.author}")
|
||||
|
||||
# AI check
|
||||
try:
|
||||
image_result = await self.bot.ai_provider.analyze_image(
|
||||
image_url=image_url,
|
||||
sensitivity=config.ai_sensitivity,
|
||||
sensitivity=sensitivity,
|
||||
)
|
||||
logger.info(
|
||||
f"Embed image result: nsfw={image_result.is_nsfw}, category={image_result.nsfw_category}, "
|
||||
f"severity={image_result.nsfw_severity}, violent={image_result.is_violent}, conf={image_result.confidence}"
|
||||
)
|
||||
|
||||
# Filter based on NSFW-only mode setting
|
||||
should_flag_image = False
|
||||
categories = []
|
||||
|
||||
if config.nsfw_only_filtering:
|
||||
# In NSFW-only mode, only flag sexual content
|
||||
if image_result.is_nsfw:
|
||||
should_flag_image = True
|
||||
categories.append(ContentCategory.SEXUAL)
|
||||
else:
|
||||
# Normal mode: flag all inappropriate content
|
||||
if image_result.is_nsfw:
|
||||
should_flag_image = True
|
||||
categories.append(ContentCategory.SEXUAL)
|
||||
if image_result.is_violent:
|
||||
should_flag_image = True
|
||||
categories.append(ContentCategory.VIOLENCE)
|
||||
if image_result.is_disturbing:
|
||||
should_flag_image = True
|
||||
|
||||
if should_flag_image:
|
||||
# Use nsfw_severity if available, otherwise use None for default calculation
|
||||
severity_override = (
|
||||
image_result.nsfw_severity if image_result.nsfw_severity > 0 else None
|
||||
)
|
||||
|
||||
# Include NSFW category in explanation for better logging
|
||||
explanation = image_result.description
|
||||
if image_result.nsfw_category and image_result.nsfw_category != "none":
|
||||
explanation = f"[{image_result.nsfw_category}] {explanation}"
|
||||
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=image_result.confidence,
|
||||
categories=categories,
|
||||
explanation=explanation,
|
||||
suggested_action=_get_action_for_nsfw(image_result.nsfw_category),
|
||||
severity_override=severity_override,
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Image Analysis")
|
||||
return
|
||||
|
||||
# Analyze URLs for phishing
|
||||
urls = URL_PATTERN.findall(message.content)
|
||||
allowlist = {normalize_domain(domain) for domain in config.scam_allowlist if domain}
|
||||
for url in urls[:3]: # Limit to first 3 URLs
|
||||
hostname = normalize_domain(url)
|
||||
if allowlist and is_allowed_domain(hostname, allowlist):
|
||||
except Exception as e:
|
||||
logger.error(f"AI embed image analysis failed: {e}", exc_info=True)
|
||||
continue
|
||||
phishing_result = await self.bot.ai_provider.analyze_phishing(
|
||||
url=url,
|
||||
message_content=message.content,
|
||||
|
||||
logger.debug(
|
||||
f"Embed image result: nsfw={image_result.is_nsfw}, "
|
||||
f"category={image_result.nsfw_category}, "
|
||||
f"confidence={image_result.confidence}"
|
||||
)
|
||||
|
||||
if phishing_result.is_phishing and phishing_result.confidence > 0.7:
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=phishing_result.confidence,
|
||||
categories=[ContentCategory.SCAM],
|
||||
explanation=phishing_result.explanation,
|
||||
suggested_action="delete",
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Phishing Detection")
|
||||
return
|
||||
# Track AI usage
|
||||
self.bot.ai_rate_limiter.track_usage(message.guild.id, message.author.id)
|
||||
self._track_message(message.id)
|
||||
|
||||
@commands.group(name="ai", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_cmd(self, ctx: commands.Context) -> None:
|
||||
"""View AI moderation settings."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="AI Moderation Settings",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="AI Moderation",
|
||||
value="✅ Enabled" if config and config.ai_moderation_enabled else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="NSFW Detection",
|
||||
value="✅ Enabled" if config and config.nsfw_detection_enabled else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Sensitivity",
|
||||
value=f"{config.ai_sensitivity}/100" if config else "50/100",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Confidence Threshold",
|
||||
value=f"{config.ai_confidence_threshold:.2f}" if config else "0.70",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Log Only",
|
||||
value="✅ Enabled" if config and config.ai_log_only else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="NSFW-Only Mode",
|
||||
value="✅ Enabled" if config and config.nsfw_only_filtering else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="AI Provider",
|
||||
value=self.bot.settings.ai_provider.capitalize(),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@ai_cmd.command(name="enable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_enable(self, ctx: commands.Context) -> None:
|
||||
"""Enable AI moderation."""
|
||||
if self.bot.settings.ai_provider == "none":
|
||||
await ctx.send(
|
||||
"AI moderation is not configured. Set `GUARDDEN_AI_PROVIDER` and API key."
|
||||
)
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=True)
|
||||
await ctx.send("✅ AI moderation enabled.")
|
||||
|
||||
@ai_cmd.command(name="disable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_disable(self, ctx: commands.Context) -> None:
|
||||
"""Disable AI moderation."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=False)
|
||||
await ctx.send("❌ AI moderation disabled.")
|
||||
|
||||
@ai_cmd.command(name="sensitivity")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_sensitivity(self, ctx: commands.Context, level: int) -> None:
|
||||
"""Set AI sensitivity level (0-100). Higher = more strict."""
|
||||
if not 0 <= level <= 100:
|
||||
await ctx.send("Sensitivity must be between 0 and 100.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level)
|
||||
await ctx.send(f"AI sensitivity set to {level}/100.")
|
||||
|
||||
@ai_cmd.command(name="threshold")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_threshold(self, ctx: commands.Context, value: float) -> None:
|
||||
"""Set AI confidence threshold (0.0-1.0)."""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
await ctx.send("Threshold must be between 0.0 and 1.0.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_confidence_threshold=value)
|
||||
await ctx.send(f"AI confidence threshold set to {value:.2f}.")
|
||||
|
||||
@ai_cmd.command(name="logonly")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_logonly(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable log-only mode for AI moderation."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_log_only=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"AI log-only mode {status}.")
|
||||
|
||||
@ai_cmd.command(name="nsfw")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_nsfw(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable NSFW image detection."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, nsfw_detection_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"NSFW detection {status}.")
|
||||
|
||||
@ai_cmd.command(name="nsfwonly")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_nsfw_only(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable NSFW-only filtering mode.
|
||||
|
||||
When enabled, only sexual/nude content will be filtered.
|
||||
Violence, harassment, and other content types will be allowed.
|
||||
"""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, nsfw_only_filtering=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
|
||||
if enabled:
|
||||
embed = discord.Embed(
|
||||
title="NSFW-Only Mode Enabled",
|
||||
description="⚠️ **Important:** Only sexual and nude content will now be filtered.\n"
|
||||
"Violence, harassment, hate speech, and other content types will be **allowed**.",
|
||||
color=discord.Color.orange(),
|
||||
)
|
||||
embed.add_field(
|
||||
name="What will be filtered:",
|
||||
value="• Sexual content\n• Nude images\n• Explicit material",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="What will be allowed:",
|
||||
value="• Violence and gore\n• Harassment\n• Hate speech\n• Self-harm content",
|
||||
inline=True,
|
||||
)
|
||||
embed.set_footer(text="Use '!ai nsfwonly false' to return to normal filtering")
|
||||
# Filter based on NSFW-only mode
|
||||
should_flag = False
|
||||
if nsfw_only_filtering:
|
||||
# Only flag sexual content
|
||||
if image_result.is_nsfw:
|
||||
should_flag = True
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
title="NSFW-Only Mode Disabled",
|
||||
description="✅ Normal content filtering restored.\n"
|
||||
"All inappropriate content types will now be filtered.",
|
||||
color=discord.Color.green(),
|
||||
# Flag all inappropriate content
|
||||
if image_result.is_nsfw or image_result.is_violent or image_result.is_disturbing:
|
||||
should_flag = True
|
||||
|
||||
if should_flag:
|
||||
# Delete message (no logging, no timeout, no DM)
|
||||
try:
|
||||
await message.delete()
|
||||
logger.info(
|
||||
f"Deleted NSFW embed from {message.author} in {message.guild.name}: "
|
||||
f"category={image_result.nsfw_category}, confidence={image_result.confidence:.2f}"
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@ai_cmd.command(name="analyze")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_analyze(self, ctx: commands.Context, *, text: str) -> None:
|
||||
"""Test AI analysis on text (does not take action)."""
|
||||
if self.bot.settings.ai_provider == "none":
|
||||
await ctx.send("AI moderation is not configured.")
|
||||
except (discord.Forbidden, discord.NotFound):
|
||||
pass
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
result = await self.bot.ai_provider.moderate_text(
|
||||
content=text,
|
||||
context=f"Test analysis in {ctx.guild.name}",
|
||||
sensitivity=50,
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="AI Analysis Result",
|
||||
color=discord.Color.red() if result.is_flagged else discord.Color.green(),
|
||||
)
|
||||
|
||||
embed.add_field(name="Flagged", value="Yes" if result.is_flagged else "No", inline=True)
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Suggested Action", value=result.suggested_action, inline=True)
|
||||
|
||||
if result.categories:
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories, inline=False)
|
||||
|
||||
if result.explanation:
|
||||
embed.add_field(name="Explanation", value=result.explanation[:1000], inline=False)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the AI Moderation cog."""
|
||||
|
||||
@@ -1,331 +1,81 @@
|
||||
"""Automod cog for automatic content moderation."""
|
||||
"""Automod cog for automatic spam detection - Minimal Version."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.models import ModerationLog, Strike
|
||||
from guardden.services.automod import (
|
||||
AutomodResult,
|
||||
AutomodService,
|
||||
SpamConfig,
|
||||
normalize_domain,
|
||||
)
|
||||
from guardden.utils.notifications import send_moderation_notification
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
from guardden.services.automod import AutomodResult, AutomodService, SpamConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Automod(commands.Cog):
|
||||
"""Automatic content moderation."""
|
||||
"""Automatic spam detection (no commands, no banned words)."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
self.automod = AutomodService()
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Optional owner allowlist for automod commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
return self.bot.is_owner_allowed(ctx.author.id)
|
||||
def _spam_config(self) -> SpamConfig:
|
||||
"""Get spam config from YAML."""
|
||||
config_loader = self.bot.config_loader
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
def _spam_config(self, config) -> SpamConfig:
|
||||
if not config:
|
||||
return self.automod.default_spam_config
|
||||
return SpamConfig(
|
||||
message_rate_limit=config.message_rate_limit,
|
||||
message_rate_window=config.message_rate_window,
|
||||
duplicate_threshold=config.duplicate_threshold,
|
||||
mention_limit=config.mention_limit,
|
||||
mention_rate_limit=config.mention_rate_limit,
|
||||
mention_rate_window=config.mention_rate_window,
|
||||
message_rate_limit=config_loader.get_setting("automod.message_rate_limit", 5),
|
||||
message_rate_window=config_loader.get_setting("automod.message_rate_window", 5),
|
||||
duplicate_threshold=config_loader.get_setting("automod.duplicate_threshold", 3),
|
||||
mention_limit=config_loader.get_setting("automod.mention_limit", 5),
|
||||
mention_rate_limit=config_loader.get_setting("automod.mention_rate_limit", 10),
|
||||
mention_rate_window=config_loader.get_setting("automod.mention_rate_window", 60),
|
||||
)
|
||||
|
||||
async def _get_strike_count(self, guild_id: int, user_id: int) -> int:
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(func.sum(Strike.points)).where(
|
||||
Strike.guild_id == guild_id,
|
||||
Strike.user_id == user_id,
|
||||
Strike.is_active == True,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return total or 0
|
||||
|
||||
async def _add_strike(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
member: discord.Member,
|
||||
reason: str,
|
||||
) -> int:
|
||||
async with self.bot.database.session() as session:
|
||||
strike = Strike(
|
||||
guild_id=guild.id,
|
||||
user_id=member.id,
|
||||
user_name=str(member),
|
||||
moderator_id=self.bot.user.id if self.bot.user else 0,
|
||||
reason=reason,
|
||||
points=1,
|
||||
)
|
||||
session.add(strike)
|
||||
|
||||
return await self._get_strike_count(guild.id, member.id)
|
||||
|
||||
async def _apply_strike_actions(
|
||||
self,
|
||||
member: discord.Member,
|
||||
total_strikes: int,
|
||||
config,
|
||||
) -> None:
|
||||
if not config or not config.strike_actions:
|
||||
return
|
||||
|
||||
for threshold, action_config in sorted(
|
||||
config.strike_actions.items(), key=lambda item: int(item[0]), reverse=True
|
||||
):
|
||||
if total_strikes < int(threshold):
|
||||
continue
|
||||
action = action_config.get("action")
|
||||
if action == "ban":
|
||||
await member.ban(reason=f"Automod: {total_strikes} strikes")
|
||||
elif action == "kick":
|
||||
await member.kick(reason=f"Automod: {total_strikes} strikes")
|
||||
elif action == "timeout":
|
||||
duration = action_config.get("duration", 3600)
|
||||
await member.timeout(
|
||||
timedelta(seconds=duration),
|
||||
reason=f"Automod: {total_strikes} strikes",
|
||||
)
|
||||
break
|
||||
|
||||
async def _log_database_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
async with self.bot.database.session() as session:
|
||||
action = "delete"
|
||||
if result.should_timeout:
|
||||
action = "timeout"
|
||||
elif result.should_strike:
|
||||
action = "strike"
|
||||
elif result.should_warn:
|
||||
action = "warn"
|
||||
|
||||
expires_at = None
|
||||
if result.timeout_duration:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=result.timeout_duration)
|
||||
|
||||
log_entry = ModerationLog(
|
||||
guild_id=message.guild.id,
|
||||
target_id=message.author.id,
|
||||
target_name=str(message.author),
|
||||
moderator_id=self.bot.user.id if self.bot.user else 0,
|
||||
moderator_name=str(self.bot.user) if self.bot.user else "GuardDen",
|
||||
action=action,
|
||||
reason=result.reason,
|
||||
duration=result.timeout_duration or None,
|
||||
expires_at=expires_at,
|
||||
channel_id=message.channel.id,
|
||||
message_id=message.id,
|
||||
message_content=message.content,
|
||||
is_automatic=True,
|
||||
)
|
||||
session.add(log_entry)
|
||||
|
||||
async def _handle_violation(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
"""Handle an automod violation."""
|
||||
# Delete the message
|
||||
"""Handle an automod violation by deleting the message."""
|
||||
# Delete the message (no logging, no timeout, no DM)
|
||||
if result.should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
logger.info(
|
||||
f"Automod deleted message from {message.author} in {message.guild.name}: {result.reason}"
|
||||
)
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot delete message in {message.guild}: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass # Already deleted
|
||||
|
||||
# Apply timeout
|
||||
if result.should_timeout and result.timeout_duration > 0:
|
||||
try:
|
||||
await message.author.timeout(
|
||||
timedelta(seconds=result.timeout_duration),
|
||||
reason=f"Automod: {result.reason}",
|
||||
)
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot timeout {message.author}: missing permissions")
|
||||
|
||||
# Log the action
|
||||
await self._log_database_action(message, result)
|
||||
await self._log_automod_action(message, result)
|
||||
|
||||
# Apply strike escalation if configured
|
||||
if (result.should_warn or result.should_strike) and isinstance(
|
||||
message.author, discord.Member
|
||||
):
|
||||
total = await self._add_strike(message.guild, message.author, result.reason)
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
await self._apply_strike_actions(message.author, total, config)
|
||||
|
||||
# Notify the user
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
embed = discord.Embed(
|
||||
title=f"Message Removed in {message.guild.name}",
|
||||
description=result.reason,
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
if result.should_timeout:
|
||||
embed.add_field(
|
||||
name="Timeout",
|
||||
value=f"You have been timed out for {result.timeout_duration} seconds.",
|
||||
)
|
||||
|
||||
# Use notification utility to send DM with in-channel fallback
|
||||
if isinstance(message.channel, discord.TextChannel):
|
||||
await send_moderation_notification(
|
||||
user=message.author,
|
||||
channel=message.channel,
|
||||
embed=embed,
|
||||
send_in_channel=config.send_in_channel_warnings if config else False,
|
||||
)
|
||||
|
||||
async def _log_automod_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
"""Log an automod action to the mod log channel."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Action",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(
|
||||
name=str(message.author),
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
embed.add_field(name="Filter", value=result.matched_filter, inline=True)
|
||||
embed.add_field(name="Channel", value=message.channel.mention, inline=True)
|
||||
embed.add_field(name="Reason", value=result.reason, inline=False)
|
||||
|
||||
if message.content:
|
||||
content = (
|
||||
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||
)
|
||||
embed.add_field(name="Message Content", value=f"```{content}```", inline=False)
|
||||
|
||||
actions = []
|
||||
if result.should_delete:
|
||||
actions.append("Message deleted")
|
||||
if result.should_warn:
|
||||
actions.append("User warned")
|
||||
if result.should_strike:
|
||||
actions.append("Strike added")
|
||||
if result.should_timeout:
|
||||
actions.append(f"Timeout ({result.timeout_duration}s)")
|
||||
|
||||
embed.add_field(name="Actions Taken", value=", ".join(actions) or "None", inline=False)
|
||||
embed.set_footer(text=f"User ID: {message.author.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Check all messages for automod violations."""
|
||||
# Ignore DMs, bots, and empty messages
|
||||
"""Check all messages for spam violations."""
|
||||
# Skip DMs, bots, and empty messages
|
||||
if not message.guild or message.author.bot or not message.content:
|
||||
return
|
||||
|
||||
# Ignore users with manage_messages permission
|
||||
if isinstance(message.author, discord.Member):
|
||||
if message.author.guild_permissions.manage_messages:
|
||||
# Get config from YAML
|
||||
config = self.bot.config_loader
|
||||
if not config.get_setting("automod.enabled", True):
|
||||
return
|
||||
|
||||
# Get guild config
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.automod_enabled:
|
||||
return
|
||||
|
||||
# Check if user is whitelisted
|
||||
if message.author.id in config.whitelisted_user_ids:
|
||||
return
|
||||
|
||||
result: AutomodResult | None = None
|
||||
|
||||
# Check banned words
|
||||
banned_words = await self.bot.guild_config.get_banned_words(message.guild.id)
|
||||
if banned_words:
|
||||
result = self.automod.check_banned_words(message.content, banned_words)
|
||||
|
||||
spam_config = self._spam_config(config)
|
||||
|
||||
# Check scam links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_scam_links(
|
||||
message.content,
|
||||
allowlist=config.scam_allowlist,
|
||||
)
|
||||
|
||||
# Check spam
|
||||
if not result and config.anti_spam_enabled:
|
||||
# Check spam ONLY (no banned words, no scam links, no invites)
|
||||
if config.get_setting("automod.anti_spam_enabled", True):
|
||||
spam_config = self._spam_config()
|
||||
result = self.automod.check_spam(
|
||||
message,
|
||||
anti_spam_enabled=True,
|
||||
spam_config=spam_config,
|
||||
)
|
||||
|
||||
# Check invite links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_invite_links(message.content, allow_invites=False)
|
||||
|
||||
# Handle violation if found
|
||||
if result:
|
||||
logger.info(
|
||||
f"Automod triggered in {message.guild.name}: "
|
||||
f"{result.matched_filter} by {message.author}"
|
||||
)
|
||||
await self._handle_violation(message, result)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||
"""Check edited messages for automod violations."""
|
||||
"""Check edited messages for spam violations."""
|
||||
# Only check if content changed
|
||||
if before.content == after.content:
|
||||
return
|
||||
@@ -333,186 +83,6 @@ class Automod(commands.Cog):
|
||||
# Reuse on_message logic
|
||||
await self.on_message(after)
|
||||
|
||||
@commands.group(name="automod", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_cmd(self, ctx: commands.Context) -> None:
|
||||
"""View automod status and configuration."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Configuration",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Automod Enabled",
|
||||
value="✅ Yes" if config and config.automod_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Anti-Spam",
|
||||
value="✅ Yes" if config and config.anti_spam_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Link Filter",
|
||||
value="✅ Yes" if config and config.link_filter_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
spam_config = self._spam_config(config)
|
||||
|
||||
# Show thresholds
|
||||
embed.add_field(
|
||||
name="Rate Limit",
|
||||
value=f"{spam_config.message_rate_limit} msgs / {spam_config.message_rate_window}s",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Duplicate Threshold",
|
||||
value=f"{spam_config.duplicate_threshold} same messages",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mention Limit",
|
||||
value=f"{spam_config.mention_limit} per message",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mention Rate",
|
||||
value=f"{spam_config.mention_rate_limit} mentions / {spam_config.mention_rate_window}s",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
embed.add_field(
|
||||
name="Banned Words",
|
||||
value=f"{len(banned_words)} configured",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@automod_cmd.command(name="threshold")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_threshold(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
setting: Literal[
|
||||
"message_rate_limit",
|
||||
"message_rate_window",
|
||||
"duplicate_threshold",
|
||||
"mention_limit",
|
||||
"mention_rate_limit",
|
||||
"mention_rate_window",
|
||||
],
|
||||
value: int,
|
||||
) -> None:
|
||||
"""Update a single automod threshold."""
|
||||
if value <= 0:
|
||||
await ctx.send("Threshold values must be positive.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, **{setting: value})
|
||||
await ctx.send(f"Updated `{setting}` to {value}.")
|
||||
|
||||
@automod_cmd.group(name="allowlist", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_allowlist(self, ctx: commands.Context) -> None:
|
||||
"""Show the scam link allowlist."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
allowlist = sorted(config.scam_allowlist) if config else []
|
||||
if not allowlist:
|
||||
await ctx.send("No allowlisted domains configured.")
|
||||
return
|
||||
|
||||
formatted = "\n".join(f"- `{domain}`" for domain in allowlist[:20])
|
||||
await ctx.send(f"Allowed domains:\n{formatted}")
|
||||
|
||||
@automod_allowlist.command(name="add")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_allowlist_add(self, ctx: commands.Context, domain: str) -> None:
|
||||
"""Add a domain to the scam link allowlist."""
|
||||
normalized = normalize_domain(domain)
|
||||
if not normalized:
|
||||
await ctx.send("Provide a valid domain or URL to allowlist.")
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
allowlist = list(config.scam_allowlist) if config else []
|
||||
|
||||
if normalized in allowlist:
|
||||
await ctx.send(f"`{normalized}` is already allowlisted.")
|
||||
return
|
||||
|
||||
allowlist.append(normalized)
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, scam_allowlist=allowlist)
|
||||
await ctx.send(f"Added `{normalized}` to the allowlist.")
|
||||
|
||||
@automod_allowlist.command(name="remove")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_allowlist_remove(self, ctx: commands.Context, domain: str) -> None:
|
||||
"""Remove a domain from the scam link allowlist."""
|
||||
normalized = normalize_domain(domain)
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
allowlist = list(config.scam_allowlist) if config else []
|
||||
|
||||
if normalized not in allowlist:
|
||||
await ctx.send(f"`{normalized}` is not in the allowlist.")
|
||||
return
|
||||
|
||||
allowlist.remove(normalized)
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, scam_allowlist=allowlist)
|
||||
await ctx.send(f"Removed `{normalized}` from the allowlist.")
|
||||
|
||||
@automod_cmd.command(name="test")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_test(self, ctx: commands.Context, *, text: str) -> None:
|
||||
"""Test a message against automod filters (does not take action)."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
results = []
|
||||
|
||||
# Check banned words
|
||||
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
result = self.automod.check_banned_words(text, banned_words)
|
||||
if result:
|
||||
results.append(f"**Banned Words**: {result.reason}")
|
||||
|
||||
# Check scam links
|
||||
result = self.automod.check_scam_links(
|
||||
text, allowlist=config.scam_allowlist if config else []
|
||||
)
|
||||
if result:
|
||||
results.append(f"**Scam Detection**: {result.reason}")
|
||||
|
||||
# Check invite links
|
||||
result = self.automod.check_invite_links(text, allow_invites=False)
|
||||
if result:
|
||||
results.append(f"**Invite Links**: {result.reason}")
|
||||
|
||||
# Check caps
|
||||
result = self.automod.check_all_caps(text)
|
||||
if result:
|
||||
results.append(f"**Excessive Caps**: {result.reason}")
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Test Results",
|
||||
color=discord.Color.red() if results else discord.Color.green(),
|
||||
)
|
||||
|
||||
if results:
|
||||
embed.description = "\n".join(results)
|
||||
else:
|
||||
embed.description = "✅ No violations detected"
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Automod cog."""
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
"""Automod service for content filtering and spam detection."""
|
||||
"""Automod service for spam detection - Minimal Version."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, NamedTuple, Sequence
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
@@ -16,221 +13,17 @@ else:
|
||||
try:
|
||||
import discord # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
|
||||
class _DiscordStub:
|
||||
class Message: # minimal stub for type hints
|
||||
class Message:
|
||||
pass
|
||||
|
||||
discord = _DiscordStub() # type: ignore
|
||||
|
||||
from guardden.models.guild import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Circuit breaker for regex safety
|
||||
class RegexTimeoutError(Exception):
|
||||
"""Raised when regex execution takes too long."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RegexCircuitBreaker:
|
||||
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
||||
|
||||
def __init__(self, timeout_seconds: float = 0.1):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.failed_patterns: dict[str, datetime] = {}
|
||||
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
||||
|
||||
def _timeout_handler(self, signum, frame):
|
||||
"""Signal handler for regex timeout."""
|
||||
raise RegexTimeoutError("Regex execution timed out")
|
||||
|
||||
def is_pattern_disabled(self, pattern: str) -> bool:
|
||||
"""Check if a pattern is temporarily disabled due to timeouts."""
|
||||
if pattern not in self.failed_patterns:
|
||||
return False
|
||||
|
||||
failure_time = self.failed_patterns[pattern]
|
||||
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
||||
# Re-enable the pattern after threshold time
|
||||
del self.failed_patterns[pattern]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
||||
"""Safely execute regex search with timeout protection."""
|
||||
if self.is_pattern_disabled(pattern):
|
||||
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
# Basic pattern validation to catch obviously problematic patterns
|
||||
if self._is_dangerous_pattern(pattern):
|
||||
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
old_handler = None
|
||||
try:
|
||||
# Set up timeout signal (Unix systems only)
|
||||
if hasattr(signal, "SIGALRM"):
|
||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Compile and execute regex
|
||||
compiled_pattern = re.compile(pattern, flags)
|
||||
result = bool(compiled_pattern.search(text))
|
||||
|
||||
execution_time = time.perf_counter() - start_time
|
||||
|
||||
# Log slow patterns for monitoring
|
||||
if execution_time > self.timeout_seconds * 0.8:
|
||||
logger.warning(
|
||||
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except RegexTimeoutError:
|
||||
# Pattern took too long, disable it temporarily
|
||||
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
||||
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in regex execution: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Clean up timeout signal
|
||||
if hasattr(signal, "SIGALRM") and old_handler is not None:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
||||
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
||||
# Check for patterns that are commonly problematic
|
||||
dangerous_indicators = [
|
||||
r"(\w+)+", # Nested quantifiers
|
||||
r"(\d+)+", # Nested quantifiers on digits
|
||||
r"(.+)+", # Nested quantifiers on anything
|
||||
r"(.*)+", # Nested quantifiers on anything (greedy)
|
||||
r"(\w*)+", # Nested quantifiers with *
|
||||
r"(\S+)+", # Nested quantifiers on non-whitespace
|
||||
]
|
||||
|
||||
# Check for excessively long patterns
|
||||
if len(pattern) > 500:
|
||||
return True
|
||||
|
||||
# Check for nested quantifiers (simplified detection)
|
||||
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
|
||||
return True
|
||||
|
||||
# Check for excessive repetition operators
|
||||
if pattern.count("+") > 10 or pattern.count("*") > 10:
|
||||
return True
|
||||
|
||||
# Check for specific dangerous patterns
|
||||
for dangerous in dangerous_indicators:
|
||||
if dangerous in pattern:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global circuit breaker instance
|
||||
_regex_circuit_breaker = RegexCircuitBreaker()
|
||||
|
||||
|
||||
# Known scam/phishing patterns
|
||||
SCAM_PATTERNS = [
|
||||
# Discord scam patterns
|
||||
r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}",
|
||||
r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro",
|
||||
r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)",
|
||||
# Generic phishing
|
||||
r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account",
|
||||
r"(?:suspended|locked|limited)[-.\s]?account",
|
||||
r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)",
|
||||
# Crypto scams
|
||||
r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)",
|
||||
r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)",
|
||||
]
|
||||
|
||||
# Suspicious TLDs often used in phishing
|
||||
SUSPICIOUS_TLDS = {
|
||||
".xyz",
|
||||
".top",
|
||||
".club",
|
||||
".work",
|
||||
".click",
|
||||
".link",
|
||||
".info",
|
||||
".ru",
|
||||
".cn",
|
||||
".tk",
|
||||
".ml",
|
||||
".ga",
|
||||
".cf",
|
||||
".gq",
|
||||
}
|
||||
|
||||
# URL pattern for extraction - more restrictive for security
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?|"
|
||||
r"(?:www\.)?[a-zA-Z0-9-]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|gov|edu)(?:/[^\s]*)?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class SpamRecord(NamedTuple):
|
||||
"""Record of a message for spam tracking."""
|
||||
|
||||
content_hash: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSpamTracker:
|
||||
"""Tracks spam behavior for a single user."""
|
||||
|
||||
messages: list[SpamRecord] = field(default_factory=list)
|
||||
mention_count: int = 0
|
||||
last_mention_time: datetime | None = None
|
||||
duplicate_count: int = 0
|
||||
last_action_time: datetime | None = None
|
||||
|
||||
def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None:
|
||||
"""Remove old messages from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
self.messages = [m for m in self.messages if m.timestamp > cutoff]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutomodResult:
|
||||
"""Result of automod check."""
|
||||
|
||||
should_delete: bool = False
|
||||
should_warn: bool = False
|
||||
should_strike: bool = False
|
||||
should_timeout: bool = False
|
||||
timeout_duration: int = 0 # seconds
|
||||
reason: str = ""
|
||||
matched_filter: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpamConfig:
|
||||
"""Configuration for spam thresholds."""
|
||||
|
||||
"""Spam detection configuration."""
|
||||
message_rate_limit: int = 5
|
||||
message_rate_window: int = 5
|
||||
duplicate_threshold: int = 3
|
||||
@@ -239,324 +32,158 @@ class SpamConfig:
|
||||
mention_rate_window: int = 60
|
||||
|
||||
|
||||
def normalize_domain(value: str) -> str:
|
||||
"""Normalize a domain or URL for allowlist checks with security validation."""
|
||||
if not value or not isinstance(value, str):
|
||||
return ""
|
||||
|
||||
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
|
||||
return ""
|
||||
|
||||
text = value.strip().lower()
|
||||
if not text or len(text) > 2000: # Prevent excessively long URLs
|
||||
return ""
|
||||
|
||||
try:
|
||||
if "://" not in text:
|
||||
text = f"http://{text}"
|
||||
|
||||
parsed = urlparse(text)
|
||||
hostname = parsed.hostname or ""
|
||||
|
||||
# Additional validation for hostname
|
||||
if not hostname or len(hostname) > 253: # RFC limit
|
||||
return ""
|
||||
|
||||
# Check for malicious patterns
|
||||
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
|
||||
return ""
|
||||
|
||||
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
|
||||
return ""
|
||||
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
|
||||
return ""
|
||||
for label in hostname.split("."):
|
||||
if not label:
|
||||
return ""
|
||||
if label.startswith("-") or label.endswith("-"):
|
||||
return ""
|
||||
|
||||
# Remove www prefix
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
|
||||
return hostname
|
||||
except (ValueError, UnicodeError, Exception):
|
||||
# urlparse can raise various exceptions with malicious input
|
||||
return ""
|
||||
@dataclass
|
||||
class AutomodResult:
|
||||
"""Result of an automod check."""
|
||||
matched_filter: str
|
||||
reason: str
|
||||
should_delete: bool = True
|
||||
should_warn: bool = False
|
||||
should_strike: bool = False
|
||||
should_timeout: bool = False
|
||||
timeout_duration: int | None = None
|
||||
|
||||
|
||||
def is_allowed_domain(hostname: str, allowlist: set[str]) -> bool:
|
||||
"""Check if a hostname is allowlisted."""
|
||||
if not hostname:
|
||||
return False
|
||||
for domain in allowlist:
|
||||
if hostname == domain or hostname.endswith(f".{domain}"):
|
||||
return True
|
||||
return False
|
||||
class SpamTracker:
|
||||
"""Track user spam behavior."""
|
||||
|
||||
def __init__(self):
|
||||
# guild_id -> user_id -> deque of message timestamps
|
||||
self.message_times: dict[int, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
|
||||
# guild_id -> user_id -> deque of message contents for duplicate detection
|
||||
self.message_contents: dict[int, dict[int, list[str]]] = defaultdict(lambda: defaultdict(list))
|
||||
# guild_id -> user_id -> deque of mention timestamps
|
||||
self.mention_times: dict[int, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list))
|
||||
# Last cleanup time
|
||||
self.last_cleanup = time.time()
|
||||
|
||||
def cleanup_old_entries(self):
|
||||
"""Periodically cleanup old entries to prevent memory leaks."""
|
||||
now = time.time()
|
||||
if now - self.last_cleanup < 300: # Cleanup every 5 minutes
|
||||
return
|
||||
|
||||
cutoff = now - 3600 # Keep last hour of data
|
||||
|
||||
for guild_data in [self.message_times, self.mention_times]:
|
||||
for guild_id in list(guild_data.keys()):
|
||||
for user_id in list(guild_data[guild_id].keys()):
|
||||
# Remove old timestamps
|
||||
guild_data[guild_id][user_id] = [
|
||||
ts for ts in guild_data[guild_id][user_id] if ts > cutoff
|
||||
]
|
||||
# Remove empty users
|
||||
if not guild_data[guild_id][user_id]:
|
||||
del guild_data[guild_id][user_id]
|
||||
# Remove empty guilds
|
||||
if not guild_data[guild_id]:
|
||||
del guild_data[guild_id]
|
||||
|
||||
# Cleanup message contents
|
||||
for guild_id in list(self.message_contents.keys()):
|
||||
for user_id in list(self.message_contents[guild_id].keys()):
|
||||
# Keep only last 10 messages per user
|
||||
self.message_contents[guild_id][user_id] = self.message_contents[guild_id][user_id][-10:]
|
||||
if not self.message_contents[guild_id][user_id]:
|
||||
del self.message_contents[guild_id][user_id]
|
||||
if not self.message_contents[guild_id]:
|
||||
del self.message_contents[guild_id]
|
||||
|
||||
self.last_cleanup = now
|
||||
|
||||
|
||||
class AutomodService:
|
||||
"""Service for automatic content moderation."""
|
||||
"""Service for spam detection - no banned words, no scam links, no invites."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Compile scam patterns
|
||||
self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS]
|
||||
|
||||
# Per-guild, per-user spam tracking
|
||||
# Structure: {guild_id: {user_id: UserSpamTracker}}
|
||||
self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict(
|
||||
lambda: defaultdict(UserSpamTracker)
|
||||
)
|
||||
|
||||
# Default spam thresholds
|
||||
def __init__(self):
|
||||
self.spam_tracker = SpamTracker()
|
||||
self.default_spam_config = SpamConfig()
|
||||
|
||||
def _get_content_hash(self, content: str) -> str:
|
||||
"""Get a normalized hash of message content for duplicate detection."""
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
# Use simple string operations for basic patterns to avoid regex overhead
|
||||
normalized = content.lower()
|
||||
|
||||
# Remove special characters (simplified approach)
|
||||
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
# Normalize whitespace
|
||||
normalized = " ".join(normalized.split())
|
||||
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
self, content: str, banned_words: Sequence[BannedWord]
|
||||
) -> AutomodResult | None:
|
||||
"""Check message against banned words list."""
|
||||
content_lower = content.lower()
|
||||
|
||||
for banned in banned_words:
|
||||
matched = False
|
||||
|
||||
if banned.is_regex:
|
||||
# Use circuit breaker for safe regex execution
|
||||
if _regex_circuit_breaker.safe_regex_search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
else:
|
||||
if banned.pattern.lower() in content_lower:
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
result = AutomodResult(
|
||||
should_delete=True,
|
||||
reason=banned.reason or f"Matched banned word filter",
|
||||
matched_filter=f"banned_word:{banned.id}",
|
||||
)
|
||||
|
||||
if banned.action == "warn":
|
||||
result.should_warn = True
|
||||
elif banned.action == "strike":
|
||||
result.should_strike = True
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def check_scam_links(
|
||||
self, content: str, allowlist: list[str] | None = None
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for scam/phishing patterns."""
|
||||
# Check for known scam patterns
|
||||
for pattern in self._scam_patterns:
|
||||
if pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason="Message matched known scam/phishing pattern",
|
||||
matched_filter="scam_pattern",
|
||||
)
|
||||
|
||||
allowlist_set = {normalize_domain(domain) for domain in allowlist or [] if domain}
|
||||
|
||||
# Check URLs for suspicious TLDs
|
||||
urls = URL_PATTERN.findall(content)
|
||||
for url in urls:
|
||||
# Limit URL length to prevent processing extremely long URLs
|
||||
if len(url) > 2000:
|
||||
continue
|
||||
|
||||
url_lower = url.lower()
|
||||
hostname = normalize_domain(url)
|
||||
|
||||
# Skip if hostname normalization failed (security check)
|
||||
if not hostname:
|
||||
continue
|
||||
|
||||
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
||||
continue
|
||||
|
||||
for tld in SUSPICIOUS_TLDS:
|
||||
if tld in url_lower:
|
||||
# Additional check: is it trying to impersonate a known domain?
|
||||
impersonation_keywords = [
|
||||
"discord",
|
||||
"steam",
|
||||
"nitro",
|
||||
"gift",
|
||||
"free",
|
||||
"login",
|
||||
"verify",
|
||||
]
|
||||
if any(kw in url_lower for kw in impersonation_keywords):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Suspicious link detected: {url[:50]}",
|
||||
matched_filter="suspicious_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_spam(
|
||||
self,
|
||||
message: discord.Message,
|
||||
message: "discord.Message",
|
||||
anti_spam_enabled: bool = True,
|
||||
spam_config: SpamConfig | None = None,
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for spam behavior."""
|
||||
"""Check message for spam patterns.
|
||||
|
||||
Args:
|
||||
message: Discord message to check
|
||||
anti_spam_enabled: Whether spam detection is enabled
|
||||
spam_config: Spam configuration settings
|
||||
|
||||
Returns:
|
||||
AutomodResult if spam detected, None otherwise
|
||||
"""
|
||||
if not anti_spam_enabled:
|
||||
return None
|
||||
|
||||
# Skip DM messages
|
||||
if message.guild is None:
|
||||
return None
|
||||
|
||||
config = spam_config or self.default_spam_config
|
||||
|
||||
guild_id = message.guild.id
|
||||
user_id = message.author.id
|
||||
tracker = self._spam_trackers[guild_id][user_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
now = time.time()
|
||||
|
||||
# Cleanup old records
|
||||
tracker.cleanup()
|
||||
# Periodic cleanup
|
||||
self.spam_tracker.cleanup_old_entries()
|
||||
|
||||
# Check message rate
|
||||
content_hash = self._get_content_hash(message.content)
|
||||
tracker.messages.append(SpamRecord(content_hash, now))
|
||||
# Check 1: Message rate limiting
|
||||
message_times = self.spam_tracker.message_times[guild_id][user_id]
|
||||
cutoff_time = now - config.message_rate_window
|
||||
|
||||
# Rate limit check
|
||||
recent_window = now - timedelta(seconds=config.message_rate_window)
|
||||
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||
# Remove old timestamps
|
||||
message_times = [ts for ts in message_times if ts > cutoff_time]
|
||||
self.spam_tracker.message_times[guild_id][user_id] = message_times
|
||||
|
||||
if len(recent_messages) > config.message_rate_limit:
|
||||
# Add current message
|
||||
message_times.append(now)
|
||||
|
||||
if len(message_times) > config.message_rate_limit:
|
||||
return AutomodResult(
|
||||
matched_filter="spam_rate_limit",
|
||||
reason=f"Exceeded message rate limit ({len(message_times)} messages in {config.message_rate_window}s)",
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=60, # 1 minute timeout
|
||||
reason=(
|
||||
f"Sending messages too fast ({len(recent_messages)} in "
|
||||
f"{config.message_rate_window}s)"
|
||||
),
|
||||
matched_filter="rate_limit",
|
||||
)
|
||||
|
||||
# Duplicate message check
|
||||
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||
# Check 2: Duplicate messages
|
||||
message_contents = self.spam_tracker.message_contents[guild_id][user_id]
|
||||
message_contents.append(message.content)
|
||||
self.spam_tracker.message_contents[guild_id][user_id] = message_contents[-10:] # Keep last 10
|
||||
|
||||
# Count duplicates in recent messages
|
||||
duplicate_count = message_contents.count(message.content)
|
||||
if duplicate_count >= config.duplicate_threshold:
|
||||
return AutomodResult(
|
||||
matched_filter="spam_duplicate",
|
||||
reason=f"Duplicate message posted {duplicate_count} times",
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Duplicate message detected ({duplicate_count} times)",
|
||||
matched_filter="duplicate",
|
||||
)
|
||||
|
||||
# Mass mention check
|
||||
mention_count = len(message.mentions) + len(message.role_mentions)
|
||||
if message.mention_everyone:
|
||||
mention_count += 100 # Treat @everyone as many mentions
|
||||
|
||||
# Check 3: Mass mentions in single message
|
||||
mention_count = len(message.mentions)
|
||||
if mention_count > config.mention_limit:
|
||||
return AutomodResult(
|
||||
matched_filter="spam_mass_mentions",
|
||||
reason=f"Too many mentions in single message ({mention_count})",
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300, # 5 minute timeout
|
||||
reason=f"Mass mentions detected ({mention_count} mentions)",
|
||||
matched_filter="mass_mention",
|
||||
)
|
||||
|
||||
# Check 4: Mention rate limiting
|
||||
if mention_count > 0:
|
||||
if tracker.last_mention_time:
|
||||
window = timedelta(seconds=config.mention_rate_window)
|
||||
if now - tracker.last_mention_time > window:
|
||||
tracker.mention_count = 0
|
||||
tracker.mention_count += mention_count
|
||||
tracker.last_mention_time = now
|
||||
mention_times = self.spam_tracker.mention_times[guild_id][user_id]
|
||||
mention_cutoff = now - config.mention_rate_window
|
||||
|
||||
if tracker.mention_count > config.mention_rate_limit:
|
||||
# Remove old timestamps
|
||||
mention_times = [ts for ts in mention_times if ts > mention_cutoff]
|
||||
|
||||
# Add current mentions
|
||||
mention_times.extend([now] * mention_count)
|
||||
self.spam_tracker.mention_times[guild_id][user_id] = mention_times
|
||||
|
||||
if len(mention_times) > config.mention_rate_limit:
|
||||
return AutomodResult(
|
||||
matched_filter="spam_mention_rate",
|
||||
reason=f"Exceeded mention rate limit ({len(mention_times)} mentions in {config.mention_rate_window}s)",
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300,
|
||||
reason=(
|
||||
"Too many mentions in a short period "
|
||||
f"({tracker.mention_count} in {config.mention_rate_window}s)"
|
||||
),
|
||||
matched_filter="mention_rate",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||
"""Check for Discord invite links."""
|
||||
if allow_invites:
|
||||
return None
|
||||
|
||||
invite_pattern = re.compile(
|
||||
r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
if invite_pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Discord invite links are not allowed",
|
||||
matched_filter="invite_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_all_caps(
|
||||
self, content: str, threshold: float = 0.7, min_length: int = 10
|
||||
) -> AutomodResult | None:
|
||||
"""Check for excessive caps usage."""
|
||||
# Only check messages with enough letters
|
||||
letters = [c for c in content if c.isalpha()]
|
||||
if len(letters) < min_length:
|
||||
return None
|
||||
|
||||
caps_count = sum(1 for c in letters if c.isupper())
|
||||
caps_ratio = caps_count / len(letters)
|
||||
|
||||
if caps_ratio > threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Excessive caps usage",
|
||||
matched_filter="caps",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def reset_user_tracker(self, guild_id: int, user_id: int) -> None:
|
||||
"""Reset spam tracking for a user."""
|
||||
if guild_id in self._spam_trackers:
|
||||
self._spam_trackers[guild_id].pop(user_id, None)
|
||||
|
||||
def cleanup_guild(self, guild_id: int) -> None:
|
||||
"""Remove all tracking data for a guild."""
|
||||
self._spam_trackers.pop(guild_id, None)
|
||||
|
||||
|
||||
_automod_service = AutomodService()
|
||||
|
||||
|
||||
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
|
||||
"""Convenience wrapper for scam detection."""
|
||||
return _automod_service.check_scam_links(content, allowlist)
|
||||
|
||||
Reference in New Issue
Block a user