Implement GuardDen Discord moderation bot
Features: - Core moderation: warn, kick, ban, timeout, strike system - Automod: banned words filter, scam detection, anti-spam, link filtering - AI moderation: Claude/OpenAI integration, NSFW detection, phishing analysis - Verification system: button, captcha, math, emoji challenges - Rate limiting system with configurable scopes - Event logging: joins, leaves, message edits/deletes, voice activity - Per-guild configuration with caching - Docker deployment support Bug fixes applied: - Fixed await on session.delete() in guild_config.py - Fixed memory leak in AI moderation message tracking (use deque) - Added error handling to bot shutdown - Added error handling to timeout command - Removed unused Literal import - Added prefix validation - Added image analysis limit (3 per message) - Fixed test mock for SQLAlchemy model
This commit is contained in:
3
src/guardden/__init__.py
Normal file
3
src/guardden/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""GuardDen - A comprehensive Discord moderation bot."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
40
src/guardden/__main__.py
Normal file
40
src/guardden/__main__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Entry point for GuardDen bot."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.config import get_settings
|
||||
from guardden.utils import setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the GuardDen bot."""
|
||||
try:
|
||||
settings = get_settings()
|
||||
except Exception as e:
|
||||
print(f"Failed to load configuration: {e}", file=sys.stderr)
|
||||
print("Make sure GUARDDEN_DISCORD_TOKEN is set.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
setup_logging(settings.log_level)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
bot = GuardDen(settings)
|
||||
|
||||
async def runner() -> None:
|
||||
async with bot:
|
||||
await bot.start(settings.discord_token.get_secret_value())
|
||||
|
||||
try:
|
||||
asyncio.run(runner())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.exception(f"Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
131
src/guardden/bot.py
Normal file
131
src/guardden/bot.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Main bot class for GuardDen."""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.config import Settings
|
||||
from guardden.services.ai import AIProvider, create_ai_provider
|
||||
from guardden.services.database import Database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guardden.services.guild_config import GuildConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardDen(commands.Bot):
|
||||
"""The main GuardDen Discord bot."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
intents.voice_states = True
|
||||
|
||||
super().__init__(
|
||||
command_prefix=self._get_prefix,
|
||||
intents=intents,
|
||||
help_command=commands.DefaultHelpCommand(),
|
||||
)
|
||||
|
||||
# Services
|
||||
self.database = Database(settings)
|
||||
self.guild_config: "GuildConfigService | None" = None
|
||||
self.ai_provider: AIProvider | None = None
|
||||
|
||||
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
||||
"""Get the command prefix for a guild."""
|
||||
if not message.guild:
|
||||
return [self.settings.discord_prefix]
|
||||
|
||||
if self.guild_config:
|
||||
config = await self.guild_config.get_config(message.guild.id)
|
||||
if config:
|
||||
return [config.prefix]
|
||||
|
||||
return [self.settings.discord_prefix]
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Called when the bot is starting up."""
|
||||
logger.info("Starting GuardDen setup...")
|
||||
|
||||
# Connect to database
|
||||
await self.database.connect()
|
||||
await self.database.create_tables()
|
||||
|
||||
# Initialize services
|
||||
from guardden.services.guild_config import GuildConfigService
|
||||
|
||||
self.guild_config = GuildConfigService(self.database)
|
||||
|
||||
# Initialize AI provider
|
||||
api_key = None
|
||||
if self.settings.ai_provider == "anthropic" and self.settings.anthropic_api_key:
|
||||
api_key = self.settings.anthropic_api_key.get_secret_value()
|
||||
elif self.settings.ai_provider == "openai" and self.settings.openai_api_key:
|
||||
api_key = self.settings.openai_api_key.get_secret_value()
|
||||
|
||||
self.ai_provider = create_ai_provider(self.settings.ai_provider, api_key)
|
||||
|
||||
# Load cogs
|
||||
await self._load_cogs()
|
||||
|
||||
logger.info("GuardDen setup complete")
|
||||
|
||||
async def _load_cogs(self) -> None:
|
||||
"""Load all cog extensions."""
|
||||
cogs = [
|
||||
"guardden.cogs.events",
|
||||
"guardden.cogs.moderation",
|
||||
"guardden.cogs.admin",
|
||||
"guardden.cogs.automod",
|
||||
"guardden.cogs.ai_moderation",
|
||||
"guardden.cogs.verification",
|
||||
]
|
||||
|
||||
for cog in cogs:
|
||||
try:
|
||||
await self.load_extension(cog)
|
||||
logger.info(f"Loaded cog: {cog}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cog {cog}: {e}")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Called when the bot is fully connected and ready."""
|
||||
if self.user:
|
||||
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
logger.info(f"Connected to {len(self.guilds)} guild(s)")
|
||||
|
||||
# Set presence
|
||||
activity = discord.Activity(
|
||||
type=discord.ActivityType.watching,
|
||||
name="over your community",
|
||||
)
|
||||
await self.change_presence(activity=activity)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up when shutting down."""
|
||||
logger.info("Shutting down GuardDen...")
|
||||
if self.ai_provider:
|
||||
try:
|
||||
await self.ai_provider.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing AI provider: {e}")
|
||||
await self.database.disconnect()
|
||||
await super().close()
|
||||
|
||||
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||
"""Called when the bot joins a new guild."""
|
||||
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
||||
|
||||
if self.guild_config:
|
||||
await self.guild_config.create_guild(guild)
|
||||
|
||||
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
||||
"""Called when the bot is removed from a guild."""
|
||||
logger.info(f"Removed from guild: {guild.name} (ID: {guild.id})")
|
||||
1
src/guardden/cogs/__init__.py
Normal file
1
src/guardden/cogs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Discord cogs for GuardDen."""
|
||||
255
src/guardden/cogs/admin.py
Normal file
255
src/guardden/cogs/admin.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Admin commands for bot configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Admin(commands.Cog):
|
||||
"""Administrative commands for bot configuration."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Ensure only administrators can use these commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
return ctx.author.guild_permissions.administrator
|
||||
|
||||
@commands.group(name="config", invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def config(self, ctx: commands.Context) -> None:
|
||||
"""View or modify bot configuration."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
if not config:
|
||||
await ctx.send("No configuration found. Run a config command to initialize.")
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"Configuration for {ctx.guild.name}",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
# General settings
|
||||
embed.add_field(name="Prefix", value=f"`{config.prefix}`", inline=True)
|
||||
embed.add_field(name="Locale", value=config.locale, inline=True)
|
||||
embed.add_field(name="\u200b", value="\u200b", inline=True)
|
||||
|
||||
# Channels
|
||||
log_ch = ctx.guild.get_channel(config.log_channel_id) if config.log_channel_id else None
|
||||
mod_log_ch = (
|
||||
ctx.guild.get_channel(config.mod_log_channel_id) if config.mod_log_channel_id else None
|
||||
)
|
||||
welcome_ch = (
|
||||
ctx.guild.get_channel(config.welcome_channel_id) if config.welcome_channel_id else None
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Log Channel", value=log_ch.mention if log_ch else "Not set", inline=True
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mod Log Channel",
|
||||
value=mod_log_ch.mention if mod_log_ch else "Not set",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Welcome Channel",
|
||||
value=welcome_ch.mention if welcome_ch else "Not set",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
# Features
|
||||
features = []
|
||||
if config.automod_enabled:
|
||||
features.append("AutoMod")
|
||||
if config.anti_spam_enabled:
|
||||
features.append("Anti-Spam")
|
||||
if config.link_filter_enabled:
|
||||
features.append("Link Filter")
|
||||
if config.ai_moderation_enabled:
|
||||
features.append("AI Moderation")
|
||||
if config.verification_enabled:
|
||||
features.append("Verification")
|
||||
|
||||
embed.add_field(
|
||||
name="Enabled Features",
|
||||
value=", ".join(features) if features else "None",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@config.command(name="prefix")
|
||||
@commands.guild_only()
|
||||
async def config_prefix(self, ctx: commands.Context, prefix: str) -> None:
|
||||
"""Set the command prefix for this server."""
|
||||
if not prefix or not prefix.strip():
|
||||
await ctx.send("Prefix cannot be empty or whitespace only.")
|
||||
return
|
||||
|
||||
if len(prefix) > 10:
|
||||
await ctx.send("Prefix must be 10 characters or less.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, prefix=prefix)
|
||||
await ctx.send(f"Command prefix set to `{prefix}`")
|
||||
|
||||
@config.command(name="logchannel")
|
||||
@commands.guild_only()
|
||||
async def config_log_channel(
|
||||
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||
) -> None:
|
||||
"""Set the channel for general event logs."""
|
||||
channel_id = channel.id if channel else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, log_channel_id=channel_id)
|
||||
|
||||
if channel:
|
||||
await ctx.send(f"Log channel set to {channel.mention}")
|
||||
else:
|
||||
await ctx.send("Log channel has been disabled.")
|
||||
|
||||
@config.command(name="modlogchannel")
|
||||
@commands.guild_only()
|
||||
async def config_mod_log_channel(
|
||||
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||
) -> None:
|
||||
"""Set the channel for moderation action logs."""
|
||||
channel_id = channel.id if channel else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, mod_log_channel_id=channel_id)
|
||||
|
||||
if channel:
|
||||
await ctx.send(f"Moderation log channel set to {channel.mention}")
|
||||
else:
|
||||
await ctx.send("Moderation log channel has been disabled.")
|
||||
|
||||
@config.command(name="welcomechannel")
|
||||
@commands.guild_only()
|
||||
async def config_welcome_channel(
|
||||
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||
) -> None:
|
||||
"""Set the welcome channel for new members."""
|
||||
channel_id = channel.id if channel else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, welcome_channel_id=channel_id)
|
||||
|
||||
if channel:
|
||||
await ctx.send(f"Welcome channel set to {channel.mention}")
|
||||
else:
|
||||
await ctx.send("Welcome channel has been disabled.")
|
||||
|
||||
@config.command(name="muterole")
|
||||
@commands.guild_only()
|
||||
async def config_mute_role(
|
||||
self, ctx: commands.Context, role: discord.Role | None = None
|
||||
) -> None:
|
||||
"""Set the role to assign when muting members."""
|
||||
role_id = role.id if role else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, mute_role_id=role_id)
|
||||
|
||||
if role:
|
||||
await ctx.send(f"Mute role set to {role.mention}")
|
||||
else:
|
||||
await ctx.send("Mute role has been cleared.")
|
||||
|
||||
@config.command(name="automod")
|
||||
@commands.guild_only()
|
||||
async def config_automod(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable automod features."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, automod_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"AutoMod has been {status}.")
|
||||
|
||||
@config.command(name="antispam")
|
||||
@commands.guild_only()
|
||||
async def config_antispam(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable anti-spam protection."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, anti_spam_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"Anti-spam has been {status}.")
|
||||
|
||||
@config.command(name="linkfilter")
|
||||
@commands.guild_only()
|
||||
async def config_linkfilter(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable link filtering."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, link_filter_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"Link filter has been {status}.")
|
||||
|
||||
@commands.group(name="bannedwords", aliases=["bw"], invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def banned_words(self, ctx: commands.Context) -> None:
|
||||
"""Manage banned words list."""
|
||||
words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
|
||||
if not words:
|
||||
await ctx.send("No banned words configured.")
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Banned Words",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
|
||||
for word in words[:25]: # Discord embed limit
|
||||
word_type = "Regex" if word.is_regex else "Text"
|
||||
embed.add_field(
|
||||
name=f"#{word.id}: {word.pattern[:30]}",
|
||||
value=f"Type: {word_type} | Action: {word.action}",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
if len(words) > 25:
|
||||
embed.set_footer(text=f"Showing 25 of {len(words)} banned words")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@banned_words.command(name="add")
|
||||
@commands.guild_only()
|
||||
async def banned_words_add(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
pattern: str,
|
||||
action: Literal["delete", "warn", "strike"] = "delete",
|
||||
is_regex: bool = False,
|
||||
) -> None:
|
||||
"""Add a banned word or pattern."""
|
||||
word = await self.bot.guild_config.add_banned_word(
|
||||
guild_id=ctx.guild.id,
|
||||
pattern=pattern,
|
||||
added_by=ctx.author.id,
|
||||
is_regex=is_regex,
|
||||
action=action,
|
||||
)
|
||||
|
||||
word_type = "regex pattern" if is_regex else "word"
|
||||
await ctx.send(f"Added banned {word_type}: `{pattern}` (ID: {word.id}, Action: {action})")
|
||||
|
||||
@banned_words.command(name="remove", aliases=["delete"])
|
||||
@commands.guild_only()
|
||||
async def banned_words_remove(self, ctx: commands.Context, word_id: int) -> None:
|
||||
"""Remove a banned word by ID."""
|
||||
success = await self.bot.guild_config.remove_banned_word(ctx.guild.id, word_id)
|
||||
|
||||
if success:
|
||||
await ctx.send(f"Removed banned word #{word_id}")
|
||||
else:
|
||||
await ctx.send(f"Banned word #{word_id} not found.")
|
||||
|
||||
@commands.command(name="sync")
|
||||
@commands.is_owner()
|
||||
async def sync_commands(self, ctx: commands.Context) -> None:
|
||||
"""Sync slash commands (bot owner only)."""
|
||||
await self.bot.tree.sync()
|
||||
await ctx.send("Slash commands synced.")
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Admin cog."""
|
||||
await bot.add_cog(Admin(bot))
|
||||
366
src/guardden/cogs/ai_moderation.py
Normal file
366
src/guardden/cogs/ai_moderation.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""AI-powered moderation cog."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# URL pattern for extraction
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class AIModeration(commands.Cog):
|
||||
"""AI-powered content moderation."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
# Track recently analyzed messages to avoid duplicates (deque auto-removes oldest)
|
||||
self._analyzed_messages: deque[int] = deque(maxlen=1000)
|
||||
|
||||
def _should_analyze(self, message: discord.Message) -> bool:
|
||||
"""Determine if a message should be analyzed by AI."""
|
||||
# Skip if already analyzed
|
||||
if message.id in self._analyzed_messages:
|
||||
return False
|
||||
|
||||
# Skip short messages
|
||||
if len(message.content) < 20 and not message.attachments:
|
||||
return False
|
||||
|
||||
# Skip messages from bots
|
||||
if message.author.bot:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _track_message(self, message_id: int) -> None:
|
||||
"""Track that a message has been analyzed."""
|
||||
self._analyzed_messages.append(message_id)
|
||||
|
||||
async def _handle_ai_result(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
) -> None:
|
||||
"""Handle the result of AI analysis."""
|
||||
if not result.is_flagged:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config:
|
||||
return
|
||||
|
||||
# Check if severity meets threshold based on sensitivity
|
||||
# Higher sensitivity = lower threshold needed to trigger
|
||||
threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30
|
||||
if result.severity < threshold:
|
||||
logger.debug(
|
||||
f"AI flagged content but below threshold: "
|
||||
f"severity={result.severity}, threshold={threshold}"
|
||||
)
|
||||
return
|
||||
|
||||
# Determine action based on suggested action and severity
|
||||
should_delete = result.suggested_action in ("delete", "timeout", "ban")
|
||||
should_timeout = result.suggested_action in ("timeout", "ban") and result.severity > 70
|
||||
|
||||
# Delete message if needed
|
||||
if should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot delete message: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
# Timeout user for severe violations
|
||||
if should_timeout and isinstance(message.author, discord.Member):
|
||||
timeout_duration = 300 if result.severity < 90 else 3600 # 5 min or 1 hour
|
||||
try:
|
||||
await message.author.timeout(
|
||||
timedelta(seconds=timeout_duration),
|
||||
reason=f"AI Moderation: {result.explanation[:100]}",
|
||||
)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
# Log to mod channel
|
||||
await self._log_ai_action(message, result, analysis_type)
|
||||
|
||||
# Notify user
|
||||
try:
|
||||
embed = discord.Embed(
|
||||
title=f"Message Flagged in {message.guild.name}",
|
||||
description=result.explanation,
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(
|
||||
name="Categories",
|
||||
value=", ".join(cat.value for cat in result.categories) or "Unknown",
|
||||
)
|
||||
if should_timeout:
|
||||
embed.add_field(name="Action", value="You have been timed out")
|
||||
await message.author.send(embed=embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
async def _log_ai_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
) -> None:
|
||||
"""Log an AI moderation action."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"AI Moderation - {analysis_type}",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(
|
||||
name=str(message.author),
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Action", value=result.suggested_action, inline=True)
|
||||
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories or "None", inline=False)
|
||||
embed.add_field(name="Explanation", value=result.explanation[:500], inline=False)
|
||||
|
||||
if message.content:
|
||||
content = (
|
||||
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||
)
|
||||
embed.add_field(name="Content", value=f"```{content}```", inline=False)
|
||||
|
||||
embed.set_footer(text=f"User ID: {message.author.id} | Channel: #{message.channel.name}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Analyze messages with AI moderation."""
|
||||
if not message.guild:
|
||||
return
|
||||
|
||||
# Check if AI moderation is enabled for this guild
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.ai_moderation_enabled:
|
||||
return
|
||||
|
||||
# Skip users with manage_messages permission
|
||||
if isinstance(message.author, discord.Member):
|
||||
if message.author.guild_permissions.manage_messages:
|
||||
return
|
||||
|
||||
if not self._should_analyze(message):
|
||||
return
|
||||
|
||||
self._track_message(message.id)
|
||||
|
||||
# Analyze text content
|
||||
if message.content and len(message.content) >= 20:
|
||||
result = await self.bot.ai_provider.moderate_text(
|
||||
content=message.content,
|
||||
context=f"Discord server: {message.guild.name}, channel: {message.channel.name}",
|
||||
sensitivity=config.ai_sensitivity,
|
||||
)
|
||||
|
||||
if result.is_flagged:
|
||||
await self._handle_ai_result(message, result, "Text Analysis")
|
||||
return # Don't continue if already flagged
|
||||
|
||||
# Analyze images if NSFW detection is enabled (limit to 3 per message)
|
||||
if config.nsfw_detection_enabled and message.attachments:
|
||||
images_analyzed = 0
|
||||
for attachment in message.attachments:
|
||||
if images_analyzed >= 3:
|
||||
break
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
images_analyzed += 1
|
||||
image_result = await self.bot.ai_provider.analyze_image(
|
||||
image_url=attachment.url,
|
||||
sensitivity=config.ai_sensitivity,
|
||||
)
|
||||
|
||||
if (
|
||||
image_result.is_nsfw
|
||||
or image_result.is_violent
|
||||
or image_result.is_disturbing
|
||||
):
|
||||
# Convert to ModerationResult format
|
||||
categories = []
|
||||
if image_result.is_nsfw:
|
||||
categories.append(ContentCategory.SEXUAL)
|
||||
if image_result.is_violent:
|
||||
categories.append(ContentCategory.VIOLENCE)
|
||||
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=image_result.confidence,
|
||||
categories=categories,
|
||||
explanation=image_result.description,
|
||||
suggested_action="delete",
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Image Analysis")
|
||||
return
|
||||
|
||||
# Analyze URLs for phishing
|
||||
urls = URL_PATTERN.findall(message.content)
|
||||
for url in urls[:3]: # Limit to first 3 URLs
|
||||
phishing_result = await self.bot.ai_provider.analyze_phishing(
|
||||
url=url,
|
||||
message_content=message.content,
|
||||
)
|
||||
|
||||
if phishing_result.is_phishing and phishing_result.confidence > 0.7:
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=phishing_result.confidence,
|
||||
categories=[ContentCategory.SCAM],
|
||||
explanation=phishing_result.explanation,
|
||||
suggested_action="delete",
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Phishing Detection")
|
||||
return
|
||||
|
||||
@commands.group(name="ai", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_cmd(self, ctx: commands.Context) -> None:
|
||||
"""View AI moderation settings."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="AI Moderation Settings",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="AI Moderation",
|
||||
value="✅ Enabled" if config and config.ai_moderation_enabled else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="NSFW Detection",
|
||||
value="✅ Enabled" if config and config.nsfw_detection_enabled else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Sensitivity",
|
||||
value=f"{config.ai_sensitivity}/100" if config else "50/100",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="AI Provider",
|
||||
value=self.bot.settings.ai_provider.capitalize(),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@ai_cmd.command(name="enable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_enable(self, ctx: commands.Context) -> None:
|
||||
"""Enable AI moderation."""
|
||||
if self.bot.settings.ai_provider == "none":
|
||||
await ctx.send(
|
||||
"AI moderation is not configured. Set `GUARDDEN_AI_PROVIDER` and API key."
|
||||
)
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=True)
|
||||
await ctx.send("✅ AI moderation enabled.")
|
||||
|
||||
@ai_cmd.command(name="disable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_disable(self, ctx: commands.Context) -> None:
|
||||
"""Disable AI moderation."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=False)
|
||||
await ctx.send("❌ AI moderation disabled.")
|
||||
|
||||
@ai_cmd.command(name="sensitivity")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_sensitivity(self, ctx: commands.Context, level: int) -> None:
|
||||
"""Set AI sensitivity level (0-100). Higher = more strict."""
|
||||
if not 0 <= level <= 100:
|
||||
await ctx.send("Sensitivity must be between 0 and 100.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level)
|
||||
await ctx.send(f"AI sensitivity set to {level}/100.")
|
||||
|
||||
@ai_cmd.command(name="nsfw")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_nsfw(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable NSFW image detection."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, nsfw_detection_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"NSFW detection {status}.")
|
||||
|
||||
@ai_cmd.command(name="analyze")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_analyze(self, ctx: commands.Context, *, text: str) -> None:
|
||||
"""Test AI analysis on text (does not take action)."""
|
||||
if self.bot.settings.ai_provider == "none":
|
||||
await ctx.send("AI moderation is not configured.")
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
result = await self.bot.ai_provider.moderate_text(
|
||||
content=text,
|
||||
context=f"Test analysis in {ctx.guild.name}",
|
||||
sensitivity=50,
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="AI Analysis Result",
|
||||
color=discord.Color.red() if result.is_flagged else discord.Color.green(),
|
||||
)
|
||||
|
||||
embed.add_field(name="Flagged", value="Yes" if result.is_flagged else "No", inline=True)
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Suggested Action", value=result.suggested_action, inline=True)
|
||||
|
||||
if result.categories:
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories, inline=False)
|
||||
|
||||
if result.explanation:
|
||||
embed.add_field(name="Explanation", value=result.explanation[:1000], inline=False)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the AI Moderation cog."""
|
||||
await bot.add_cog(AIModeration(bot))
|
||||
267
src/guardden/cogs/automod.py
Normal file
267
src/guardden/cogs/automod.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Automod cog for automatic content moderation."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.automod import AutomodResult, AutomodService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Automod(commands.Cog):
|
||||
"""Automatic content moderation."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
self.automod = AutomodService()
|
||||
|
||||
async def _handle_violation(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
"""Handle an automod violation."""
|
||||
# Delete the message
|
||||
if result.should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot delete message in {message.guild}: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass # Already deleted
|
||||
|
||||
# Apply timeout
|
||||
if result.should_timeout and result.timeout_duration > 0:
|
||||
try:
|
||||
await message.author.timeout(
|
||||
timedelta(seconds=result.timeout_duration),
|
||||
reason=f"Automod: {result.reason}",
|
||||
)
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot timeout {message.author}: missing permissions")
|
||||
|
||||
# Log the action
|
||||
await self._log_automod_action(message, result)
|
||||
|
||||
# Notify the user via DM
|
||||
try:
|
||||
embed = discord.Embed(
|
||||
title=f"Message Removed in {message.guild.name}",
|
||||
description=result.reason,
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
if result.should_timeout:
|
||||
embed.add_field(
|
||||
name="Timeout",
|
||||
value=f"You have been timed out for {result.timeout_duration} seconds.",
|
||||
)
|
||||
await message.author.send(embed=embed)
|
||||
except discord.Forbidden:
|
||||
pass # User has DMs disabled
|
||||
|
||||
async def _log_automod_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
"""Log an automod action to the mod log channel."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Action",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(
|
||||
name=str(message.author),
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
embed.add_field(name="Filter", value=result.matched_filter, inline=True)
|
||||
embed.add_field(name="Channel", value=message.channel.mention, inline=True)
|
||||
embed.add_field(name="Reason", value=result.reason, inline=False)
|
||||
|
||||
if message.content:
|
||||
content = (
|
||||
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||
)
|
||||
embed.add_field(name="Message Content", value=f"```{content}```", inline=False)
|
||||
|
||||
actions = []
|
||||
if result.should_delete:
|
||||
actions.append("Message deleted")
|
||||
if result.should_warn:
|
||||
actions.append("User warned")
|
||||
if result.should_strike:
|
||||
actions.append("Strike added")
|
||||
if result.should_timeout:
|
||||
actions.append(f"Timeout ({result.timeout_duration}s)")
|
||||
|
||||
embed.add_field(name="Actions Taken", value=", ".join(actions) or "None", inline=False)
|
||||
embed.set_footer(text=f"User ID: {message.author.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Check all messages for automod violations."""
|
||||
# Ignore DMs, bots, and empty messages
|
||||
if not message.guild or message.author.bot or not message.content:
|
||||
return
|
||||
|
||||
# Ignore users with manage_messages permission
|
||||
if isinstance(message.author, discord.Member):
|
||||
if message.author.guild_permissions.manage_messages:
|
||||
return
|
||||
|
||||
# Get guild config
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.automod_enabled:
|
||||
return
|
||||
|
||||
result: AutomodResult | None = None
|
||||
|
||||
# Check banned words
|
||||
banned_words = await self.bot.guild_config.get_banned_words(message.guild.id)
|
||||
if banned_words:
|
||||
result = self.automod.check_banned_words(message.content, banned_words)
|
||||
|
||||
# Check scam links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_scam_links(message.content)
|
||||
|
||||
# Check spam
|
||||
if not result and config.anti_spam_enabled:
|
||||
result = self.automod.check_spam(message, anti_spam_enabled=True)
|
||||
|
||||
# Check invite links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_invite_links(message.content, allow_invites=False)
|
||||
|
||||
# Handle violation if found
|
||||
if result:
|
||||
logger.info(
|
||||
f"Automod triggered in {message.guild.name}: "
|
||||
f"{result.matched_filter} by {message.author}"
|
||||
)
|
||||
await self._handle_violation(message, result)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||
"""Check edited messages for automod violations."""
|
||||
# Only check if content changed
|
||||
if before.content == after.content:
|
||||
return
|
||||
|
||||
# Reuse on_message logic
|
||||
await self.on_message(after)
|
||||
|
||||
@commands.group(name="automod", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_cmd(self, ctx: commands.Context) -> None:
|
||||
"""View automod status and configuration."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Configuration",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Automod Enabled",
|
||||
value="✅ Yes" if config and config.automod_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Anti-Spam",
|
||||
value="✅ Yes" if config and config.anti_spam_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Link Filter",
|
||||
value="✅ Yes" if config and config.link_filter_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
# Show thresholds
|
||||
embed.add_field(
|
||||
name="Rate Limit",
|
||||
value=f"{self.automod.message_rate_limit} msgs / {self.automod.message_rate_window}s",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Duplicate Threshold",
|
||||
value=f"{self.automod.duplicate_threshold} same messages",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mention Limit",
|
||||
value=f"{self.automod.mention_limit} per message",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
embed.add_field(
|
||||
name="Banned Words",
|
||||
value=f"{len(banned_words)} configured",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@automod_cmd.command(name="test")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_test(self, ctx: commands.Context, *, text: str) -> None:
|
||||
"""Test a message against automod filters (does not take action)."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
results = []
|
||||
|
||||
# Check banned words
|
||||
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
result = self.automod.check_banned_words(text, banned_words)
|
||||
if result:
|
||||
results.append(f"**Banned Words**: {result.reason}")
|
||||
|
||||
# Check scam links
|
||||
result = self.automod.check_scam_links(text)
|
||||
if result:
|
||||
results.append(f"**Scam Detection**: {result.reason}")
|
||||
|
||||
# Check invite links
|
||||
result = self.automod.check_invite_links(text, allow_invites=False)
|
||||
if result:
|
||||
results.append(f"**Invite Links**: {result.reason}")
|
||||
|
||||
# Check caps
|
||||
result = self.automod.check_all_caps(text)
|
||||
if result:
|
||||
results.append(f"**Excessive Caps**: {result.reason}")
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Test Results",
|
||||
color=discord.Color.red() if results else discord.Color.green(),
|
||||
)
|
||||
|
||||
if results:
|
||||
embed.description = "\n".join(results)
|
||||
else:
|
||||
embed.description = "✅ No violations detected"
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Automod cog."""
|
||||
await bot.add_cog(Automod(bot))
|
||||
237
src/guardden/cogs/events.py
Normal file
237
src/guardden/cogs/events.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Event handlers for logging and monitoring."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Events(commands.Cog):
|
||||
"""Handles Discord events for logging and monitoring."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_join(self, member: discord.Member) -> None:
|
||||
"""Called when a member joins a guild."""
|
||||
logger.debug(f"Member joined: {member} in {member.guild}")
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Joined",
|
||||
description=f"{member.mention} ({member})",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=member.display_avatar.url)
|
||||
embed.add_field(
|
||||
name="Account Created", value=discord.utils.format_dt(member.created_at, "R")
|
||||
)
|
||||
embed.add_field(name="Member ID", value=str(member.id))
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_remove(self, member: discord.Member) -> None:
|
||||
"""Called when a member leaves a guild."""
|
||||
logger.debug(f"Member left: {member} from {member.guild}")
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Left",
|
||||
description=f"{member} ({member.id})",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=member.display_avatar.url)
|
||||
|
||||
if member.joined_at:
|
||||
embed.add_field(name="Joined", value=discord.utils.format_dt(member.joined_at, "R"))
|
||||
|
||||
roles = [r.mention for r in member.roles if r != member.guild.default_role]
|
||||
if roles:
|
||||
embed.add_field(name="Roles", value=", ".join(roles[:10]), inline=False)
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_delete(self, message: discord.Message) -> None:
|
||||
"""Called when a message is deleted."""
|
||||
if message.author.bot or not message.guild:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Message Deleted",
|
||||
description=f"In {message.channel.mention}",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(name=str(message.author), icon_url=message.author.display_avatar.url)
|
||||
|
||||
if message.content:
|
||||
content = message.content[:1024] if len(message.content) > 1024 else message.content
|
||||
embed.add_field(name="Content", value=content, inline=False)
|
||||
|
||||
if message.attachments:
|
||||
attachments = "\n".join(a.filename for a in message.attachments)
|
||||
embed.add_field(name="Attachments", value=attachments, inline=False)
|
||||
|
||||
embed.set_footer(text=f"Author ID: {message.author.id} | Message ID: {message.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||
"""Called when a message is edited."""
|
||||
if before.author.bot or not before.guild:
|
||||
return
|
||||
|
||||
if before.content == after.content:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(before.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = before.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Message Edited",
|
||||
description=f"In {before.channel.mention} | [Jump to message]({after.jump_url})",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(name=str(before.author), icon_url=before.author.display_avatar.url)
|
||||
|
||||
before_content = before.content[:1024] if len(before.content) > 1024 else before.content
|
||||
after_content = after.content[:1024] if len(after.content) > 1024 else after.content
|
||||
|
||||
embed.add_field(name="Before", value=before_content or "*empty*", inline=False)
|
||||
embed.add_field(name="After", value=after_content or "*empty*", inline=False)
|
||||
embed.set_footer(text=f"Author ID: {before.author.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_voice_state_update(
|
||||
self,
|
||||
member: discord.Member,
|
||||
before: discord.VoiceState,
|
||||
after: discord.VoiceState,
|
||||
) -> None:
|
||||
"""Called when a member's voice state changes."""
|
||||
if member.bot:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = None
|
||||
|
||||
if before.channel is None and after.channel is not None:
|
||||
embed = discord.Embed(
|
||||
title="Voice Channel Joined",
|
||||
description=f"{member.mention} joined {after.channel.mention}",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
elif before.channel is not None and after.channel is None:
|
||||
embed = discord.Embed(
|
||||
title="Voice Channel Left",
|
||||
description=f"{member.mention} left {before.channel.mention}",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
elif before.channel != after.channel and before.channel and after.channel:
|
||||
embed = discord.Embed(
|
||||
title="Voice Channel Moved",
|
||||
description=f"{member.mention} moved from {before.channel.mention} to {after.channel.mention}",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
if embed:
|
||||
embed.set_author(name=str(member), icon_url=member.display_avatar.url)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_ban(self, guild: discord.Guild, user: discord.User) -> None:
|
||||
"""Called when a user is banned."""
|
||||
config = await self.bot.guild_config.get_config(guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Banned",
|
||||
description=f"{user} ({user.id})",
|
||||
color=discord.Color.dark_red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=user.display_avatar.url)
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_unban(self, guild: discord.Guild, user: discord.User) -> None:
|
||||
"""Called when a user is unbanned."""
|
||||
config = await self.bot.guild_config.get_config(guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Unbanned",
|
||||
description=f"{user} ({user.id})",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=user.display_avatar.url)
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Events cog."""
|
||||
await bot.add_cog(Events(bot))
|
||||
466
src/guardden/cogs/moderation.py
Normal file
466
src/guardden/cogs/moderation.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Moderation commands and automod features."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.models import ModerationLog, Strike
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_duration(duration_str: str) -> timedelta | None:
|
||||
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
|
||||
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
units = {
|
||||
"s": timedelta(seconds=amount),
|
||||
"m": timedelta(minutes=amount),
|
||||
"h": timedelta(hours=amount),
|
||||
"d": timedelta(days=amount),
|
||||
"w": timedelta(weeks=amount),
|
||||
}
|
||||
|
||||
return units.get(unit)
|
||||
|
||||
|
||||
class Moderation(commands.Cog):
|
||||
"""Moderation commands for server management."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
async def _log_action(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
target: discord.Member | discord.User,
|
||||
moderator: discord.Member | discord.User,
|
||||
action: str,
|
||||
reason: str | None = None,
|
||||
duration: int | None = None,
|
||||
channel: discord.TextChannel | None = None,
|
||||
message: discord.Message | None = None,
|
||||
is_automatic: bool = False,
|
||||
) -> None:
|
||||
"""Log a moderation action to the database."""
|
||||
expires_at = None
|
||||
if duration:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration)
|
||||
|
||||
async with self.bot.database.session() as session:
|
||||
log_entry = ModerationLog(
|
||||
guild_id=guild.id,
|
||||
target_id=target.id,
|
||||
target_name=str(target),
|
||||
moderator_id=moderator.id,
|
||||
moderator_name=str(moderator),
|
||||
action=action,
|
||||
reason=reason,
|
||||
duration=duration,
|
||||
expires_at=expires_at,
|
||||
channel_id=channel.id if channel else None,
|
||||
message_id=message.id if message else None,
|
||||
message_content=message.content if message else None,
|
||||
is_automatic=is_automatic,
|
||||
)
|
||||
session.add(log_entry)
|
||||
|
||||
async def _get_strike_count(self, guild_id: int, user_id: int) -> int:
|
||||
"""Get the total active strike count for a user."""
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(func.sum(Strike.points)).where(
|
||||
Strike.guild_id == guild_id,
|
||||
Strike.user_id == user_id,
|
||||
Strike.is_active == True,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return total or 0
|
||||
|
||||
async def _add_strike(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
user: discord.Member,
|
||||
moderator: discord.Member | discord.User,
|
||||
reason: str,
|
||||
points: int = 1,
|
||||
) -> int:
|
||||
"""Add a strike to a user and return their new total."""
|
||||
async with self.bot.database.session() as session:
|
||||
strike = Strike(
|
||||
guild_id=guild.id,
|
||||
user_id=user.id,
|
||||
user_name=str(user),
|
||||
moderator_id=moderator.id,
|
||||
reason=reason,
|
||||
points=points,
|
||||
)
|
||||
session.add(strike)
|
||||
|
||||
return await self._get_strike_count(guild.id, user.id)
|
||||
|
||||
@commands.command(name="warn")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def warn(
|
||||
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Warn a member."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot warn someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
await self._log_action(ctx.guild, member, ctx.author, "warn", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Warning Issued",
|
||||
description=f"{member.mention} has been warned.",
|
||||
color=discord.Color.yellow(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
# Try to DM the user
|
||||
try:
|
||||
dm_embed = discord.Embed(
|
||||
title=f"Warning in {ctx.guild.name}",
|
||||
description=f"You have been warned.",
|
||||
color=discord.Color.yellow(),
|
||||
)
|
||||
dm_embed.add_field(name="Reason", value=reason)
|
||||
await member.send(embed=dm_embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
@commands.command(name="strike")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def strike(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
member: discord.Member,
|
||||
points: int = 1,
|
||||
*,
|
||||
reason: str = "No reason provided",
|
||||
) -> None:
|
||||
"""Add a strike to a member."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot strike someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
total_strikes = await self._add_strike(ctx.guild, member, ctx.author, reason, points)
|
||||
await self._log_action(ctx.guild, member, ctx.author, "strike", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Strike Added",
|
||||
description=f"{member.mention} has received {points} strike(s).",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.add_field(name="Total Strikes", value=str(total_strikes))
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
# Check for automatic actions based on strike thresholds
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
if config and config.strike_actions:
|
||||
for threshold, action_config in sorted(
|
||||
config.strike_actions.items(), key=lambda x: int(x[0]), reverse=True
|
||||
):
|
||||
if total_strikes >= int(threshold):
|
||||
action = action_config.get("action")
|
||||
if action == "ban":
|
||||
await ctx.invoke(
|
||||
self.ban, member=member, reason=f"Automatic: {total_strikes} strikes"
|
||||
)
|
||||
elif action == "kick":
|
||||
await ctx.invoke(
|
||||
self.kick, member=member, reason=f"Automatic: {total_strikes} strikes"
|
||||
)
|
||||
elif action == "timeout":
|
||||
duration = action_config.get("duration", 3600)
|
||||
await ctx.invoke(
|
||||
self.timeout,
|
||||
member=member,
|
||||
duration=f"{duration}s",
|
||||
reason=f"Automatic: {total_strikes} strikes",
|
||||
)
|
||||
break
|
||||
|
||||
@commands.command(name="strikes")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def strikes(self, ctx: commands.Context, member: discord.Member) -> None:
|
||||
"""View strikes for a member."""
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(Strike)
|
||||
.where(
|
||||
Strike.guild_id == ctx.guild.id,
|
||||
Strike.user_id == member.id,
|
||||
Strike.is_active == True,
|
||||
)
|
||||
.order_by(Strike.created_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
user_strikes = result.scalars().all()
|
||||
|
||||
total = await self._get_strike_count(ctx.guild.id, member.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"Strikes for {member}",
|
||||
description=f"Total active strikes: **{total}**",
|
||||
color=discord.Color.orange(),
|
||||
)
|
||||
|
||||
if user_strikes:
|
||||
for strike in user_strikes:
|
||||
embed.add_field(
|
||||
name=f"Strike #{strike.id} ({strike.points} pts)",
|
||||
value=f"{strike.reason}\n*{strike.created_at.strftime('%Y-%m-%d')}*",
|
||||
inline=False,
|
||||
)
|
||||
else:
|
||||
embed.description = f"{member.mention} has no active strikes."
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="timeout", aliases=["mute"])
|
||||
@commands.has_permissions(moderate_members=True)
|
||||
@commands.guild_only()
|
||||
async def timeout(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
member: discord.Member,
|
||||
duration: str = "1h",
|
||||
*,
|
||||
reason: str = "No reason provided",
|
||||
) -> None:
|
||||
"""Timeout a member (e.g., !timeout @user 1h Spamming)."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot timeout someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
delta = parse_duration(duration)
|
||||
if not delta:
|
||||
await ctx.send("Invalid duration. Use format like: 30m, 1h, 7d")
|
||||
return
|
||||
|
||||
if delta > timedelta(days=28):
|
||||
await ctx.send("Timeout duration cannot exceed 28 days.")
|
||||
return
|
||||
|
||||
try:
|
||||
await member.timeout(delta, reason=f"{ctx.author}: {reason}")
|
||||
except discord.Forbidden:
|
||||
await ctx.send("I don't have permission to timeout this user.")
|
||||
return
|
||||
except discord.HTTPException as e:
|
||||
await ctx.send(f"Failed to timeout user: {e}")
|
||||
return
|
||||
|
||||
await self._log_action(
|
||||
ctx.guild, member, ctx.author, "timeout", reason, int(delta.total_seconds())
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Timed Out",
|
||||
description=f"{member.mention} has been timed out for {duration}.",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="untimeout", aliases=["unmute"])
|
||||
@commands.has_permissions(moderate_members=True)
|
||||
@commands.guild_only()
|
||||
async def untimeout(
|
||||
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Remove timeout from a member."""
|
||||
await member.timeout(None, reason=f"{ctx.author}: {reason}")
|
||||
await self._log_action(ctx.guild, member, ctx.author, "unmute", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Timeout Removed",
|
||||
description=f"{member.mention}'s timeout has been removed.",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="kick")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def kick(
|
||||
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Kick a member from the server."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot kick someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
# Try to DM the user before kicking
|
||||
try:
|
||||
dm_embed = discord.Embed(
|
||||
title=f"Kicked from {ctx.guild.name}",
|
||||
description=f"You have been kicked from the server.",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
dm_embed.add_field(name="Reason", value=reason)
|
||||
await member.send(embed=dm_embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await member.kick(reason=f"{ctx.author}: {reason}")
|
||||
await self._log_action(ctx.guild, member, ctx.author, "kick", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Kicked",
|
||||
description=f"{member} has been kicked from the server.",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="ban")
|
||||
@commands.has_permissions(ban_members=True)
|
||||
@commands.guild_only()
|
||||
async def ban(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
member: discord.Member | discord.User,
|
||||
*,
|
||||
reason: str = "No reason provided",
|
||||
) -> None:
|
||||
"""Ban a member from the server."""
|
||||
if isinstance(member, discord.Member):
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot ban someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
# Try to DM the user before banning
|
||||
try:
|
||||
dm_embed = discord.Embed(
|
||||
title=f"Banned from {ctx.guild.name}",
|
||||
description=f"You have been banned from the server.",
|
||||
color=discord.Color.dark_red(),
|
||||
)
|
||||
dm_embed.add_field(name="Reason", value=reason)
|
||||
await member.send(embed=dm_embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
|
||||
await self._log_action(ctx.guild, member, ctx.author, "ban", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Banned",
|
||||
description=f"{member} has been banned from the server.",
|
||||
color=discord.Color.dark_red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="unban")
|
||||
@commands.has_permissions(ban_members=True)
|
||||
@commands.guild_only()
|
||||
async def unban(
|
||||
self, ctx: commands.Context, user_id: int, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Unban a user by their ID."""
|
||||
try:
|
||||
user = await self.bot.fetch_user(user_id)
|
||||
await ctx.guild.unban(user, reason=f"{ctx.author}: {reason}")
|
||||
await self._log_action(ctx.guild, user, ctx.author, "unban", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="User Unbanned",
|
||||
description=f"{user} has been unbanned.",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
except discord.NotFound:
|
||||
await ctx.send("User not found or not banned.")
|
||||
except discord.Forbidden:
|
||||
await ctx.send("I don't have permission to unban this user.")
|
||||
|
||||
@commands.command(name="purge", aliases=["clear"])
|
||||
@commands.has_permissions(manage_messages=True)
|
||||
@commands.guild_only()
|
||||
async def purge(self, ctx: commands.Context, amount: int) -> None:
|
||||
"""Delete multiple messages at once (max 100)."""
|
||||
if amount < 1 or amount > 100:
|
||||
await ctx.send("Please specify a number between 1 and 100.")
|
||||
return
|
||||
|
||||
deleted = await ctx.channel.purge(limit=amount + 1) # +1 to include the command message
|
||||
|
||||
msg = await ctx.send(f"Deleted {len(deleted) - 1} message(s).")
|
||||
await msg.delete(delay=3)
|
||||
|
||||
@commands.command(name="modlogs", aliases=["history"])
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def modlogs(self, ctx: commands.Context, member: discord.Member | discord.User) -> None:
|
||||
"""View moderation history for a user."""
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(ModerationLog)
|
||||
.where(ModerationLog.guild_id == ctx.guild.id, ModerationLog.target_id == member.id)
|
||||
.order_by(ModerationLog.created_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
logs = result.scalars().all()
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"Moderation History for {member}",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
if logs:
|
||||
for log in logs:
|
||||
value = f"**Reason:** {log.reason or 'None'}\n**By:** {log.moderator_name}\n*{log.created_at.strftime('%Y-%m-%d %H:%M')}*"
|
||||
embed.add_field(name=f"{log.action.upper()} (#{log.id})", value=value, inline=False)
|
||||
else:
|
||||
embed.description = "No moderation history found."
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Moderation cog."""
|
||||
await bot.add_cog(Moderation(bot))
|
||||
423
src/guardden/cogs/verification.py
Normal file
423
src/guardden/cogs/verification.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Verification cog for new member verification."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import discord
|
||||
from discord import ui
|
||||
from discord.ext import commands, tasks
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.verification import (
|
||||
ChallengeType,
|
||||
PendingVerification,
|
||||
VerificationService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VerifyButton(ui.Button["VerificationView"]):
|
||||
"""Button for simple verification."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
style=discord.ButtonStyle.success,
|
||||
label="Verify",
|
||||
custom_id="verify_button",
|
||||
)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
if self.view is None:
|
||||
return
|
||||
|
||||
success, message = await self.view.cog.complete_verification(
|
||||
interaction.guild.id,
|
||||
interaction.user.id,
|
||||
"verified",
|
||||
)
|
||||
|
||||
if success:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
# Disable the button
|
||||
self.disabled = True
|
||||
self.label = "Verified"
|
||||
await interaction.message.edit(view=self.view)
|
||||
else:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
|
||||
|
||||
class EmojiButton(ui.Button["EmojiVerificationView"]):
|
||||
"""Button for emoji selection verification."""
|
||||
|
||||
def __init__(self, emoji: str, row: int = 0) -> None:
|
||||
super().__init__(
|
||||
style=discord.ButtonStyle.secondary,
|
||||
label=emoji,
|
||||
custom_id=f"emoji_{emoji}",
|
||||
row=row,
|
||||
)
|
||||
self.emoji_value = emoji
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
if self.view is None:
|
||||
return
|
||||
|
||||
success, message = await self.view.cog.complete_verification(
|
||||
interaction.guild.id,
|
||||
interaction.user.id,
|
||||
self.emoji_value,
|
||||
)
|
||||
|
||||
if success:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
# Disable all buttons
|
||||
for item in self.view.children:
|
||||
if isinstance(item, ui.Button):
|
||||
item.disabled = True
|
||||
await interaction.message.edit(view=self.view)
|
||||
else:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
|
||||
|
||||
class VerificationView(ui.View):
|
||||
"""View for button verification."""
|
||||
|
||||
def __init__(self, cog: "Verification", timeout: float = 600) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.cog = cog
|
||||
self.add_item(VerifyButton())
|
||||
|
||||
|
||||
class EmojiVerificationView(ui.View):
|
||||
"""View for emoji selection verification."""
|
||||
|
||||
def __init__(self, cog: "Verification", options: list[str], timeout: float = 600) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.cog = cog
|
||||
for i, emoji in enumerate(options):
|
||||
self.add_item(EmojiButton(emoji, row=i // 4))
|
||||
|
||||
|
||||
class CaptchaModal(ui.Modal):
|
||||
"""Modal for captcha/math input."""
|
||||
|
||||
answer = ui.TextInput(
|
||||
label="Your Answer",
|
||||
placeholder="Enter the answer here...",
|
||||
max_length=50,
|
||||
)
|
||||
|
||||
def __init__(self, cog: "Verification", title: str = "Verification") -> None:
|
||||
super().__init__(title=title)
|
||||
self.cog = cog
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||
success, message = await self.cog.complete_verification(
|
||||
interaction.guild.id,
|
||||
interaction.user.id,
|
||||
self.answer.value,
|
||||
)
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
|
||||
|
||||
class AnswerButton(ui.Button["AnswerView"]):
|
||||
"""Button to open the answer modal."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
style=discord.ButtonStyle.primary,
|
||||
label="Submit Answer",
|
||||
custom_id="submit_answer",
|
||||
)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
if self.view is None:
|
||||
return
|
||||
modal = CaptchaModal(self.view.cog)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
|
||||
class AnswerView(ui.View):
|
||||
"""View with button to open answer modal."""
|
||||
|
||||
def __init__(self, cog: "Verification", timeout: float = 600) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.cog = cog
|
||||
self.add_item(AnswerButton())
|
||||
|
||||
|
||||
class Verification(commands.Cog):
|
||||
"""Member verification system."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
self.service = VerificationService()
|
||||
self.cleanup_task.start()
|
||||
|
||||
def cog_unload(self) -> None:
|
||||
self.cleanup_task.cancel()
|
||||
|
||||
@tasks.loop(minutes=5)
|
||||
async def cleanup_task(self) -> None:
|
||||
"""Periodically clean up expired verifications."""
|
||||
count = self.service.cleanup_expired()
|
||||
if count > 0:
|
||||
logger.debug(f"Cleaned up {count} expired verifications")
|
||||
|
||||
@cleanup_task.before_loop
|
||||
async def before_cleanup(self) -> None:
|
||||
await self.bot.wait_until_ready()
|
||||
|
||||
async def complete_verification(
|
||||
self, guild_id: int, user_id: int, response: str
|
||||
) -> tuple[bool, str]:
|
||||
"""Complete a verification and assign role if successful."""
|
||||
success, message = self.service.verify(guild_id, user_id, response)
|
||||
|
||||
if success:
|
||||
# Assign verified role
|
||||
guild = self.bot.get_guild(guild_id)
|
||||
if guild:
|
||||
member = guild.get_member(user_id)
|
||||
config = await self.bot.guild_config.get_config(guild_id)
|
||||
|
||||
if member and config and config.verified_role_id:
|
||||
role = guild.get_role(config.verified_role_id)
|
||||
if role:
|
||||
try:
|
||||
await member.add_roles(role, reason="Verification completed")
|
||||
logger.info(f"Verified {member} in {guild.name}")
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot assign verified role in {guild.name}")
|
||||
|
||||
return success, message
|
||||
|
||||
async def send_verification(
|
||||
self,
|
||||
member: discord.Member,
|
||||
channel: discord.TextChannel,
|
||||
challenge_type: ChallengeType,
|
||||
) -> None:
|
||||
"""Send a verification challenge to a member."""
|
||||
pending = self.service.create_challenge(
|
||||
user_id=member.id,
|
||||
guild_id=member.guild.id,
|
||||
challenge_type=challenge_type,
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Verification Required",
|
||||
description=pending.challenge.question,
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_footer(
|
||||
text=f"Expires in 10 minutes • {pending.challenge.max_attempts} attempts allowed"
|
||||
)
|
||||
|
||||
# Create appropriate view based on challenge type
|
||||
if challenge_type == ChallengeType.BUTTON:
|
||||
view = VerificationView(self)
|
||||
elif challenge_type == ChallengeType.EMOJI:
|
||||
view = EmojiVerificationView(self, pending.challenge.options)
|
||||
else:
|
||||
# Captcha or Math - use modal
|
||||
view = AnswerView(self)
|
||||
|
||||
try:
|
||||
# Try to DM the user first
|
||||
dm_channel = await member.create_dm()
|
||||
msg = await dm_channel.send(embed=embed, view=view)
|
||||
pending.message_id = msg.id
|
||||
pending.channel_id = dm_channel.id
|
||||
except discord.Forbidden:
|
||||
# Fall back to channel mention
|
||||
msg = await channel.send(
|
||||
content=member.mention,
|
||||
embed=embed,
|
||||
view=view,
|
||||
)
|
||||
pending.message_id = msg.id
|
||||
pending.channel_id = channel.id
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_join(self, member: discord.Member) -> None:
|
||||
"""Handle new member joins for verification."""
|
||||
if member.bot:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.verification_enabled:
|
||||
return
|
||||
|
||||
# Determine verification channel
|
||||
channel_id = config.welcome_channel_id or config.log_channel_id
|
||||
if not channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
# Get challenge type from config
|
||||
try:
|
||||
challenge_type = ChallengeType(config.verification_type)
|
||||
except ValueError:
|
||||
challenge_type = ChallengeType.BUTTON
|
||||
|
||||
await self.send_verification(member, channel, challenge_type)
|
||||
|
||||
@commands.group(name="verify", invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def verify_cmd(self, ctx: commands.Context) -> None:
|
||||
"""Request a verification challenge."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
if not config or not config.verification_enabled:
|
||||
await ctx.send("Verification is not enabled on this server.")
|
||||
return
|
||||
|
||||
# Check if already verified
|
||||
if config.verified_role_id:
|
||||
role = ctx.guild.get_role(config.verified_role_id)
|
||||
if role and role in ctx.author.roles:
|
||||
await ctx.send("You are already verified!")
|
||||
return
|
||||
|
||||
# Check for existing pending verification
|
||||
pending = self.service.get_pending(ctx.guild.id, ctx.author.id)
|
||||
if pending and not pending.challenge.is_expired:
|
||||
await ctx.send("You already have a pending verification. Please complete it first.")
|
||||
return
|
||||
|
||||
# Get challenge type
|
||||
try:
|
||||
challenge_type = ChallengeType(config.verification_type)
|
||||
except ValueError:
|
||||
challenge_type = ChallengeType.BUTTON
|
||||
|
||||
await self.send_verification(ctx.author, ctx.channel, challenge_type)
|
||||
await ctx.message.delete(delay=1)
|
||||
|
||||
@verify_cmd.command(name="setup")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_setup(self, ctx: commands.Context) -> None:
|
||||
"""View verification setup status."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Verification Setup",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Enabled",
|
||||
value="✅ Yes" if config and config.verification_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Type",
|
||||
value=config.verification_type if config else "button",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
if config and config.verified_role_id:
|
||||
role = ctx.guild.get_role(config.verified_role_id)
|
||||
embed.add_field(
|
||||
name="Verified Role",
|
||||
value=role.mention if role else "Not found",
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(name="Verified Role", value="Not set", inline=True)
|
||||
|
||||
pending_count = self.service.get_pending_count(ctx.guild.id)
|
||||
embed.add_field(name="Pending Verifications", value=str(pending_count), inline=True)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@verify_cmd.command(name="enable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_enable(self, ctx: commands.Context) -> None:
|
||||
"""Enable verification for new members."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
if not config or not config.verified_role_id:
|
||||
await ctx.send("Please set a verified role first with `!verify role @role`")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=True)
|
||||
await ctx.send("✅ Verification enabled for new members.")
|
||||
|
||||
@verify_cmd.command(name="disable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_disable(self, ctx: commands.Context) -> None:
|
||||
"""Disable verification."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=False)
|
||||
await ctx.send("❌ Verification disabled.")
|
||||
|
||||
@verify_cmd.command(name="role")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_role(self, ctx: commands.Context, role: discord.Role) -> None:
|
||||
"""Set the role given upon verification."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, verified_role_id=role.id)
|
||||
await ctx.send(f"Verified role set to {role.mention}")
|
||||
|
||||
@verify_cmd.command(name="type")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_type(self, ctx: commands.Context, vtype: str) -> None:
|
||||
"""Set verification type (button, captcha, math, emoji)."""
|
||||
try:
|
||||
challenge_type = ChallengeType(vtype.lower())
|
||||
except ValueError:
|
||||
valid = ", ".join(t.value for t in ChallengeType if t != ChallengeType.QUESTIONS)
|
||||
await ctx.send(f"Invalid type. Valid options: {valid}")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(
|
||||
ctx.guild.id, verification_type=challenge_type.value
|
||||
)
|
||||
await ctx.send(f"Verification type set to **{challenge_type.value}**")
|
||||
|
||||
@verify_cmd.command(name="test")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_test(self, ctx: commands.Context, vtype: str = "button") -> None:
|
||||
"""Test verification (sends challenge to you)."""
|
||||
try:
|
||||
challenge_type = ChallengeType(vtype.lower())
|
||||
except ValueError:
|
||||
challenge_type = ChallengeType.BUTTON
|
||||
|
||||
await self.send_verification(ctx.author, ctx.channel, challenge_type)
|
||||
|
||||
@verify_cmd.command(name="reset")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def verify_reset(self, ctx: commands.Context, member: discord.Member) -> None:
|
||||
"""Reset verification for a member (remove role and cancel pending)."""
|
||||
# Cancel any pending verification
|
||||
self.service.cancel(ctx.guild.id, member.id)
|
||||
|
||||
# Remove verified role
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
if config and config.verified_role_id:
|
||||
role = ctx.guild.get_role(config.verified_role_id)
|
||||
if role and role in member.roles:
|
||||
try:
|
||||
await member.remove_roles(role, reason=f"Verification reset by {ctx.author}")
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await ctx.send(f"Reset verification for {member.mention}")
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Verification cog."""
|
||||
await bot.add_cog(Verification(bot))
|
||||
50
src/guardden/config.py
Normal file
50
src/guardden/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Configuration management for GuardDen."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
env_prefix="GUARDDEN_",
|
||||
)
|
||||
|
||||
# Discord settings
|
||||
discord_token: SecretStr = Field(..., description="Discord bot token")
|
||||
discord_prefix: str = Field(default="!", description="Default command prefix")
|
||||
|
||||
# Database settings
|
||||
database_url: SecretStr = Field(
|
||||
default=SecretStr("postgresql://guardden:guardden@localhost:5432/guardden"),
|
||||
description="PostgreSQL connection URL",
|
||||
)
|
||||
database_pool_min: int = Field(default=5, description="Minimum database pool size")
|
||||
database_pool_max: int = Field(default=20, description="Maximum database pool size")
|
||||
|
||||
# AI settings (optional)
|
||||
ai_provider: Literal["anthropic", "openai", "none"] = Field(
|
||||
default="none", description="AI provider for content moderation"
|
||||
)
|
||||
anthropic_api_key: SecretStr | None = Field(default=None, description="Anthropic API key")
|
||||
openai_api_key: SecretStr | None = Field(default=None, description="OpenAI API key")
|
||||
|
||||
# Logging
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||
default="INFO", description="Logging level"
|
||||
)
|
||||
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
return Settings()
|
||||
15
src/guardden/models/__init__.py
Normal file
15
src/guardden/models/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Database models for GuardDen."""
|
||||
|
||||
from guardden.models.base import Base
|
||||
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Guild",
|
||||
"GuildSettings",
|
||||
"BannedWord",
|
||||
"ModerationLog",
|
||||
"Strike",
|
||||
"UserNote",
|
||||
]
|
||||
32
src/guardden/models/base.py
Normal file
32
src/guardden/models/base.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Base model and database utilities."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin that adds created_at and updated_at timestamps."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
# Type alias for Discord snowflake IDs (64-bit integers)
|
||||
SnowflakeID = BigInteger
|
||||
117
src/guardden/models/guild.py
Normal file
117
src/guardden/models/guild.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Guild-related database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guardden.models.moderation import ModerationLog, Strike
|
||||
|
||||
|
||||
class Guild(Base, TimestampMixin):
|
||||
"""Represents a Discord guild (server) configuration."""
|
||||
|
||||
__tablename__ = "guilds"
|
||||
|
||||
id: Mapped[int] = mapped_column(SnowflakeID, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
owner_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
premium: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
settings: Mapped["GuildSettings"] = relationship(
|
||||
back_populates="guild", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
banned_words: Mapped[list["BannedWord"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
moderation_logs: Mapped[list["ModerationLog"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
strikes: Mapped[list["Strike"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class GuildSettings(Base, TimestampMixin):
|
||||
"""Per-guild bot settings and configuration."""
|
||||
|
||||
__tablename__ = "guild_settings"
|
||||
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
# General settings
|
||||
prefix: Mapped[str] = mapped_column(String(10), default="!", nullable=False)
|
||||
locale: Mapped[str] = mapped_column(String(10), default="en", nullable=False)
|
||||
|
||||
# Channel configuration (stored as snowflake IDs)
|
||||
log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
mod_log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
welcome_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
|
||||
# Role configuration
|
||||
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Moderation settings
|
||||
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
anti_spam_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
link_filter_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Strike thresholds (actions at each threshold)
|
||||
strike_actions: Mapped[dict] = mapped_column(
|
||||
JSONB,
|
||||
default=lambda: {
|
||||
"1": {"action": "warn"},
|
||||
"3": {"action": "timeout", "duration": 3600},
|
||||
"5": {"action": "kick"},
|
||||
"7": {"action": "ban"},
|
||||
},
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# AI moderation settings
|
||||
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
|
||||
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Verification settings
|
||||
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
verification_type: Mapped[str] = mapped_column(
|
||||
String(20), default="button", nullable=False
|
||||
) # button, captcha, questions
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="settings")
|
||||
|
||||
|
||||
class BannedWord(Base, TimestampMixin):
|
||||
"""Banned words/phrases for a guild with regex support."""
|
||||
|
||||
__tablename__ = "banned_words"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
pattern: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
is_regex: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
action: Mapped[str] = mapped_column(
|
||||
String(20), default="delete", nullable=False
|
||||
) # delete, warn, strike
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Who added this and when
|
||||
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="banned_words")
|
||||
101
src/guardden/models/moderation.py
Normal file
101
src/guardden/models/moderation.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Moderation-related database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||
from guardden.models.guild import Guild
|
||||
|
||||
|
||||
class ModAction(str, Enum):
|
||||
"""Types of moderation actions."""
|
||||
|
||||
WARN = "warn"
|
||||
TIMEOUT = "timeout"
|
||||
KICK = "kick"
|
||||
BAN = "ban"
|
||||
UNBAN = "unban"
|
||||
UNMUTE = "unmute"
|
||||
NOTE = "note"
|
||||
STRIKE = "strike"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class ModerationLog(Base, TimestampMixin):
|
||||
"""Log of all moderation actions taken."""
|
||||
|
||||
__tablename__ = "moderation_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Target and moderator
|
||||
target_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
target_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
moderator_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
|
||||
# Action details
|
||||
action: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
duration: Mapped[int | None] = mapped_column(Integer, nullable=True) # Duration in seconds
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Context
|
||||
channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
message_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
message_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Was this an automatic action?
|
||||
is_automatic: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="moderation_logs")
|
||||
|
||||
|
||||
class Strike(Base, TimestampMixin):
|
||||
"""User strikes/warnings tracking."""
|
||||
|
||||
__tablename__ = "strikes"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
user_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
points: Mapped[int] = mapped_column(Integer, default=1, nullable=False)
|
||||
|
||||
# Strikes can expire
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Reference to the moderation log entry
|
||||
mod_log_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("moderation_logs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="strikes")
|
||||
|
||||
|
||||
class UserNote(Base, TimestampMixin):
|
||||
"""Moderator notes on users."""
|
||||
|
||||
__tablename__ = "user_notes"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
16
src/guardden/services/__init__.py
Normal file
16
src/guardden/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Services for GuardDen."""
|
||||
|
||||
from guardden.services.automod import AutomodService
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
|
||||
from guardden.services.verification import ChallengeType, VerificationService
|
||||
|
||||
__all__ = [
|
||||
"AutomodService",
|
||||
"ChallengeType",
|
||||
"Database",
|
||||
"RateLimiter",
|
||||
"VerificationService",
|
||||
"get_rate_limiter",
|
||||
"ratelimit",
|
||||
]
|
||||
6
src/guardden/services/ai/__init__.py
Normal file
6
src/guardden/services/ai/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""AI services for content moderation."""
|
||||
|
||||
from guardden.services.ai.base import AIProvider, ModerationResult
|
||||
from guardden.services.ai.factory import create_ai_provider
|
||||
|
||||
__all__ = ["AIProvider", "ModerationResult", "create_ai_provider"]
|
||||
261
src/guardden/services/ai/anthropic_provider.py
Normal file
261
src/guardden/services/ai/anthropic_provider.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Anthropic Claude AI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Content moderation system prompt
|
||||
MODERATION_SYSTEM_PROMPT = """You are a content moderation AI for a Discord server. Analyze the given message and determine if it violates community guidelines.
|
||||
|
||||
Categories to check:
|
||||
- harassment: Personal attacks, bullying, intimidation
|
||||
- hate_speech: Discrimination, slurs, dehumanization based on identity
|
||||
- sexual: Explicit sexual content, sexual solicitation
|
||||
- violence: Threats, graphic violence, encouraging harm
|
||||
- self_harm: Suicide, self-injury content or encouragement
|
||||
- spam: Repetitive, promotional, or low-quality content
|
||||
- scam: Phishing attempts, fraudulent offers, impersonation
|
||||
- misinformation: Dangerous false information
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_flagged": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"categories": ["category1", "category2"],
|
||||
"explanation": "Brief explanation",
|
||||
"suggested_action": "none/warn/delete/timeout/ban"
|
||||
}
|
||||
|
||||
Be balanced - flag genuinely problematic content but allow normal conversation, jokes, and mild language. Consider context."""
|
||||
|
||||
IMAGE_ANALYSIS_PROMPT = """Analyze this image for content moderation purposes. Check for:
|
||||
- NSFW content (nudity, sexual content)
|
||||
- Violence or gore
|
||||
- Disturbing or shocking content
|
||||
- Any content inappropriate for a general audience
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_nsfw": true/false,
|
||||
"is_violent": true/false,
|
||||
"is_disturbing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"description": "Brief description of the image",
|
||||
"categories": ["category1", "category2"]
|
||||
}
|
||||
|
||||
Be accurate but not overly sensitive - artistic nudity or mild violence in appropriate contexts may be acceptable."""
|
||||
|
||||
PHISHING_ANALYSIS_PROMPT = """Analyze this URL and message context for phishing or scam indicators.
|
||||
|
||||
Check for:
|
||||
- Domain impersonation (typosquatting, lookalike domains)
|
||||
- Urgency tactics ("act now", "limited time")
|
||||
- Requests for credentials or personal info
|
||||
- Too-good-to-be-true offers
|
||||
- Suspicious redirects or URL shorteners
|
||||
- Mismatched or hidden URLs
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_phishing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"risk_factors": ["factor1", "factor2"],
|
||||
"explanation": "Brief explanation"
|
||||
}"""
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
"""AI provider using Anthropic's Claude API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "claude-3-haiku-20240307") -> None:
|
||||
"""
|
||||
Initialize Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model to use (default: claude-3-haiku for speed/cost)
|
||||
"""
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("anthropic package required. Install with: pip install anthropic")
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
self.model = model
|
||||
logger.info(f"Initialized Anthropic provider with model: {model}")
|
||||
|
||||
async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str:
|
||||
"""Make an API call to Claude."""
|
||||
try:
|
||||
message = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": user_content}],
|
||||
)
|
||||
return message.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
def _parse_json_response(self, response: str) -> dict:
|
||||
"""Parse JSON from response, handling markdown code blocks."""
|
||||
import json
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
text = response.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
# Remove first and last lines (```json and ```)
|
||||
text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
return json.loads(text)
|
||||
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""Analyze text content for policy violations."""
|
||||
# Adjust prompt based on sensitivity
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = "\n\nBe lenient - only flag clearly problematic content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = "\n\nBe strict - flag anything potentially problematic."
|
||||
|
||||
system = MODERATION_SYSTEM_PROMPT + sensitivity_note
|
||||
|
||||
user_message = f"Message to analyze:\n{content}"
|
||||
if context:
|
||||
user_message = f"Context: {context}\n\n{user_message}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
categories = [
|
||||
ContentCategory(cat)
|
||||
for cat in data.get("categories", [])
|
||||
if cat in ContentCategory.__members__.values()
|
||||
]
|
||||
|
||||
return ModerationResult(
|
||||
is_flagged=data.get("is_flagged", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
categories=categories,
|
||||
explanation=data.get("explanation", ""),
|
||||
suggested_action=data.get("suggested_action", "none"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moderating text: {e}")
|
||||
return ModerationResult(
|
||||
is_flagged=False,
|
||||
explanation=f"Error analyzing content: {str(e)}",
|
||||
)
|
||||
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""Analyze an image for NSFW or inappropriate content."""
|
||||
import base64
|
||||
|
||||
import aiohttp
|
||||
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = "\n\nBe lenient - only flag explicit content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = "\n\nBe strict - flag suggestive content as well."
|
||||
|
||||
system = IMAGE_ANALYSIS_PROMPT + sensitivity_note
|
||||
|
||||
try:
|
||||
# Download image and convert to base64
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status != 200:
|
||||
return ImageAnalysisResult(
|
||||
description=f"Failed to download image: HTTP {resp.status}"
|
||||
)
|
||||
|
||||
content_type = resp.content_type or "image/jpeg"
|
||||
image_data = await resp.read()
|
||||
|
||||
# Check file size (max 20MB for Claude)
|
||||
if len(image_data) > 20 * 1024 * 1024:
|
||||
return ImageAnalysisResult(description="Image too large to analyze")
|
||||
|
||||
base64_image = base64.standard_b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Create multimodal message
|
||||
user_content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": content_type,
|
||||
"data": base64_image,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Analyze this image for content moderation."},
|
||||
]
|
||||
|
||||
response = await self._call_api(system, user_content)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return ImageAnalysisResult(
|
||||
is_nsfw=data.get("is_nsfw", False),
|
||||
is_violent=data.get("is_violent", False),
|
||||
is_disturbing=data.get("is_disturbing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
description=data.get("description", ""),
|
||||
categories=data.get("categories", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing image: {e}")
|
||||
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""Analyze a URL for phishing/scam indicators."""
|
||||
user_message = f"URL to analyze: {url}"
|
||||
if message_content:
|
||||
user_message += f"\n\nFull message context:\n{message_content}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(PHISHING_ANALYSIS_PROMPT, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return PhishingAnalysisResult(
|
||||
is_phishing=data.get("is_phishing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
risk_factors=data.get("risk_factors", []),
|
||||
explanation=data.get("explanation", ""),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phishing: {e}")
|
||||
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
await self.client.close()
|
||||
149
src/guardden/services/ai/base.py
Normal file
149
src/guardden/services/ai/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Base classes for AI providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of problematic content."""
|
||||
|
||||
SAFE = "safe"
|
||||
HARASSMENT = "harassment"
|
||||
HATE_SPEECH = "hate_speech"
|
||||
SEXUAL = "sexual"
|
||||
VIOLENCE = "violence"
|
||||
SELF_HARM = "self_harm"
|
||||
SPAM = "spam"
|
||||
SCAM = "scam"
|
||||
MISINFORMATION = "misinformation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModerationResult:
|
||||
"""Result of AI content moderation."""
|
||||
|
||||
is_flagged: bool = False
|
||||
confidence: float = 0.0 # 0.0 to 1.0
|
||||
categories: list[ContentCategory] = field(default_factory=list)
|
||||
explanation: str = ""
|
||||
suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none"
|
||||
|
||||
@property
|
||||
def severity(self) -> int:
|
||||
"""Get severity score 0-100 based on confidence and categories."""
|
||||
if not self.is_flagged:
|
||||
return 0
|
||||
|
||||
# Base severity from confidence
|
||||
severity = int(self.confidence * 50)
|
||||
|
||||
# Add severity based on category
|
||||
high_severity = {
|
||||
ContentCategory.HATE_SPEECH,
|
||||
ContentCategory.SELF_HARM,
|
||||
ContentCategory.SCAM,
|
||||
}
|
||||
medium_severity = {
|
||||
ContentCategory.HARASSMENT,
|
||||
ContentCategory.VIOLENCE,
|
||||
ContentCategory.SEXUAL,
|
||||
}
|
||||
|
||||
for cat in self.categories:
|
||||
if cat in high_severity:
|
||||
severity += 30
|
||||
elif cat in medium_severity:
|
||||
severity += 20
|
||||
else:
|
||||
severity += 10
|
||||
|
||||
return min(severity, 100)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnalysisResult:
|
||||
"""Result of AI image analysis."""
|
||||
|
||||
is_nsfw: bool = False
|
||||
is_violent: bool = False
|
||||
is_disturbing: bool = False
|
||||
confidence: float = 0.0
|
||||
description: str = ""
|
||||
categories: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhishingAnalysisResult:
|
||||
"""Result of AI phishing/scam analysis."""
|
||||
|
||||
is_phishing: bool = False
|
||||
confidence: float = 0.0
|
||||
risk_factors: list[str] = field(default_factory=list)
|
||||
explanation: str = ""
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""
|
||||
Analyze text content for policy violations.
|
||||
|
||||
Args:
|
||||
content: The text to analyze
|
||||
context: Optional context about the conversation/server
|
||||
sensitivity: 0-100, higher means more strict
|
||||
|
||||
Returns:
|
||||
ModerationResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""
|
||||
Analyze an image for NSFW or inappropriate content.
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to analyze
|
||||
sensitivity: 0-100, higher means more strict
|
||||
|
||||
Returns:
|
||||
ImageAnalysisResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""
|
||||
Analyze a URL for phishing/scam indicators.
|
||||
|
||||
Args:
|
||||
url: The URL to analyze
|
||||
message_content: Optional full message for context
|
||||
|
||||
Returns:
|
||||
PhishingAnalysisResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
pass
|
||||
67
src/guardden/services/ai/factory.py
Normal file
67
src/guardden/services/ai/factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Factory for creating AI providers."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from guardden.services.ai.base import AIProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NullProvider(AIProvider):
|
||||
"""Null provider that does nothing (for when AI is disabled)."""
|
||||
|
||||
async def moderate_text(self, content, context=None, sensitivity=50):
|
||||
from guardden.services.ai.base import ModerationResult
|
||||
|
||||
return ModerationResult()
|
||||
|
||||
async def analyze_image(self, image_url, sensitivity=50):
|
||||
from guardden.services.ai.base import ImageAnalysisResult
|
||||
|
||||
return ImageAnalysisResult()
|
||||
|
||||
async def analyze_phishing(self, url, message_content=None):
|
||||
from guardden.services.ai.base import PhishingAnalysisResult
|
||||
|
||||
return PhishingAnalysisResult()
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_ai_provider(
|
||||
provider: Literal["anthropic", "openai", "none"],
|
||||
api_key: str | None = None,
|
||||
) -> AIProvider:
|
||||
"""
|
||||
Create an AI provider instance.
|
||||
|
||||
Args:
|
||||
provider: The provider type to create
|
||||
api_key: API key for the provider
|
||||
|
||||
Returns:
|
||||
AIProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is unknown or API key is missing
|
||||
"""
|
||||
if provider == "none":
|
||||
logger.info("AI moderation disabled")
|
||||
return NullProvider()
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"API key required for {provider} provider")
|
||||
|
||||
if provider == "anthropic":
|
||||
from guardden.services.ai.anthropic_provider import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(api_key)
|
||||
|
||||
if provider == "openai":
|
||||
from guardden.services.ai.openai_provider import OpenAIProvider
|
||||
|
||||
return OpenAIProvider(api_key)
|
||||
|
||||
raise ValueError(f"Unknown AI provider: {provider}")
|
||||
213
src/guardden/services/ai/openai_provider.py
Normal file
213
src/guardden/services/ai/openai_provider.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""OpenAI AI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
"""AI provider using OpenAI's API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None:
|
||||
"""
|
||||
Initialize OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model to use (default: gpt-4o-mini for speed/cost)
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ImportError("openai package required. Install with: pip install openai")
|
||||
|
||||
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
logger.info(f"Initialized OpenAI provider with model: {model}")
|
||||
|
||||
async def _call_api(
|
||||
self,
|
||||
system: str,
|
||||
user_content: Any,
|
||||
max_tokens: int = 500,
|
||||
) -> str:
|
||||
"""Make an API call to OpenAI."""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
def _parse_json_response(self, response: str) -> dict:
|
||||
"""Parse JSON from response."""
|
||||
import json
|
||||
|
||||
return json.loads(response)
|
||||
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""Analyze text content for policy violations."""
|
||||
# First, use OpenAI's built-in moderation API for quick check
|
||||
try:
|
||||
mod_response = await self.client.moderations.create(input=content)
|
||||
results = mod_response.results[0]
|
||||
|
||||
# Map OpenAI categories to our categories
|
||||
category_mapping = {
|
||||
"harassment": ContentCategory.HARASSMENT,
|
||||
"harassment/threatening": ContentCategory.HARASSMENT,
|
||||
"hate": ContentCategory.HATE_SPEECH,
|
||||
"hate/threatening": ContentCategory.HATE_SPEECH,
|
||||
"self-harm": ContentCategory.SELF_HARM,
|
||||
"self-harm/intent": ContentCategory.SELF_HARM,
|
||||
"self-harm/instructions": ContentCategory.SELF_HARM,
|
||||
"sexual": ContentCategory.SEXUAL,
|
||||
"sexual/minors": ContentCategory.SEXUAL,
|
||||
"violence": ContentCategory.VIOLENCE,
|
||||
"violence/graphic": ContentCategory.VIOLENCE,
|
||||
}
|
||||
|
||||
flagged_categories = []
|
||||
max_score = 0.0
|
||||
|
||||
for category, score in results.category_scores.model_dump().items():
|
||||
if score > 0.5: # Threshold
|
||||
if category in category_mapping:
|
||||
flagged_categories.append(category_mapping[category])
|
||||
max_score = max(max_score, score)
|
||||
|
||||
# Adjust threshold based on sensitivity
|
||||
threshold = 0.3 + (0.4 * (100 - sensitivity) / 100) # 0.3 to 0.7
|
||||
|
||||
if results.flagged or max_score > threshold:
|
||||
return ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=max_score,
|
||||
categories=list(set(flagged_categories)),
|
||||
explanation="Content flagged by moderation API",
|
||||
suggested_action="delete" if max_score > 0.8 else "warn",
|
||||
)
|
||||
|
||||
return ModerationResult(is_flagged=False, confidence=1.0 - max_score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moderating text: {e}")
|
||||
return ModerationResult(
|
||||
is_flagged=False,
|
||||
explanation=f"Error analyzing content: {str(e)}",
|
||||
)
|
||||
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""Analyze an image for NSFW or inappropriate content."""
|
||||
system = """Analyze this image for content moderation. Respond in JSON format:
|
||||
{
|
||||
"is_nsfw": true/false,
|
||||
"is_violent": true/false,
|
||||
"is_disturbing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"description": "Brief description",
|
||||
"categories": ["category1"]
|
||||
}"""
|
||||
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = " Be lenient - only flag explicit content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = " Be strict - flag suggestive content."
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="gpt-4o-mini", # Use vision-capable model
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{"role": "system", "content": system + sensitivity_note},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Analyze this image for moderation."},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
data = self._parse_json_response(response.choices[0].message.content or "{}")
|
||||
|
||||
return ImageAnalysisResult(
|
||||
is_nsfw=data.get("is_nsfw", False),
|
||||
is_violent=data.get("is_violent", False),
|
||||
is_disturbing=data.get("is_disturbing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
description=data.get("description", ""),
|
||||
categories=data.get("categories", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing image: {e}")
|
||||
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""Analyze a URL for phishing/scam indicators."""
|
||||
system = """Analyze the URL for phishing/scam indicators. Respond in JSON:
|
||||
{
|
||||
"is_phishing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"risk_factors": ["factor1"],
|
||||
"explanation": "Brief explanation"
|
||||
}
|
||||
|
||||
Check for: domain impersonation, urgency tactics, credential requests, too-good-to-be-true offers."""
|
||||
|
||||
user_message = f"URL: {url}"
|
||||
if message_content:
|
||||
user_message += f"\n\nMessage context: {message_content}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return PhishingAnalysisResult(
|
||||
is_phishing=data.get("is_phishing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
risk_factors=data.get("risk_factors", []),
|
||||
explanation=data.get("explanation", ""),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phishing: {e}")
|
||||
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
await self.client.close()
|
||||
301
src/guardden/services/automod.py
Normal file
301
src/guardden/services/automod.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Automod service for content filtering and spam detection."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple
|
||||
|
||||
import discord
|
||||
|
||||
from guardden.models import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Known scam/phishing patterns
|
||||
SCAM_PATTERNS = [
|
||||
# Discord scam patterns
|
||||
r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}",
|
||||
r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro",
|
||||
r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)",
|
||||
# Generic phishing
|
||||
r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account",
|
||||
r"(?:suspended|locked|limited)[-.\s]?account",
|
||||
r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)",
|
||||
# Crypto scams
|
||||
r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)",
|
||||
r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)",
|
||||
]
|
||||
|
||||
# Suspicious TLDs often used in phishing
|
||||
SUSPICIOUS_TLDS = {
|
||||
".xyz",
|
||||
".top",
|
||||
".club",
|
||||
".work",
|
||||
".click",
|
||||
".link",
|
||||
".info",
|
||||
".ru",
|
||||
".cn",
|
||||
".tk",
|
||||
".ml",
|
||||
".ga",
|
||||
".cf",
|
||||
".gq",
|
||||
}
|
||||
|
||||
# URL pattern for extraction
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
|
||||
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class SpamRecord(NamedTuple):
|
||||
"""Record of a message for spam tracking."""
|
||||
|
||||
content_hash: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSpamTracker:
|
||||
"""Tracks spam behavior for a single user."""
|
||||
|
||||
messages: list[SpamRecord] = field(default_factory=list)
|
||||
mention_count: int = 0
|
||||
last_mention_time: datetime | None = None
|
||||
duplicate_count: int = 0
|
||||
last_action_time: datetime | None = None
|
||||
|
||||
def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None:
|
||||
"""Remove old messages from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
self.messages = [m for m in self.messages if m.timestamp > cutoff]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutomodResult:
|
||||
"""Result of automod check."""
|
||||
|
||||
should_delete: bool = False
|
||||
should_warn: bool = False
|
||||
should_strike: bool = False
|
||||
should_timeout: bool = False
|
||||
timeout_duration: int = 0 # seconds
|
||||
reason: str = ""
|
||||
matched_filter: str = ""
|
||||
|
||||
|
||||
class AutomodService:
|
||||
"""Service for automatic content moderation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Compile scam patterns
|
||||
self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS]
|
||||
|
||||
# Per-guild, per-user spam tracking
|
||||
# Structure: {guild_id: {user_id: UserSpamTracker}}
|
||||
self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict(
|
||||
lambda: defaultdict(UserSpamTracker)
|
||||
)
|
||||
|
||||
# Spam thresholds
|
||||
self.message_rate_limit = 5 # messages per window
|
||||
self.message_rate_window = 5 # seconds
|
||||
self.duplicate_threshold = 3 # same message count
|
||||
self.mention_limit = 5 # mentions per message
|
||||
self.mention_rate_limit = 10 # mentions per window
|
||||
self.mention_rate_window = 60 # seconds
|
||||
|
||||
def _get_content_hash(self, content: str) -> str:
|
||||
"""Get a normalized hash of message content for duplicate detection."""
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
normalized = re.sub(r"[^\w\s]", "", content.lower())
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
self, content: str, banned_words: list[BannedWord]
|
||||
) -> AutomodResult | None:
|
||||
"""Check message against banned words list."""
|
||||
content_lower = content.lower()
|
||||
|
||||
for banned in banned_words:
|
||||
matched = False
|
||||
|
||||
if banned.is_regex:
|
||||
try:
|
||||
if re.search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
except re.error:
|
||||
logger.warning(f"Invalid regex pattern: {banned.pattern}")
|
||||
continue
|
||||
else:
|
||||
if banned.pattern.lower() in content_lower:
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
result = AutomodResult(
|
||||
should_delete=True,
|
||||
reason=banned.reason or f"Matched banned word filter",
|
||||
matched_filter=f"banned_word:{banned.id}",
|
||||
)
|
||||
|
||||
if banned.action == "warn":
|
||||
result.should_warn = True
|
||||
elif banned.action == "strike":
|
||||
result.should_strike = True
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def check_scam_links(self, content: str) -> AutomodResult | None:
|
||||
"""Check message for scam/phishing patterns."""
|
||||
# Check for known scam patterns
|
||||
for pattern in self._scam_patterns:
|
||||
if pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason="Message matched known scam/phishing pattern",
|
||||
matched_filter="scam_pattern",
|
||||
)
|
||||
|
||||
# Check URLs for suspicious TLDs
|
||||
urls = URL_PATTERN.findall(content)
|
||||
for url in urls:
|
||||
url_lower = url.lower()
|
||||
for tld in SUSPICIOUS_TLDS:
|
||||
if tld in url_lower:
|
||||
# Additional check: is it trying to impersonate a known domain?
|
||||
impersonation_keywords = [
|
||||
"discord",
|
||||
"steam",
|
||||
"nitro",
|
||||
"gift",
|
||||
"free",
|
||||
"login",
|
||||
"verify",
|
||||
]
|
||||
if any(kw in url_lower for kw in impersonation_keywords):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Suspicious link detected: {url[:50]}",
|
||||
matched_filter="suspicious_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_spam(
|
||||
self, message: discord.Message, anti_spam_enabled: bool = True
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for spam behavior."""
|
||||
if not anti_spam_enabled:
|
||||
return None
|
||||
|
||||
guild_id = message.guild.id
|
||||
user_id = message.author.id
|
||||
tracker = self._spam_trackers[guild_id][user_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Cleanup old records
|
||||
tracker.cleanup()
|
||||
|
||||
# Check message rate
|
||||
content_hash = self._get_content_hash(message.content)
|
||||
tracker.messages.append(SpamRecord(content_hash, now))
|
||||
|
||||
# Rate limit check
|
||||
recent_window = now - timedelta(seconds=self.message_rate_window)
|
||||
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||
|
||||
if len(recent_messages) > self.message_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=60, # 1 minute timeout
|
||||
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
|
||||
matched_filter="rate_limit",
|
||||
)
|
||||
|
||||
# Duplicate message check
|
||||
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||
if duplicate_count >= self.duplicate_threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Duplicate message detected ({duplicate_count} times)",
|
||||
matched_filter="duplicate",
|
||||
)
|
||||
|
||||
# Mass mention check
|
||||
mention_count = len(message.mentions) + len(message.role_mentions)
|
||||
if message.mention_everyone:
|
||||
mention_count += 100 # Treat @everyone as many mentions
|
||||
|
||||
if mention_count > self.mention_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300, # 5 minute timeout
|
||||
reason=f"Mass mentions detected ({mention_count} mentions)",
|
||||
matched_filter="mass_mention",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||
"""Check for Discord invite links."""
|
||||
if allow_invites:
|
||||
return None
|
||||
|
||||
invite_pattern = re.compile(
|
||||
r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
if invite_pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Discord invite links are not allowed",
|
||||
matched_filter="invite_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_all_caps(
|
||||
self, content: str, threshold: float = 0.7, min_length: int = 10
|
||||
) -> AutomodResult | None:
|
||||
"""Check for excessive caps usage."""
|
||||
# Only check messages with enough letters
|
||||
letters = [c for c in content if c.isalpha()]
|
||||
if len(letters) < min_length:
|
||||
return None
|
||||
|
||||
caps_count = sum(1 for c in letters if c.isupper())
|
||||
caps_ratio = caps_count / len(letters)
|
||||
|
||||
if caps_ratio > threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Excessive caps usage",
|
||||
matched_filter="caps",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def reset_user_tracker(self, guild_id: int, user_id: int) -> None:
|
||||
"""Reset spam tracking for a user."""
|
||||
if guild_id in self._spam_trackers:
|
||||
self._spam_trackers[guild_id].pop(user_id, None)
|
||||
|
||||
def cleanup_guild(self, guild_id: int) -> None:
|
||||
"""Remove all tracking data for a guild."""
|
||||
self._spam_trackers.pop(guild_id, None)
|
||||
99
src/guardden/services/database.py
Normal file
99
src/guardden/services/database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Database connection and session management."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from guardden.config import Settings
|
||||
from guardden.models import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
"""Manages database connections and sessions."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self._engine = None
|
||||
self._session_factory = None
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize database connection pool."""
|
||||
db_url = self.settings.database_url.get_secret_value()
|
||||
|
||||
# Create SQLAlchemy async engine
|
||||
# Convert postgresql:// to postgresql+asyncpg://
|
||||
if db_url.startswith("postgresql://"):
|
||||
sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
else:
|
||||
sqlalchemy_url = db_url
|
||||
|
||||
self._engine = create_async_engine(
|
||||
sqlalchemy_url,
|
||||
pool_size=self.settings.database_pool_min,
|
||||
max_overflow=self.settings.database_pool_max - self.settings.database_pool_min,
|
||||
echo=self.settings.log_level == "DEBUG",
|
||||
)
|
||||
|
||||
self._session_factory = async_sessionmaker(
|
||||
self._engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Also create a raw asyncpg pool for performance-critical operations
|
||||
self._pool = await asyncpg.create_pool(
|
||||
db_url,
|
||||
min_size=self.settings.database_pool_min,
|
||||
max_size=self.settings.database_pool_max,
|
||||
)
|
||||
|
||||
logger.info("Database connection established")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close all database connections."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
|
||||
if self._engine:
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
|
||||
logger.info("Database connections closed")
|
||||
|
||||
async def create_tables(self) -> None:
|
||||
"""Create all database tables."""
|
||||
if not self._engine:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
logger.info("Database tables created")
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session context manager."""
|
||||
if not self._session_factory:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
@property
|
||||
def pool(self) -> asyncpg.Pool:
|
||||
"""Get the raw asyncpg connection pool."""
|
||||
if not self._pool:
|
||||
raise RuntimeError("Database not connected")
|
||||
return self._pool
|
||||
145
src/guardden/services/guild_config.py
Normal file
145
src/guardden/services/guild_config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Guild configuration service."""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
import discord
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from guardden.models import BannedWord, Guild, GuildSettings
|
||||
from guardden.services.database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuildConfigService:
|
||||
"""Manages guild configurations with caching."""
|
||||
|
||||
def __init__(self, database: Database) -> None:
|
||||
self.database = database
|
||||
self._cache: dict[int, GuildSettings] = {}
|
||||
|
||||
async def get_config(self, guild_id: int) -> GuildSettings | None:
|
||||
"""Get guild configuration, using cache if available."""
|
||||
if guild_id in self._cache:
|
||||
return self._cache[guild_id]
|
||||
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if settings:
|
||||
self._cache[guild_id] = settings
|
||||
|
||||
return settings
|
||||
|
||||
async def get_guild(self, guild_id: int) -> Guild | None:
|
||||
"""Get full guild data including settings and banned words."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(Guild)
|
||||
.options(selectinload(Guild.settings), selectinload(Guild.banned_words))
|
||||
.where(Guild.id == guild_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_guild(self, guild: discord.Guild) -> Guild:
|
||||
"""Create a new guild entry with default settings."""
|
||||
async with self.database.session() as session:
|
||||
# Check if guild already exists
|
||||
existing = await session.get(Guild, guild.id)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Create new guild
|
||||
db_guild = Guild(
|
||||
id=guild.id,
|
||||
name=guild.name,
|
||||
owner_id=guild.owner_id,
|
||||
)
|
||||
session.add(db_guild)
|
||||
await session.flush()
|
||||
|
||||
# Create default settings
|
||||
settings = GuildSettings(guild_id=guild.id)
|
||||
session.add(settings)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Created guild config for {guild.name} (ID: {guild.id})")
|
||||
return db_guild
|
||||
|
||||
async def update_settings(self, guild_id: int, **kwargs) -> GuildSettings | None:
|
||||
"""Update guild settings."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Invalidate cache
|
||||
self._cache.pop(guild_id, None)
|
||||
|
||||
return settings
|
||||
|
||||
def invalidate_cache(self, guild_id: int) -> None:
|
||||
"""Remove a guild from the cache."""
|
||||
self._cache.pop(guild_id, None)
|
||||
|
||||
async def get_banned_words(self, guild_id: int) -> list[BannedWord]:
|
||||
"""Get all banned words for a guild."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == guild_id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def add_banned_word(
|
||||
self,
|
||||
guild_id: int,
|
||||
pattern: str,
|
||||
added_by: int,
|
||||
is_regex: bool = False,
|
||||
action: str = "delete",
|
||||
reason: str | None = None,
|
||||
) -> BannedWord:
|
||||
"""Add a banned word to a guild."""
|
||||
async with self.database.session() as session:
|
||||
banned_word = BannedWord(
|
||||
guild_id=guild_id,
|
||||
pattern=pattern,
|
||||
is_regex=is_regex,
|
||||
action=action,
|
||||
reason=reason,
|
||||
added_by=added_by,
|
||||
)
|
||||
session.add(banned_word)
|
||||
await session.commit()
|
||||
return banned_word
|
||||
|
||||
async def remove_banned_word(self, guild_id: int, word_id: int) -> bool:
|
||||
"""Remove a banned word from a guild."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(BannedWord.id == word_id, BannedWord.guild_id == guild_id)
|
||||
)
|
||||
word = result.scalar_one_or_none()
|
||||
|
||||
if word:
|
||||
session.delete(word)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
300
src/guardden/services/ratelimit.py
Normal file
300
src/guardden/services/ratelimit.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Rate limiting service for command and action throttling."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitScope(str, Enum):
|
||||
"""Scope of rate limiting."""
|
||||
|
||||
USER = "user" # Per user globally
|
||||
MEMBER = "member" # Per user per guild
|
||||
CHANNEL = "channel" # Per channel
|
||||
GUILD = "guild" # Per guild
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitBucket:
|
||||
"""Tracks rate limit state for a single bucket."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: float
|
||||
requests: list[datetime] = field(default_factory=list)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove expired requests from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.window_seconds)
|
||||
self.requests = [r for r in self.requests if r > cutoff]
|
||||
|
||||
def is_limited(self) -> bool:
|
||||
"""Check if this bucket is rate limited."""
|
||||
self.cleanup()
|
||||
return len(self.requests) >= self.max_requests
|
||||
|
||||
def record(self) -> None:
|
||||
"""Record a request."""
|
||||
self.requests.append(datetime.now(timezone.utc))
|
||||
|
||||
def remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
self.cleanup()
|
||||
return max(0, self.max_requests - len(self.requests))
|
||||
|
||||
def reset_after(self) -> float:
|
||||
"""Get seconds until rate limit resets."""
|
||||
if not self.requests:
|
||||
return 0.0
|
||||
self.cleanup()
|
||||
if not self.requests:
|
||||
return 0.0
|
||||
oldest = min(self.requests)
|
||||
reset_time = oldest + timedelta(seconds=self.window_seconds)
|
||||
remaining = (reset_time - datetime.now(timezone.utc)).total_seconds()
|
||||
return max(0.0, remaining)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for a rate limit."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: float
|
||||
scope: RateLimitScope = RateLimitScope.MEMBER
|
||||
|
||||
def create_bucket(self) -> RateLimitBucket:
|
||||
return RateLimitBucket(
|
||||
max_requests=self.max_requests,
|
||||
window_seconds=self.window_seconds,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
is_limited: bool
|
||||
remaining: int
|
||||
reset_after: float
|
||||
bucket_key: str
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""General-purpose rate limiter."""
|
||||
|
||||
# Default rate limits for various actions
|
||||
DEFAULT_LIMITS = {
|
||||
"command": RateLimitConfig(5, 10, RateLimitScope.MEMBER), # 5 commands per 10s
|
||||
"moderation": RateLimitConfig(10, 60, RateLimitScope.MEMBER), # 10 mod actions per minute
|
||||
"verification": RateLimitConfig(3, 300, RateLimitScope.MEMBER), # 3 verifications per 5 min
|
||||
"message": RateLimitConfig(10, 10, RateLimitScope.MEMBER), # 10 messages per 10s
|
||||
"api_call": RateLimitConfig(
|
||||
30, 60, RateLimitScope.GUILD
|
||||
), # 30 API calls per minute per guild
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Buckets: {action: {bucket_key: RateLimitBucket}}
|
||||
self._buckets: dict[str, dict[str, RateLimitBucket]] = defaultdict(dict)
|
||||
self._configs: dict[str, RateLimitConfig] = dict(self.DEFAULT_LIMITS)
|
||||
|
||||
def configure(self, action: str, config: RateLimitConfig) -> None:
|
||||
"""Configure rate limit for an action."""
|
||||
self._configs[action] = config
|
||||
# Clear existing buckets for this action
|
||||
self._buckets[action].clear()
|
||||
|
||||
def _get_bucket_key(
|
||||
self,
|
||||
scope: RateLimitScope,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> str:
|
||||
"""Generate a bucket key based on scope."""
|
||||
if scope == RateLimitScope.USER:
|
||||
return f"user:{user_id}"
|
||||
elif scope == RateLimitScope.MEMBER:
|
||||
return f"member:{guild_id}:{user_id}"
|
||||
elif scope == RateLimitScope.CHANNEL:
|
||||
return f"channel:{channel_id}"
|
||||
elif scope == RateLimitScope.GUILD:
|
||||
return f"guild:{guild_id}"
|
||||
return f"unknown:{user_id}:{guild_id}"
|
||||
|
||||
def check(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check if an action is rate limited.
|
||||
|
||||
Does not record the request - use `acquire()` for that.
|
||||
"""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=999,
|
||||
reset_after=0,
|
||||
bucket_key="",
|
||||
)
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
bucket = self._buckets[action].get(bucket_key)
|
||||
|
||||
if not bucket:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=config.max_requests,
|
||||
reset_after=0,
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
return RateLimitResult(
|
||||
is_limited=bucket.is_limited(),
|
||||
remaining=bucket.remaining(),
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Attempt to acquire a rate limit slot.
|
||||
|
||||
Records the request if not limited.
|
||||
"""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=999,
|
||||
reset_after=0,
|
||||
bucket_key="",
|
||||
)
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
|
||||
if bucket_key not in self._buckets[action]:
|
||||
self._buckets[action][bucket_key] = config.create_bucket()
|
||||
|
||||
bucket = self._buckets[action][bucket_key]
|
||||
|
||||
if bucket.is_limited():
|
||||
return RateLimitResult(
|
||||
is_limited=True,
|
||||
remaining=0,
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
bucket.record()
|
||||
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=bucket.remaining(),
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> bool:
|
||||
"""Reset rate limit for a specific bucket."""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
return self._buckets[action].pop(bucket_key, None) is not None
|
||||
|
||||
def cleanup(self) -> int:
|
||||
"""Clean up empty and expired buckets. Returns count removed."""
|
||||
removed = 0
|
||||
for action in list(self._buckets.keys()):
|
||||
for key in list(self._buckets[action].keys()):
|
||||
bucket = self._buckets[action][key]
|
||||
bucket.cleanup()
|
||||
if not bucket.requests:
|
||||
del self._buckets[action][key]
|
||||
removed += 1
|
||||
return removed
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get or create the global rate limiter instance."""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
def ratelimit(
|
||||
action: str = "command",
|
||||
max_requests: int | None = None,
|
||||
window_seconds: float | None = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator for rate limiting commands.
|
||||
|
||||
Usage:
|
||||
@ratelimit("moderation", max_requests=5, window_seconds=60)
|
||||
async def my_command(self, ctx):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
async def wrapper(self, ctx, *args, **kwargs):
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# Configure if custom limits provided
|
||||
if max_requests is not None and window_seconds is not None:
|
||||
limiter.configure(
|
||||
action,
|
||||
RateLimitConfig(max_requests, window_seconds, RateLimitScope.MEMBER),
|
||||
)
|
||||
|
||||
result = limiter.acquire(
|
||||
action,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
|
||||
if result.is_limited:
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {result.reset_after:.1f} seconds.",
|
||||
delete_after=5,
|
||||
)
|
||||
return
|
||||
|
||||
return await func(self, ctx, *args, **kwargs)
|
||||
|
||||
# Preserve function metadata
|
||||
wrapper.__name__ = func.__name__
|
||||
wrapper.__doc__ = func.__doc__
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
300
src/guardden/services/verification.py
Normal file
300
src/guardden/services/verification.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Verification service for new member challenges."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChallengeType(str, Enum):
|
||||
"""Types of verification challenges."""
|
||||
|
||||
BUTTON = "button" # Simple button click
|
||||
CAPTCHA = "captcha" # Text-based captcha
|
||||
MATH = "math" # Simple math problem
|
||||
EMOJI = "emoji" # Select correct emoji
|
||||
QUESTIONS = "questions" # Custom questions
|
||||
|
||||
|
||||
@dataclass
|
||||
class Challenge:
|
||||
"""Represents a verification challenge."""
|
||||
|
||||
challenge_type: ChallengeType
|
||||
question: str
|
||||
answer: str
|
||||
options: list[str] = field(default_factory=list) # For multiple choice
|
||||
expires_at: datetime = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
)
|
||||
attempts: int = 0
|
||||
max_attempts: int = 3
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def check_answer(self, response: str) -> bool:
|
||||
"""Check if the response is correct."""
|
||||
self.attempts += 1
|
||||
return response.strip().lower() == self.answer.lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingVerification:
|
||||
"""Tracks a pending verification for a user."""
|
||||
|
||||
user_id: int
|
||||
guild_id: int
|
||||
challenge: Challenge
|
||||
message_id: int | None = None
|
||||
channel_id: int | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class ChallengeGenerator(ABC):
|
||||
"""Abstract base class for challenge generators."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> Challenge:
|
||||
"""Generate a new challenge."""
|
||||
pass
|
||||
|
||||
|
||||
class ButtonChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates simple button click challenges."""
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.BUTTON,
|
||||
question="Click the button below to verify you're human.",
|
||||
answer="verified",
|
||||
)
|
||||
|
||||
|
||||
class CaptchaChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates text-based captcha challenges."""
|
||||
|
||||
def __init__(self, length: int = 6) -> None:
|
||||
self.length = length
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
# Generate random alphanumeric code (avoiding confusing chars)
|
||||
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
code = "".join(random.choices(chars, k=self.length))
|
||||
|
||||
# Create visual representation with some obfuscation
|
||||
visual = self._create_visual(code)
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.CAPTCHA,
|
||||
question=f"Enter the code shown below:\n```\n{visual}\n```",
|
||||
answer=code,
|
||||
)
|
||||
|
||||
def _create_visual(self, code: str) -> str:
|
||||
"""Create a simple text-based visual captcha."""
|
||||
lines = []
|
||||
# Add some noise characters
|
||||
noise_chars = ".-*~^"
|
||||
|
||||
for _ in range(2):
|
||||
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||
|
||||
# Add the code with spacing
|
||||
spaced = " ".join(code)
|
||||
lines.append(spaced)
|
||||
|
||||
for _ in range(2):
|
||||
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class MathChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates simple math problem challenges."""
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
# Generate simple addition/subtraction/multiplication
|
||||
operation = random.choice(["+", "-", "*"])
|
||||
|
||||
if operation == "*":
|
||||
a = random.randint(2, 10)
|
||||
b = random.randint(2, 10)
|
||||
else:
|
||||
a = random.randint(10, 50)
|
||||
b = random.randint(1, 20)
|
||||
|
||||
if operation == "+":
|
||||
answer = a + b
|
||||
elif operation == "-":
|
||||
# Ensure positive result
|
||||
if b > a:
|
||||
a, b = b, a
|
||||
answer = a - b
|
||||
else:
|
||||
answer = a * b
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.MATH,
|
||||
question=f"Solve this math problem: **{a} {operation} {b} = ?**",
|
||||
answer=str(answer),
|
||||
)
|
||||
|
||||
|
||||
class EmojiChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates emoji selection challenges."""
|
||||
|
||||
EMOJI_SETS = [
|
||||
("animals", ["🐶", "🐱", "🐭", "🐹", "🐰", "🦊", "🐻", "🐼"]),
|
||||
("fruits", ["🍎", "🍐", "🍊", "🍋", "🍌", "🍉", "🍇", "🍓"]),
|
||||
("weather", ["☀️", "🌙", "⭐", "🌧️", "❄️", "🌈", "⚡", "🌪️"]),
|
||||
("sports", ["⚽", "🏀", "🏈", "⚾", "🎾", "🏐", "🏉", "🎱"]),
|
||||
]
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
category, emojis = random.choice(self.EMOJI_SETS)
|
||||
target = random.choice(emojis)
|
||||
|
||||
# Create options with the target and some others
|
||||
options = [target]
|
||||
other_emojis = [e for e in emojis if e != target]
|
||||
options.extend(random.sample(other_emojis, min(3, len(other_emojis))))
|
||||
random.shuffle(options)
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.EMOJI,
|
||||
question=f"Select the {self._emoji_name(target)} emoji:",
|
||||
answer=target,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def _emoji_name(self, emoji: str) -> str:
|
||||
"""Get a description of the emoji."""
|
||||
names = {
|
||||
"🐶": "dog",
|
||||
"🐱": "cat",
|
||||
"🐭": "mouse",
|
||||
"🐹": "hamster",
|
||||
"🐰": "rabbit",
|
||||
"🦊": "fox",
|
||||
"🐻": "bear",
|
||||
"🐼": "panda",
|
||||
"🍎": "apple",
|
||||
"🍐": "pear",
|
||||
"🍊": "orange",
|
||||
"🍋": "lemon",
|
||||
"🍌": "banana",
|
||||
"🍉": "watermelon",
|
||||
"🍇": "grapes",
|
||||
"🍓": "strawberry",
|
||||
"☀️": "sun",
|
||||
"🌙": "moon",
|
||||
"⭐": "star",
|
||||
"🌧️": "rain",
|
||||
"❄️": "snowflake",
|
||||
"🌈": "rainbow",
|
||||
"⚡": "lightning",
|
||||
"🌪️": "tornado",
|
||||
"⚽": "soccer ball",
|
||||
"🏀": "basketball",
|
||||
"🏈": "football",
|
||||
"⚾": "baseball",
|
||||
"🎾": "tennis",
|
||||
"🏐": "volleyball",
|
||||
"🏉": "rugby",
|
||||
"🎱": "pool ball",
|
||||
}
|
||||
return names.get(emoji, "correct")
|
||||
|
||||
|
||||
class VerificationService:
|
||||
"""Service for managing member verification."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Pending verifications: {(guild_id, user_id): PendingVerification}
|
||||
self._pending: dict[tuple[int, int], PendingVerification] = {}
|
||||
|
||||
# Challenge generators
|
||||
self._generators: dict[ChallengeType, ChallengeGenerator] = {
|
||||
ChallengeType.BUTTON: ButtonChallengeGenerator(),
|
||||
ChallengeType.CAPTCHA: CaptchaChallengeGenerator(),
|
||||
ChallengeType.MATH: MathChallengeGenerator(),
|
||||
ChallengeType.EMOJI: EmojiChallengeGenerator(),
|
||||
}
|
||||
|
||||
def create_challenge(
|
||||
self,
|
||||
user_id: int,
|
||||
guild_id: int,
|
||||
challenge_type: ChallengeType = ChallengeType.BUTTON,
|
||||
) -> PendingVerification:
|
||||
"""Create a new verification challenge for a user."""
|
||||
generator = self._generators.get(challenge_type)
|
||||
if not generator:
|
||||
generator = self._generators[ChallengeType.BUTTON]
|
||||
|
||||
challenge = generator.generate()
|
||||
pending = PendingVerification(
|
||||
user_id=user_id,
|
||||
guild_id=guild_id,
|
||||
challenge=challenge,
|
||||
)
|
||||
|
||||
self._pending[(guild_id, user_id)] = pending
|
||||
return pending
|
||||
|
||||
def get_pending(self, guild_id: int, user_id: int) -> PendingVerification | None:
|
||||
"""Get a pending verification for a user."""
|
||||
return self._pending.get((guild_id, user_id))
|
||||
|
||||
def verify(self, guild_id: int, user_id: int, response: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Attempt to verify a user's response.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
pending = self._pending.get((guild_id, user_id))
|
||||
|
||||
if not pending:
|
||||
return False, "No pending verification found."
|
||||
|
||||
if pending.challenge.is_expired:
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return False, "Verification expired. Please request a new one."
|
||||
|
||||
if pending.challenge.attempts >= pending.challenge.max_attempts:
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return False, "Too many failed attempts. Please request a new verification."
|
||||
|
||||
if pending.challenge.check_answer(response):
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return True, "Verification successful!"
|
||||
|
||||
remaining = pending.challenge.max_attempts - pending.challenge.attempts
|
||||
return False, f"Incorrect. {remaining} attempt(s) remaining."
|
||||
|
||||
def cancel(self, guild_id: int, user_id: int) -> bool:
|
||||
"""Cancel a pending verification."""
|
||||
return self._pending.pop((guild_id, user_id), None) is not None
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired verifications. Returns count of removed."""
|
||||
expired = [key for key, pending in self._pending.items() if pending.challenge.is_expired]
|
||||
for key in expired:
|
||||
self._pending.pop(key, None)
|
||||
return len(expired)
|
||||
|
||||
def get_pending_count(self, guild_id: int) -> int:
|
||||
"""Get count of pending verifications for a guild."""
|
||||
return sum(1 for (gid, _) in self._pending if gid == guild_id)
|
||||
5
src/guardden/utils/__init__.py
Normal file
5
src/guardden/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Utility functions for GuardDen."""
|
||||
|
||||
from guardden.utils.logging import setup_logging
|
||||
|
||||
__all__ = ["setup_logging"]
|
||||
27
src/guardden/utils/logging.py
Normal file
27
src/guardden/utils/logging.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Logging configuration for GuardDen."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def setup_logging(level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO") -> None:
|
||||
"""Configure logging for the application."""
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||
date_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, level),
|
||||
format=log_format,
|
||||
datefmt=date_format,
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.http").setLevel(logging.WARNING)
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(
|
||||
logging.DEBUG if level == "DEBUG" else logging.WARNING
|
||||
)
|
||||
Reference in New Issue
Block a user