quick commit
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s

This commit is contained in:
2026-01-17 20:24:43 +01:00
parent 95cc3cdb8f
commit 831eed8dbc
82 changed files with 8860 additions and 167 deletions

View File

@@ -1,6 +1,8 @@
"""Main bot class for GuardDen."""
import inspect
import logging
import platform
from typing import TYPE_CHECKING
import discord
@@ -9,11 +11,14 @@ from discord.ext import commands
from guardden.config import Settings
from guardden.services.ai import AIProvider, create_ai_provider
from guardden.services.database import Database
from guardden.services.ratelimit import RateLimiter
from guardden.utils.logging import get_logger, get_logging_middleware, setup_logging
if TYPE_CHECKING:
from guardden.services.guild_config import GuildConfigService
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
logging_middleware = get_logging_middleware()
class GuardDen(commands.Bot):
@@ -37,6 +42,7 @@ class GuardDen(commands.Bot):
self.database = Database(settings)
self.guild_config: "GuildConfigService | None" = None
self.ai_provider: AIProvider | None = None
self.rate_limiter = RateLimiter()
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
"""Get the command prefix for a guild."""
@@ -50,10 +56,32 @@ class GuardDen(commands.Bot):
return [self.settings.discord_prefix]
def is_guild_allowed(self, guild_id: int) -> bool:
"""Check if a guild is allowed to run the bot."""
return not self.settings.allowed_guilds or guild_id in self.settings.allowed_guilds
def is_owner_allowed(self, user_id: int) -> bool:
"""Check if a user is allowed elevated access."""
return not self.settings.owner_ids or user_id in self.settings.owner_ids
async def setup_hook(self) -> None:
"""Called when the bot is starting up."""
logger.info("Starting GuardDen setup...")
self.settings.validate_configuration()
logger.info(
"Configuration loaded: ai_provider=%s, log_level=%s, allowed_guilds=%s, owner_ids=%s",
self.settings.ai_provider,
self.settings.log_level,
self.settings.allowed_guilds or "all",
self.settings.owner_ids or "admins",
)
logger.info(
"Runtime versions: python=%s, discord.py=%s",
platform.python_version(),
discord.__version__,
)
# Connect to database
await self.database.connect()
await self.database.create_tables()
@@ -86,14 +114,27 @@ class GuardDen(commands.Bot):
"guardden.cogs.automod",
"guardden.cogs.ai_moderation",
"guardden.cogs.verification",
"guardden.cogs.health",
]
failed_cogs = []
for cog in cogs:
try:
await self.load_extension(cog)
logger.info(f"Loaded cog: {cog}")
except ImportError as e:
logger.error(f"Failed to import cog {cog}: {e}")
failed_cogs.append(cog)
except commands.ExtensionError as e:
logger.error(f"Discord extension error loading {cog}: {e}")
failed_cogs.append(cog)
except Exception as e:
logger.error(f"Failed to load cog {cog}: {e}")
logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True)
failed_cogs.append(cog)
if failed_cogs:
logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}")
# Don't fail startup if some cogs fail to load, but log it prominently
async def on_ready(self) -> None:
"""Called when the bot is fully connected and ready."""
@@ -103,9 +144,30 @@ class GuardDen(commands.Bot):
# Ensure all guilds have database entries
if self.guild_config:
initialized = 0
failed_guilds = []
for guild in self.guilds:
await self.guild_config.create_guild(guild)
logger.info(f"Initialized config for {len(self.guilds)} guild(s)")
try:
if not self.is_guild_allowed(guild.id):
logger.warning(
"Leaving unauthorized guild %s (ID: %s)", guild.name, guild.id
)
try:
await guild.leave()
except discord.HTTPException as e:
logger.error(f"Failed to leave guild {guild.id}: {e}")
continue
await self.guild_config.create_guild(guild)
initialized += 1
except Exception as e:
logger.error(f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", exc_info=True)
failed_guilds.append(guild.id)
logger.info("Initialized config for %s guild(s)", initialized)
if failed_guilds:
logger.warning(f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}")
# Set presence
activity = discord.Activity(
@@ -117,6 +179,7 @@ class GuardDen(commands.Bot):
async def close(self) -> None:
"""Clean up when shutting down."""
logger.info("Shutting down GuardDen...")
await self._shutdown_cogs()
if self.ai_provider:
try:
await self.ai_provider.close()
@@ -125,10 +188,30 @@ class GuardDen(commands.Bot):
await self.database.disconnect()
await super().close()
async def _shutdown_cogs(self) -> None:
"""Ensure cogs can clean up background tasks."""
for cog in list(self.cogs.values()):
unload = getattr(cog, "cog_unload", None)
if unload is None:
continue
try:
result = unload()
if inspect.isawaitable(result):
await result
except Exception as e:
logger.error("Error during cog unload (%s): %s", cog.qualified_name, e)
async def on_guild_join(self, guild: discord.Guild) -> None:
"""Called when the bot joins a new guild."""
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
if not self.is_guild_allowed(guild.id):
logger.warning(
"Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id
)
await guild.leave()
return
if self.guild_config:
await self.guild_config.create_guild(guild)

View File

@@ -7,6 +7,7 @@ import discord
from discord.ext import commands
from guardden.bot import GuardDen
from guardden.utils.ratelimit import RateLimitExceeded
logger = logging.getLogger(__name__)
@@ -17,12 +18,32 @@ class Admin(commands.Cog):
def __init__(self, bot: GuardDen) -> None:
self.bot = bot
async def cog_check(self, ctx: commands.Context) -> bool:
def cog_check(self, ctx: commands.Context) -> bool:
"""Ensure only administrators can use these commands."""
if not ctx.guild:
return False
if not self.bot.is_owner_allowed(ctx.author.id):
return False
return ctx.author.guild_permissions.administrator
async def cog_before_invoke(self, ctx: commands.Context) -> None:
if not ctx.command:
return
result = self.bot.rate_limiter.acquire_command(
ctx.command.qualified_name,
user_id=ctx.author.id,
guild_id=ctx.guild.id if ctx.guild else None,
channel_id=ctx.channel.id,
)
if result.is_limited:
raise RateLimitExceeded(result.reset_after)
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
if isinstance(error, RateLimitExceeded):
await ctx.send(
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
)
@commands.group(name="config", invoke_without_command=True)
@commands.guild_only()
async def config(self, ctx: commands.Context) -> None:

View File

@@ -1,7 +1,6 @@
"""AI-powered moderation cog."""
import logging
import re
from collections import deque
from datetime import datetime, timedelta, timezone
@@ -9,16 +8,13 @@ import discord
from discord.ext import commands
from guardden.bot import GuardDen
from guardden.models import ModerationLog
from guardden.services.ai.base import ContentCategory, ModerationResult
from guardden.services.automod import URL_PATTERN, is_allowed_domain, normalize_domain
from guardden.utils.ratelimit import RateLimitExceeded
logger = logging.getLogger(__name__)
# URL pattern for extraction
URL_PATTERN = re.compile(
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*",
re.IGNORECASE,
)
class AIModeration(commands.Cog):
"""AI-powered content moderation."""
@@ -28,6 +24,30 @@ class AIModeration(commands.Cog):
# Track recently analyzed messages to avoid duplicates (deque auto-removes oldest)
self._analyzed_messages: deque[int] = deque(maxlen=1000)
def cog_check(self, ctx: commands.Context) -> bool:
"""Optional owner allowlist for AI commands."""
if not ctx.guild:
return False
return self.bot.is_owner_allowed(ctx.author.id)
async def cog_before_invoke(self, ctx: commands.Context) -> None:
if not ctx.command:
return
result = self.bot.rate_limiter.acquire_command(
ctx.command.qualified_name,
user_id=ctx.author.id,
guild_id=ctx.guild.id if ctx.guild else None,
channel_id=ctx.channel.id,
)
if result.is_limited:
raise RateLimitExceeded(result.reset_after)
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
if isinstance(error, RateLimitExceeded):
await ctx.send(
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
)
def _should_analyze(self, message: discord.Message) -> bool:
"""Determine if a message should be analyzed by AI."""
# Skip if already analyzed
@@ -67,21 +87,37 @@ class AIModeration(commands.Cog):
threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30
if result.severity < threshold:
logger.debug(
f"AI flagged content but below threshold: "
f"severity={result.severity}, threshold={threshold}"
"AI flagged content but below threshold: severity=%s, threshold=%s",
result.severity,
threshold,
)
return
if result.confidence < config.ai_confidence_threshold:
logger.debug(
"AI flagged content but below confidence threshold: confidence=%s, threshold=%s",
result.confidence,
config.ai_confidence_threshold,
)
return
log_only = config.ai_log_only
# Determine action based on suggested action and severity
should_delete = result.suggested_action in ("delete", "timeout", "ban")
should_timeout = result.suggested_action in ("timeout", "ban") and result.severity > 70
should_delete = not log_only and result.suggested_action in ("delete", "timeout", "ban")
should_timeout = (
not log_only
and result.suggested_action in ("timeout", "ban")
and result.severity > 70
)
timeout_duration: int | None = None
# Delete message if needed
if should_delete:
try:
await message.delete()
except discord.Forbidden:
logger.warning(f"Cannot delete message: missing permissions")
logger.warning("Cannot delete message: missing permissions")
except discord.NotFound:
pass
@@ -96,8 +132,19 @@ class AIModeration(commands.Cog):
except discord.Forbidden:
pass
await self._log_ai_db_action(
message,
result,
analysis_type,
log_only=log_only,
timeout_duration=timeout_duration,
)
# Log to mod channel
await self._log_ai_action(message, result, analysis_type)
await self._log_ai_action(message, result, analysis_type, log_only=log_only)
if log_only:
return
# Notify user
try:
@@ -122,6 +169,7 @@ class AIModeration(commands.Cog):
message: discord.Message,
result: ModerationResult,
analysis_type: str,
log_only: bool = False,
) -> None:
"""Log an AI moderation action."""
config = await self.bot.guild_config.get_config(message.guild.id)
@@ -142,9 +190,10 @@ class AIModeration(commands.Cog):
icon_url=message.author.display_avatar.url,
)
action_label = "log-only" if log_only else result.suggested_action
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
embed.add_field(name="Action", value=result.suggested_action, inline=True)
embed.add_field(name="Action", value=action_label, inline=True)
categories = ", ".join(cat.value for cat in result.categories)
embed.add_field(name="Categories", value=categories or "None", inline=False)
@@ -160,10 +209,43 @@ class AIModeration(commands.Cog):
await channel.send(embed=embed)
async def _log_ai_db_action(
self,
message: discord.Message,
result: ModerationResult,
analysis_type: str,
log_only: bool,
timeout_duration: int | None,
) -> None:
"""Log an AI moderation action to the database."""
action = "ai_log" if log_only else f"ai_{result.suggested_action}"
reason = result.explanation or f"AI moderation flagged content ({analysis_type})"
expires_at = None
if timeout_duration:
expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout_duration)
async with self.bot.database.session() as session:
entry = ModerationLog(
guild_id=message.guild.id,
target_id=message.author.id,
target_name=str(message.author),
moderator_id=self.bot.user.id if self.bot.user else 0,
moderator_name=str(self.bot.user) if self.bot.user else "GuardDen",
action=action,
reason=reason,
duration=timeout_duration,
expires_at=expires_at,
channel_id=message.channel.id,
message_id=message.id,
message_content=message.content,
is_automatic=True,
)
session.add(entry)
@commands.Cog.listener()
async def on_message(self, message: discord.Message) -> None:
"""Analyze messages with AI moderation."""
print(f"[AI_MOD] Received message from {message.author}", flush=True)
logger.debug("AI moderation received message from %s", message.author)
# Skip bot messages early
if message.author.bot:
@@ -247,7 +329,11 @@ class AIModeration(commands.Cog):
# Analyze URLs for phishing
urls = URL_PATTERN.findall(message.content)
allowlist = {normalize_domain(domain) for domain in config.scam_allowlist if domain}
for url in urls[:3]: # Limit to first 3 URLs
hostname = normalize_domain(url)
if allowlist and is_allowed_domain(hostname, allowlist):
continue
phishing_result = await self.bot.ai_provider.analyze_phishing(
url=url,
message_content=message.content,
@@ -291,6 +377,16 @@ class AIModeration(commands.Cog):
value=f"{config.ai_sensitivity}/100" if config else "50/100",
inline=True,
)
embed.add_field(
name="Confidence Threshold",
value=f"{config.ai_confidence_threshold:.2f}" if config else "0.70",
inline=True,
)
embed.add_field(
name="Log Only",
value="✅ Enabled" if config and config.ai_log_only else "❌ Disabled",
inline=True,
)
embed.add_field(
name="AI Provider",
value=self.bot.settings.ai_provider.capitalize(),
@@ -333,6 +429,27 @@ class AIModeration(commands.Cog):
await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level)
await ctx.send(f"AI sensitivity set to {level}/100.")
@ai_cmd.command(name="threshold")
@commands.has_permissions(administrator=True)
@commands.guild_only()
async def ai_threshold(self, ctx: commands.Context, value: float) -> None:
"""Set AI confidence threshold (0.0-1.0)."""
if not 0.0 <= value <= 1.0:
await ctx.send("Threshold must be between 0.0 and 1.0.")
return
await self.bot.guild_config.update_settings(ctx.guild.id, ai_confidence_threshold=value)
await ctx.send(f"AI confidence threshold set to {value:.2f}.")
@ai_cmd.command(name="logonly")
@commands.has_permissions(administrator=True)
@commands.guild_only()
async def ai_logonly(self, ctx: commands.Context, enabled: bool) -> None:
"""Enable or disable log-only mode for AI moderation."""
await self.bot.guild_config.update_settings(ctx.guild.id, ai_log_only=enabled)
status = "enabled" if enabled else "disabled"
await ctx.send(f"AI log-only mode {status}.")
@ai_cmd.command(name="nsfw")
@commands.has_permissions(administrator=True)
@commands.guild_only()

View File

@@ -2,12 +2,21 @@
import logging
from datetime import datetime, timedelta, timezone
from typing import Literal
import discord
from discord.ext import commands
from sqlalchemy import func, select
from guardden.bot import GuardDen
from guardden.services.automod import AutomodResult, AutomodService
from guardden.models import ModerationLog, Strike
from guardden.services.automod import (
AutomodResult,
AutomodService,
SpamConfig,
normalize_domain,
)
from guardden.utils.ratelimit import RateLimitExceeded
logger = logging.getLogger(__name__)
@@ -19,6 +28,135 @@ class Automod(commands.Cog):
self.bot = bot
self.automod = AutomodService()
def cog_check(self, ctx: commands.Context) -> bool:
"""Optional owner allowlist for automod commands."""
if not ctx.guild:
return False
return self.bot.is_owner_allowed(ctx.author.id)
async def cog_before_invoke(self, ctx: commands.Context) -> None:
if not ctx.command:
return
result = self.bot.rate_limiter.acquire_command(
ctx.command.qualified_name,
user_id=ctx.author.id,
guild_id=ctx.guild.id if ctx.guild else None,
channel_id=ctx.channel.id,
)
if result.is_limited:
raise RateLimitExceeded(result.reset_after)
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
if isinstance(error, RateLimitExceeded):
await ctx.send(
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
)
def _spam_config(self, config) -> SpamConfig:
if not config:
return self.automod.default_spam_config
return SpamConfig(
message_rate_limit=config.message_rate_limit,
message_rate_window=config.message_rate_window,
duplicate_threshold=config.duplicate_threshold,
mention_limit=config.mention_limit,
mention_rate_limit=config.mention_rate_limit,
mention_rate_window=config.mention_rate_window,
)
async def _get_strike_count(self, guild_id: int, user_id: int) -> int:
async with self.bot.database.session() as session:
result = await session.execute(
select(func.sum(Strike.points)).where(
Strike.guild_id == guild_id,
Strike.user_id == user_id,
Strike.is_active == True,
)
)
total = result.scalar()
return total or 0
async def _add_strike(
self,
guild: discord.Guild,
member: discord.Member,
reason: str,
) -> int:
async with self.bot.database.session() as session:
strike = Strike(
guild_id=guild.id,
user_id=member.id,
user_name=str(member),
moderator_id=self.bot.user.id if self.bot.user else 0,
reason=reason,
points=1,
)
session.add(strike)
return await self._get_strike_count(guild.id, member.id)
async def _apply_strike_actions(
self,
member: discord.Member,
total_strikes: int,
config,
) -> None:
if not config or not config.strike_actions:
return
for threshold, action_config in sorted(
config.strike_actions.items(), key=lambda item: int(item[0]), reverse=True
):
if total_strikes < int(threshold):
continue
action = action_config.get("action")
if action == "ban":
await member.ban(reason=f"Automod: {total_strikes} strikes")
elif action == "kick":
await member.kick(reason=f"Automod: {total_strikes} strikes")
elif action == "timeout":
duration = action_config.get("duration", 3600)
await member.timeout(
timedelta(seconds=duration),
reason=f"Automod: {total_strikes} strikes",
)
break
async def _log_database_action(
self,
message: discord.Message,
result: AutomodResult,
) -> None:
async with self.bot.database.session() as session:
action = "delete"
if result.should_timeout:
action = "timeout"
elif result.should_strike:
action = "strike"
elif result.should_warn:
action = "warn"
expires_at = None
if result.timeout_duration:
expires_at = datetime.now(timezone.utc) + timedelta(seconds=result.timeout_duration)
log_entry = ModerationLog(
guild_id=message.guild.id,
target_id=message.author.id,
target_name=str(message.author),
moderator_id=self.bot.user.id if self.bot.user else 0,
moderator_name=str(self.bot.user) if self.bot.user else "GuardDen",
action=action,
reason=result.reason,
duration=result.timeout_duration or None,
expires_at=expires_at,
channel_id=message.channel.id,
message_id=message.id,
message_content=message.content,
is_automatic=True,
)
session.add(log_entry)
async def _handle_violation(
self,
message: discord.Message,
@@ -45,8 +183,15 @@ class Automod(commands.Cog):
logger.warning(f"Cannot timeout {message.author}: missing permissions")
# Log the action
await self._log_database_action(message, result)
await self._log_automod_action(message, result)
# Apply strike escalation if configured
if (result.should_warn or result.should_strike) and isinstance(message.author, discord.Member):
total = await self._add_strike(message.guild, message.author, result.reason)
config = await self.bot.guild_config.get_config(message.guild.id)
await self._apply_strike_actions(message.author, total, config)
# Notify the user via DM
try:
embed = discord.Embed(
@@ -136,13 +281,22 @@ class Automod(commands.Cog):
if banned_words:
result = self.automod.check_banned_words(message.content, banned_words)
spam_config = self._spam_config(config)
# Check scam links (if link filter enabled)
if not result and config.link_filter_enabled:
result = self.automod.check_scam_links(message.content)
result = self.automod.check_scam_links(
message.content,
allowlist=config.scam_allowlist,
)
# Check spam
if not result and config.anti_spam_enabled:
result = self.automod.check_spam(message, anti_spam_enabled=True)
result = self.automod.check_spam(
message,
anti_spam_enabled=True,
spam_config=spam_config,
)
# Check invite links (if link filter enabled)
if not result and config.link_filter_enabled:
@@ -194,20 +348,27 @@ class Automod(commands.Cog):
inline=True,
)
spam_config = self._spam_config(config)
# Show thresholds
embed.add_field(
name="Rate Limit",
value=f"{self.automod.message_rate_limit} msgs / {self.automod.message_rate_window}s",
value=f"{spam_config.message_rate_limit} msgs / {spam_config.message_rate_window}s",
inline=True,
)
embed.add_field(
name="Duplicate Threshold",
value=f"{self.automod.duplicate_threshold} same messages",
value=f"{spam_config.duplicate_threshold} same messages",
inline=True,
)
embed.add_field(
name="Mention Limit",
value=f"{self.automod.mention_limit} per message",
value=f"{spam_config.mention_limit} per message",
inline=True,
)
embed.add_field(
name="Mention Rate",
value=f"{spam_config.mention_rate_limit} mentions / {spam_config.mention_rate_window}s",
inline=True,
)
@@ -220,6 +381,82 @@ class Automod(commands.Cog):
await ctx.send(embed=embed)
@automod_cmd.command(name="threshold")
@commands.has_permissions(administrator=True)
@commands.guild_only()
async def automod_threshold(
self,
ctx: commands.Context,
setting: Literal[
"message_rate_limit",
"message_rate_window",
"duplicate_threshold",
"mention_limit",
"mention_rate_limit",
"mention_rate_window",
],
value: int,
) -> None:
"""Update a single automod threshold."""
if value <= 0:
await ctx.send("Threshold values must be positive.")
return
await self.bot.guild_config.update_settings(ctx.guild.id, **{setting: value})
await ctx.send(f"Updated `{setting}` to {value}.")
@automod_cmd.group(name="allowlist", invoke_without_command=True)
@commands.has_permissions(administrator=True)
@commands.guild_only()
async def automod_allowlist(self, ctx: commands.Context) -> None:
"""Show the scam link allowlist."""
config = await self.bot.guild_config.get_config(ctx.guild.id)
allowlist = sorted(config.scam_allowlist) if config else []
if not allowlist:
await ctx.send("No allowlisted domains configured.")
return
formatted = "\n".join(f"- `{domain}`" for domain in allowlist[:20])
await ctx.send(f"Allowed domains:\n{formatted}")
@automod_allowlist.command(name="add")
@commands.has_permissions(administrator=True)
@commands.guild_only()
async def automod_allowlist_add(self, ctx: commands.Context, domain: str) -> None:
"""Add a domain to the scam link allowlist."""
normalized = normalize_domain(domain)
if not normalized:
await ctx.send("Provide a valid domain or URL to allowlist.")
return
config = await self.bot.guild_config.get_config(ctx.guild.id)
allowlist = list(config.scam_allowlist) if config else []
if normalized in allowlist:
await ctx.send(f"`{normalized}` is already allowlisted.")
return
allowlist.append(normalized)
await self.bot.guild_config.update_settings(ctx.guild.id, scam_allowlist=allowlist)
await ctx.send(f"Added `{normalized}` to the allowlist.")
@automod_allowlist.command(name="remove")
@commands.has_permissions(administrator=True)
@commands.guild_only()
async def automod_allowlist_remove(self, ctx: commands.Context, domain: str) -> None:
"""Remove a domain from the scam link allowlist."""
normalized = normalize_domain(domain)
config = await self.bot.guild_config.get_config(ctx.guild.id)
allowlist = list(config.scam_allowlist) if config else []
if normalized not in allowlist:
await ctx.send(f"`{normalized}` is not in the allowlist.")
return
allowlist.remove(normalized)
await self.bot.guild_config.update_settings(ctx.guild.id, scam_allowlist=allowlist)
await ctx.send(f"Removed `{normalized}` from the allowlist.")
@automod_cmd.command(name="test")
@commands.has_permissions(administrator=True)
@commands.guild_only()
@@ -235,7 +472,7 @@ class Automod(commands.Cog):
results.append(f"**Banned Words**: {result.reason}")
# Check scam links
result = self.automod.check_scam_links(text)
result = self.automod.check_scam_links(text, allowlist=config.scam_allowlist if config else [])
if result:
results.append(f"**Scam Detection**: {result.reason}")

View File

@@ -0,0 +1,71 @@
"""Health check commands."""
import logging
import discord
from discord.ext import commands
from sqlalchemy import select
from guardden.bot import GuardDen
from guardden.utils.ratelimit import RateLimitExceeded
logger = logging.getLogger(__name__)
class Health(commands.Cog):
"""Health checks for the bot."""
def __init__(self, bot: GuardDen) -> None:
self.bot = bot
def cog_check(self, ctx: commands.Context) -> bool:
if not ctx.guild:
return False
if not self.bot.is_owner_allowed(ctx.author.id):
return False
return ctx.author.guild_permissions.administrator
async def cog_before_invoke(self, ctx: commands.Context) -> None:
if not ctx.command:
return
result = self.bot.rate_limiter.acquire_command(
ctx.command.qualified_name,
user_id=ctx.author.id,
guild_id=ctx.guild.id if ctx.guild else None,
channel_id=ctx.channel.id,
)
if result.is_limited:
raise RateLimitExceeded(result.reset_after)
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
if isinstance(error, RateLimitExceeded):
await ctx.send(
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
)
@commands.command(name="health")
@commands.guild_only()
async def health(self, ctx: commands.Context) -> None:
"""Check database and AI provider health."""
db_status = "ok"
try:
async with self.bot.database.session() as session:
await session.execute(select(1))
except Exception as exc: # pragma: no cover - external dependency
logger.exception("Health check database failure")
db_status = f"error: {exc}"
ai_status = "disabled"
if self.bot.settings.ai_provider != "none":
ai_status = "ok" if self.bot.ai_provider else "unavailable"
embed = discord.Embed(title="GuardDen Health", color=discord.Color.green())
embed.add_field(name="Database", value=db_status, inline=False)
embed.add_field(name="AI Provider", value=ai_status, inline=False)
await ctx.send(embed=embed)
async def setup(bot: GuardDen) -> None:
"""Load the health cog."""
await bot.add_cog(Health(bot))

View File

@@ -1,7 +1,6 @@
"""Moderation commands and automod features."""
import logging
import re
from datetime import datetime, timedelta, timezone
import discord
@@ -10,36 +9,43 @@ from sqlalchemy import func, select
from guardden.bot import GuardDen
from guardden.models import ModerationLog, Strike
from guardden.utils import parse_duration
from guardden.utils.ratelimit import RateLimitExceeded
logger = logging.getLogger(__name__)
def parse_duration(duration_str: str) -> timedelta | None:
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
if not match:
return None
amount = int(match.group(1))
unit = match.group(2)
units = {
"s": timedelta(seconds=amount),
"m": timedelta(minutes=amount),
"h": timedelta(hours=amount),
"d": timedelta(days=amount),
"w": timedelta(weeks=amount),
}
return units.get(unit)
class Moderation(commands.Cog):
"""Moderation commands for server management."""
def __init__(self, bot: GuardDen) -> None:
self.bot = bot
def cog_check(self, ctx: commands.Context) -> bool:
if not ctx.guild:
return False
if not self.bot.is_owner_allowed(ctx.author.id):
return False
return True
async def cog_before_invoke(self, ctx: commands.Context) -> None:
if not ctx.command:
return
result = self.bot.rate_limiter.acquire_command(
ctx.command.qualified_name,
user_id=ctx.author.id,
guild_id=ctx.guild.id if ctx.guild else None,
channel_id=ctx.channel.id,
)
if result.is_limited:
raise RateLimitExceeded(result.reset_after)
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
if isinstance(error, RateLimitExceeded):
await ctx.send(
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
)
async def _log_action(
self,
guild: discord.Guild,
@@ -334,7 +340,15 @@ class Moderation(commands.Cog):
except discord.Forbidden:
pass
await member.kick(reason=f"{ctx.author}: {reason}")
try:
await member.kick(reason=f"{ctx.author}: {reason}")
except discord.Forbidden:
await ctx.send("❌ I don't have permission to kick this member.")
return
except discord.HTTPException as e:
await ctx.send(f"❌ Failed to kick member: {e}")
return
await self._log_action(ctx.guild, member, ctx.author, "kick", reason)
embed = discord.Embed(
@@ -346,7 +360,10 @@ class Moderation(commands.Cog):
embed.add_field(name="Reason", value=reason, inline=False)
embed.set_footer(text=f"Moderator: {ctx.author}")
await ctx.send(embed=embed)
try:
await ctx.send(embed=embed)
except discord.HTTPException:
await ctx.send(f"{member} has been kicked from the server.")
@commands.command(name="ban")
@commands.has_permissions(ban_members=True)
@@ -376,7 +393,15 @@ class Moderation(commands.Cog):
except discord.Forbidden:
pass
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
try:
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
except discord.Forbidden:
await ctx.send("❌ I don't have permission to ban this member.")
return
except discord.HTTPException as e:
await ctx.send(f"❌ Failed to ban member: {e}")
return
await self._log_action(ctx.guild, member, ctx.author, "ban", reason)
embed = discord.Embed(
@@ -388,7 +413,10 @@ class Moderation(commands.Cog):
embed.add_field(name="Reason", value=reason, inline=False)
embed.set_footer(text=f"Moderator: {ctx.author}")
await ctx.send(embed=embed)
try:
await ctx.send(embed=embed)
except discord.HTTPException:
await ctx.send(f"{member} has been banned from the server.")
@commands.command(name="unban")
@commands.has_permissions(ban_members=True)

View File

@@ -13,6 +13,7 @@ from guardden.services.verification import (
PendingVerification,
VerificationService,
)
from guardden.utils.ratelimit import RateLimitExceeded
logger = logging.getLogger(__name__)
@@ -155,6 +156,31 @@ class Verification(commands.Cog):
self.service = VerificationService()
self.cleanup_task.start()
def cog_check(self, ctx: commands.Context) -> bool:
if not ctx.guild:
return False
if not self.bot.is_owner_allowed(ctx.author.id):
return False
return True
async def cog_before_invoke(self, ctx: commands.Context) -> None:
if not ctx.command:
return
result = self.bot.rate_limiter.acquire_command(
ctx.command.qualified_name,
user_id=ctx.author.id,
guild_id=ctx.guild.id if ctx.guild else None,
channel_id=ctx.channel.id,
)
if result.is_limited:
raise RateLimitExceeded(result.reset_after)
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
if isinstance(error, RateLimitExceeded):
await ctx.send(
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
)
def cog_unload(self) -> None:
self.cleanup_task.cancel()

View File

@@ -1,12 +1,70 @@
"""Configuration management for GuardDen."""
import json
import re
from pathlib import Path
from typing import Literal
from typing import Any, Literal
from pydantic import Field, SecretStr
from pydantic import Field, SecretStr, field_validator, ValidationError
from pydantic_settings import BaseSettings, SettingsConfigDict
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
def _validate_discord_id(value: str | int) -> int:
"""Validate a Discord snowflake ID."""
if isinstance(value, int):
id_str = str(value)
else:
id_str = str(value).strip()
# Check format
if not DISCORD_ID_PATTERN.match(id_str):
raise ValueError(f"Invalid Discord ID format: {id_str}")
# Convert to int and validate range
discord_id = int(id_str)
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
raise ValueError(f"Discord ID out of valid range: {discord_id}")
return discord_id
def _parse_id_list(value: Any) -> list[int]:
"""Parse an environment value into a list of valid Discord IDs."""
if value is None:
return []
items: list[Any]
if isinstance(value, list):
items = value
elif isinstance(value, str):
text = value.strip()
if not text:
return []
# Only allow comma or semicolon separated values, no JSON parsing for security
items = [part.strip() for part in text.replace(";", ",").split(",") if part.strip()]
else:
items = [value]
parsed: list[int] = []
seen: set[int] = set()
for item in items:
try:
discord_id = _validate_discord_id(item)
if discord_id not in seen:
parsed.append(discord_id)
seen.add(discord_id)
except (ValueError, TypeError):
# Skip invalid IDs rather than failing silently
continue
return parsed
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
@@ -40,11 +98,79 @@ class Settings(BaseSettings):
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
default="INFO", description="Logging level"
)
log_json: bool = Field(default=False, description="Use JSON structured logging format")
log_file: str | None = Field(default=None, description="Log file path (optional)")
# Access control
allowed_guilds: list[int] = Field(
default_factory=list,
description="Guild IDs the bot is allowed to join (empty = allow all)",
)
owner_ids: list[int] = Field(
default_factory=list,
description="Owner user IDs with elevated access (empty = allow admins)",
)
# Paths
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
@field_validator("allowed_guilds", "owner_ids", mode="before")
@classmethod
def _validate_id_list(cls, value: Any) -> list[int]:
return _parse_id_list(value)
@field_validator("discord_token")
@classmethod
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
"""Validate Discord bot token format."""
token = value.get_secret_value()
if not token:
raise ValueError("Discord token cannot be empty")
# Basic Discord token format validation (not perfect but catches common issues)
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
raise ValueError("Invalid Discord token format")
return value
@field_validator("anthropic_api_key", "openai_api_key")
@classmethod
def _validate_api_key(cls, value: SecretStr | None) -> SecretStr | None:
"""Validate API key format if provided."""
if value is None:
return None
key = value.get_secret_value()
if not key:
return None
# Basic API key validation
if len(key) < 20:
raise ValueError("API key too short to be valid")
return value
def validate_configuration(self) -> None:
"""Validate the settings for runtime usage."""
# AI provider validation
if self.ai_provider == "anthropic" and not self.anthropic_api_key:
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
if self.ai_provider == "openai" and not self.openai_api_key:
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
# Database pool validation
if self.database_pool_min > self.database_pool_max:
raise ValueError("database_pool_min cannot be greater than database_pool_max")
if self.database_pool_min < 1:
raise ValueError("database_pool_min must be at least 1")
# Data directory validation
if not isinstance(self.data_dir, Path):
raise ValueError("data_dir must be a valid path")
def get_settings() -> Settings:
"""Get application settings instance."""
return Settings()
settings = Settings()
settings.validate_configuration()
return settings

View File

@@ -0,0 +1 @@
"""Dashboard application package."""

View File

@@ -0,0 +1,267 @@
"""Analytics API routes for the GuardDen dashboard."""
from collections.abc import AsyncIterator
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from guardden.dashboard.auth import require_owner
from guardden.dashboard.config import DashboardSettings
from guardden.dashboard.db import DashboardDatabase
from guardden.dashboard.schemas import (
AIPerformanceStats,
AnalyticsSummary,
ModerationStats,
TimeSeriesDataPoint,
UserActivityStats,
)
from guardden.models import AICheck, MessageActivity, ModerationLog, UserActivity
def create_analytics_router(
settings: DashboardSettings,
database: DashboardDatabase,
) -> APIRouter:
"""Create the analytics API router."""
router = APIRouter(prefix="/api/analytics")
async def get_session() -> AsyncIterator[AsyncSession]:
async for session in database.session():
yield session
def require_owner_dep(request: Request) -> None:
require_owner(settings, request)
@router.get(
"/summary",
response_model=AnalyticsSummary,
dependencies=[Depends(require_owner_dep)],
)
async def analytics_summary(
guild_id: int | None = Query(default=None),
days: int = Query(default=7, ge=1, le=90),
session: AsyncSession = Depends(get_session),
) -> AnalyticsSummary:
"""Get analytics summary for the specified time period."""
start_date = datetime.now() - timedelta(days=days)
# Moderation stats
mod_query = select(ModerationLog).where(ModerationLog.created_at >= start_date)
if guild_id:
mod_query = mod_query.where(ModerationLog.guild_id == guild_id)
mod_result = await session.execute(mod_query)
mod_logs = mod_result.scalars().all()
total_actions = len(mod_logs)
actions_by_type: dict[str, int] = {}
automatic_count = 0
manual_count = 0
for log in mod_logs:
actions_by_type[log.action] = actions_by_type.get(log.action, 0) + 1
if log.is_automatic:
automatic_count += 1
else:
manual_count += 1
# Time series data (group by day)
time_series: dict[str, int] = {}
for log in mod_logs:
day_key = log.created_at.strftime("%Y-%m-%d")
time_series[day_key] = time_series.get(day_key, 0) + 1
actions_over_time = [
TimeSeriesDataPoint(timestamp=datetime.strptime(day, "%Y-%m-%d"), value=count)
for day, count in sorted(time_series.items())
]
moderation_stats = ModerationStats(
total_actions=total_actions,
actions_by_type=actions_by_type,
actions_over_time=actions_over_time,
automatic_vs_manual={"automatic": automatic_count, "manual": manual_count},
)
# User activity stats
activity_query = select(MessageActivity).where(MessageActivity.date >= start_date)
if guild_id:
activity_query = activity_query.where(MessageActivity.guild_id == guild_id)
activity_result = await session.execute(activity_query)
activities = activity_result.scalars().all()
total_messages = sum(a.total_messages for a in activities)
active_users = max((a.active_users for a in activities), default=0)
# New joins
today = datetime.now().date()
week_ago = today - timedelta(days=7)
new_joins_today = sum(a.new_joins for a in activities if a.date.date() == today)
new_joins_week = sum(a.new_joins for a in activities if a.date.date() >= week_ago)
user_activity = UserActivityStats(
active_users=active_users,
total_messages=total_messages,
new_joins_today=new_joins_today,
new_joins_week=new_joins_week,
)
# AI performance stats
ai_query = select(AICheck).where(AICheck.created_at >= start_date)
if guild_id:
ai_query = ai_query.where(AICheck.guild_id == guild_id)
ai_result = await session.execute(ai_query)
ai_checks = ai_result.scalars().all()
total_checks = len(ai_checks)
flagged_content = sum(1 for c in ai_checks if c.flagged)
avg_confidence = (
sum(c.confidence for c in ai_checks) / total_checks if total_checks > 0 else 0.0
)
false_positives = sum(1 for c in ai_checks if c.is_false_positive)
avg_response_time = (
sum(c.response_time_ms for c in ai_checks) / total_checks if total_checks > 0 else 0.0
)
ai_performance = AIPerformanceStats(
total_checks=total_checks,
flagged_content=flagged_content,
avg_confidence=avg_confidence,
false_positives=false_positives,
avg_response_time_ms=avg_response_time,
)
return AnalyticsSummary(
moderation_stats=moderation_stats,
user_activity=user_activity,
ai_performance=ai_performance,
)
@router.get(
"/moderation-stats",
response_model=ModerationStats,
dependencies=[Depends(require_owner_dep)],
)
async def moderation_stats(
guild_id: int | None = Query(default=None),
days: int = Query(default=30, ge=1, le=90),
session: AsyncSession = Depends(get_session),
) -> ModerationStats:
"""Get detailed moderation statistics."""
start_date = datetime.now() - timedelta(days=days)
query = select(ModerationLog).where(ModerationLog.created_at >= start_date)
if guild_id:
query = query.where(ModerationLog.guild_id == guild_id)
result = await session.execute(query)
logs = result.scalars().all()
total_actions = len(logs)
actions_by_type: dict[str, int] = {}
automatic_count = 0
manual_count = 0
for log in logs:
actions_by_type[log.action] = actions_by_type.get(log.action, 0) + 1
if log.is_automatic:
automatic_count += 1
else:
manual_count += 1
# Time series data
time_series: dict[str, int] = {}
for log in logs:
day_key = log.created_at.strftime("%Y-%m-%d")
time_series[day_key] = time_series.get(day_key, 0) + 1
actions_over_time = [
TimeSeriesDataPoint(timestamp=datetime.strptime(day, "%Y-%m-%d"), value=count)
for day, count in sorted(time_series.items())
]
return ModerationStats(
total_actions=total_actions,
actions_by_type=actions_by_type,
actions_over_time=actions_over_time,
automatic_vs_manual={"automatic": automatic_count, "manual": manual_count},
)
@router.get(
"/user-activity",
response_model=UserActivityStats,
dependencies=[Depends(require_owner_dep)],
)
async def user_activity_stats(
guild_id: int | None = Query(default=None),
days: int = Query(default=7, ge=1, le=90),
session: AsyncSession = Depends(get_session),
) -> UserActivityStats:
"""Get user activity statistics."""
start_date = datetime.now() - timedelta(days=days)
query = select(MessageActivity).where(MessageActivity.date >= start_date)
if guild_id:
query = query.where(MessageActivity.guild_id == guild_id)
result = await session.execute(query)
activities = result.scalars().all()
total_messages = sum(a.total_messages for a in activities)
active_users = max((a.active_users for a in activities), default=0)
today = datetime.now().date()
week_ago = today - timedelta(days=7)
new_joins_today = sum(a.new_joins for a in activities if a.date.date() == today)
new_joins_week = sum(a.new_joins for a in activities if a.date.date() >= week_ago)
return UserActivityStats(
active_users=active_users,
total_messages=total_messages,
new_joins_today=new_joins_today,
new_joins_week=new_joins_week,
)
@router.get(
"/ai-performance",
response_model=AIPerformanceStats,
dependencies=[Depends(require_owner_dep)],
)
async def ai_performance_stats(
guild_id: int | None = Query(default=None),
days: int = Query(default=30, ge=1, le=90),
session: AsyncSession = Depends(get_session),
) -> AIPerformanceStats:
"""Get AI moderation performance statistics."""
start_date = datetime.now() - timedelta(days=days)
query = select(AICheck).where(AICheck.created_at >= start_date)
if guild_id:
query = query.where(AICheck.guild_id == guild_id)
result = await session.execute(query)
checks = result.scalars().all()
total_checks = len(checks)
flagged_content = sum(1 for c in checks if c.flagged)
avg_confidence = (
sum(c.confidence for c in checks) / total_checks if total_checks > 0 else 0.0
)
false_positives = sum(1 for c in checks if c.is_false_positive)
avg_response_time = (
sum(c.response_time_ms for c in checks) / total_checks if total_checks > 0 else 0.0
)
return AIPerformanceStats(
total_checks=total_checks,
flagged_content=flagged_content,
avg_confidence=avg_confidence,
false_positives=false_positives,
avg_response_time_ms=avg_response_time,
)
return router

View File

@@ -0,0 +1,78 @@
"""Authentication helpers for the dashboard."""
from typing import Any
from urllib.parse import urlencode
import httpx
from authlib.integrations.starlette_client import OAuth
from fastapi import HTTPException, Request, status
from guardden.dashboard.config import DashboardSettings
def build_oauth(settings: DashboardSettings) -> OAuth:
"""Build OAuth client registrations."""
oauth = OAuth()
oauth.register(
name="entra",
client_id=settings.entra_client_id,
client_secret=settings.entra_client_secret.get_secret_value(),
server_metadata_url=(
"https://login.microsoftonline.com/"
f"{settings.entra_tenant_id}/v2.0/.well-known/openid-configuration"
),
client_kwargs={"scope": "openid profile email"},
)
return oauth
def discord_authorize_url(settings: DashboardSettings, state: str) -> str:
"""Generate the Discord OAuth authorization URL."""
query = urlencode(
{
"client_id": settings.discord_client_id,
"redirect_uri": settings.callback_url("discord"),
"response_type": "code",
"scope": "identify",
"state": state,
}
)
return f"https://discord.com/oauth2/authorize?{query}"
async def exchange_discord_code(settings: DashboardSettings, code: str) -> dict[str, Any]:
"""Exchange a Discord OAuth code for a user profile."""
async with httpx.AsyncClient(timeout=10.0) as client:
token_response = await client.post(
"https://discord.com/api/oauth2/token",
data={
"client_id": settings.discord_client_id,
"client_secret": settings.discord_client_secret.get_secret_value(),
"grant_type": "authorization_code",
"code": code,
"redirect_uri": settings.callback_url("discord"),
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
token_response.raise_for_status()
token_data = token_response.json()
user_response = await client.get(
"https://discord.com/api/users/@me",
headers={"Authorization": f"Bearer {token_data['access_token']}"},
)
user_response.raise_for_status()
return user_response.json()
def require_owner(settings: DashboardSettings, request: Request) -> None:
"""Ensure the current session is the configured owner."""
session = request.session
entra_oid = session.get("entra_oid")
discord_id = session.get("discord_id")
if not entra_oid or not discord_id:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
if str(entra_oid) != settings.owner_entra_object_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
if int(discord_id) != settings.owner_discord_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")

View File

@@ -0,0 +1,68 @@
"""Configuration for the GuardDen dashboard."""
from pathlib import Path
from typing import Any
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class DashboardSettings(BaseSettings):
"""Dashboard settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
env_prefix="GUARDDEN_DASHBOARD_",
)
database_url: SecretStr = Field(
validation_alias="GUARDDEN_DATABASE_URL",
description="Database connection URL",
)
base_url: str = Field(
default="http://localhost:8080",
description="Base URL for OAuth callbacks",
)
secret_key: SecretStr = Field(
default=SecretStr("change-me"),
description="Session secret key",
)
entra_tenant_id: str = Field(description="Entra ID tenant ID")
entra_client_id: str = Field(description="Entra ID application client ID")
entra_client_secret: SecretStr = Field(description="Entra ID application client secret")
discord_client_id: str = Field(description="Discord OAuth client ID")
discord_client_secret: SecretStr = Field(description="Discord OAuth client secret")
owner_discord_id: int = Field(description="Discord user ID allowed to access dashboard")
owner_entra_object_id: str = Field(description="Entra ID object ID allowed to access")
cors_origins: list[str] = Field(default_factory=list, description="Allowed CORS origins")
static_dir: Path = Field(
default=Path("dashboard/frontend/dist"),
description="Directory containing built frontend assets",
)
@field_validator("cors_origins", mode="before")
@classmethod
def _parse_origins(cls, value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item).strip() for item in value if str(item).strip()]
text = str(value).strip()
if not text:
return []
return [item.strip() for item in text.split(",") if item.strip()]
def callback_url(self, provider: str) -> str:
return f"{self.base_url}/auth/{provider}/callback"
def get_dashboard_settings() -> DashboardSettings:
"""Load dashboard settings from environment."""
return DashboardSettings()

View File

@@ -0,0 +1,298 @@
"""Configuration management API routes for the GuardDen dashboard."""
import json
from collections.abc import AsyncIterator
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, status
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from guardden.dashboard.auth import require_owner
from guardden.dashboard.config import DashboardSettings
from guardden.dashboard.db import DashboardDatabase
from guardden.dashboard.schemas import AutomodRuleConfig, ConfigExport, GuildSettings
from guardden.models import Guild
from guardden.models import GuildSettings as GuildSettingsModel
def create_config_router(
settings: DashboardSettings,
database: DashboardDatabase,
) -> APIRouter:
"""Create the configuration management API router."""
router = APIRouter(prefix="/api/guilds")
async def get_session() -> AsyncIterator[AsyncSession]:
async for session in database.session():
yield session
def require_owner_dep(request: Request) -> None:
require_owner(settings, request)
@router.get(
"/{guild_id}/settings",
response_model=GuildSettings,
dependencies=[Depends(require_owner_dep)],
)
async def get_guild_settings(
guild_id: int = Path(...),
session: AsyncSession = Depends(get_session),
) -> GuildSettings:
"""Get guild settings."""
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
result = await session.execute(query)
guild_settings = result.scalar_one_or_none()
if not guild_settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Guild settings not found",
)
return GuildSettings(
guild_id=guild_settings.guild_id,
prefix=guild_settings.prefix,
log_channel_id=guild_settings.log_channel_id,
automod_enabled=guild_settings.automod_enabled,
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
ai_sensitivity=guild_settings.ai_sensitivity,
verification_enabled=guild_settings.verification_enabled,
verification_role_id=guild_settings.verified_role_id,
max_warns_before_action=3, # Default value, could be derived from strike_actions
)
@router.put(
"/{guild_id}/settings",
response_model=GuildSettings,
dependencies=[Depends(require_owner_dep)],
)
async def update_guild_settings(
guild_id: int = Path(...),
settings_data: GuildSettings = ...,
session: AsyncSession = Depends(get_session),
) -> GuildSettings:
"""Update guild settings."""
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
result = await session.execute(query)
guild_settings = result.scalar_one_or_none()
if not guild_settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Guild settings not found",
)
# Update settings
if settings_data.prefix is not None:
guild_settings.prefix = settings_data.prefix
if settings_data.log_channel_id is not None:
guild_settings.log_channel_id = settings_data.log_channel_id
guild_settings.automod_enabled = settings_data.automod_enabled
guild_settings.ai_moderation_enabled = settings_data.ai_moderation_enabled
guild_settings.ai_sensitivity = settings_data.ai_sensitivity
guild_settings.verification_enabled = settings_data.verification_enabled
if settings_data.verification_role_id is not None:
guild_settings.verified_role_id = settings_data.verification_role_id
await session.commit()
await session.refresh(guild_settings)
return GuildSettings(
guild_id=guild_settings.guild_id,
prefix=guild_settings.prefix,
log_channel_id=guild_settings.log_channel_id,
automod_enabled=guild_settings.automod_enabled,
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
ai_sensitivity=guild_settings.ai_sensitivity,
verification_enabled=guild_settings.verification_enabled,
verification_role_id=guild_settings.verified_role_id,
max_warns_before_action=3,
)
@router.get(
"/{guild_id}/automod",
response_model=AutomodRuleConfig,
dependencies=[Depends(require_owner_dep)],
)
async def get_automod_config(
guild_id: int = Path(...),
session: AsyncSession = Depends(get_session),
) -> AutomodRuleConfig:
"""Get automod rule configuration."""
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
result = await session.execute(query)
guild_settings = result.scalar_one_or_none()
if not guild_settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Guild settings not found",
)
return AutomodRuleConfig(
guild_id=guild_settings.guild_id,
banned_words_enabled=True, # Derived from automod_enabled
scam_detection_enabled=guild_settings.automod_enabled,
spam_detection_enabled=guild_settings.anti_spam_enabled,
invite_filter_enabled=guild_settings.link_filter_enabled,
max_mentions=guild_settings.mention_limit,
max_emojis=10, # Default value
spam_threshold=guild_settings.message_rate_limit,
)
@router.put(
"/{guild_id}/automod",
response_model=AutomodRuleConfig,
dependencies=[Depends(require_owner_dep)],
)
async def update_automod_config(
guild_id: int = Path(...),
automod_data: AutomodRuleConfig = ...,
session: AsyncSession = Depends(get_session),
) -> AutomodRuleConfig:
"""Update automod rule configuration."""
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
result = await session.execute(query)
guild_settings = result.scalar_one_or_none()
if not guild_settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Guild settings not found",
)
# Update automod settings
guild_settings.automod_enabled = automod_data.scam_detection_enabled
guild_settings.anti_spam_enabled = automod_data.spam_detection_enabled
guild_settings.link_filter_enabled = automod_data.invite_filter_enabled
guild_settings.mention_limit = automod_data.max_mentions
guild_settings.message_rate_limit = automod_data.spam_threshold
await session.commit()
await session.refresh(guild_settings)
return AutomodRuleConfig(
guild_id=guild_settings.guild_id,
banned_words_enabled=automod_data.banned_words_enabled,
scam_detection_enabled=guild_settings.automod_enabled,
spam_detection_enabled=guild_settings.anti_spam_enabled,
invite_filter_enabled=guild_settings.link_filter_enabled,
max_mentions=guild_settings.mention_limit,
max_emojis=10,
spam_threshold=guild_settings.message_rate_limit,
)
@router.get(
"/{guild_id}/export",
dependencies=[Depends(require_owner_dep)],
)
async def export_config(
guild_id: int = Path(...),
session: AsyncSession = Depends(get_session),
) -> StreamingResponse:
"""Export guild configuration as JSON."""
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
result = await session.execute(query)
guild_settings = result.scalar_one_or_none()
if not guild_settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Guild settings not found",
)
# Build export data
export_data = ConfigExport(
version="1.0",
guild_settings=GuildSettings(
guild_id=guild_settings.guild_id,
prefix=guild_settings.prefix,
log_channel_id=guild_settings.log_channel_id,
automod_enabled=guild_settings.automod_enabled,
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
ai_sensitivity=guild_settings.ai_sensitivity,
verification_enabled=guild_settings.verification_enabled,
verification_role_id=guild_settings.verified_role_id,
max_warns_before_action=3,
),
automod_rules=AutomodRuleConfig(
guild_id=guild_settings.guild_id,
banned_words_enabled=True,
scam_detection_enabled=guild_settings.automod_enabled,
spam_detection_enabled=guild_settings.anti_spam_enabled,
invite_filter_enabled=guild_settings.link_filter_enabled,
max_mentions=guild_settings.mention_limit,
max_emojis=10,
spam_threshold=guild_settings.message_rate_limit,
),
exported_at=datetime.now(),
)
# Convert to JSON
json_data = export_data.model_dump_json(indent=2)
return StreamingResponse(
iter([json_data]),
media_type="application/json",
headers={"Content-Disposition": f"attachment; filename=guild_{guild_id}_config.json"},
)
@router.post(
"/{guild_id}/import",
response_model=GuildSettings,
dependencies=[Depends(require_owner_dep)],
)
async def import_config(
guild_id: int = Path(...),
config_data: ConfigExport = ...,
session: AsyncSession = Depends(get_session),
) -> GuildSettings:
"""Import guild configuration from JSON."""
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
result = await session.execute(query)
guild_settings = result.scalar_one_or_none()
if not guild_settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Guild settings not found",
)
# Import settings
settings = config_data.guild_settings
if settings.prefix is not None:
guild_settings.prefix = settings.prefix
if settings.log_channel_id is not None:
guild_settings.log_channel_id = settings.log_channel_id
guild_settings.automod_enabled = settings.automod_enabled
guild_settings.ai_moderation_enabled = settings.ai_moderation_enabled
guild_settings.ai_sensitivity = settings.ai_sensitivity
guild_settings.verification_enabled = settings.verification_enabled
if settings.verification_role_id is not None:
guild_settings.verified_role_id = settings.verification_role_id
# Import automod rules
automod = config_data.automod_rules
guild_settings.anti_spam_enabled = automod.spam_detection_enabled
guild_settings.link_filter_enabled = automod.invite_filter_enabled
guild_settings.mention_limit = automod.max_mentions
guild_settings.message_rate_limit = automod.spam_threshold
await session.commit()
await session.refresh(guild_settings)
return GuildSettings(
guild_id=guild_settings.guild_id,
prefix=guild_settings.prefix,
log_channel_id=guild_settings.log_channel_id,
automod_enabled=guild_settings.automod_enabled,
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
ai_sensitivity=guild_settings.ai_sensitivity,
verification_enabled=guild_settings.verification_enabled,
verification_role_id=guild_settings.verified_role_id,
max_warns_before_action=3,
)
return router

View File

@@ -0,0 +1,24 @@
"""Database helpers for the dashboard."""
from collections.abc import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from guardden.dashboard.config import DashboardSettings
class DashboardDatabase:
"""Async database session factory for the dashboard."""
def __init__(self, settings: DashboardSettings) -> None:
db_url = settings.database_url.get_secret_value()
if db_url.startswith("postgresql://"):
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
self._engine = create_async_engine(db_url, pool_pre_ping=True)
self._sessionmaker = async_sessionmaker(self._engine, expire_on_commit=False)
async def session(self) -> AsyncIterator[AsyncSession]:
"""Yield a database session."""
async with self._sessionmaker() as session:
yield session

View File

@@ -0,0 +1,121 @@
"""FastAPI app for the GuardDen dashboard."""
import logging
import secrets
from pathlib import Path
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
from starlette.staticfiles import StaticFiles
from guardden.dashboard.analytics import create_analytics_router
from guardden.dashboard.auth import (
build_oauth,
discord_authorize_url,
exchange_discord_code,
require_owner,
)
from guardden.dashboard.config import DashboardSettings, get_dashboard_settings
from guardden.dashboard.config_management import create_config_router
from guardden.dashboard.db import DashboardDatabase
from guardden.dashboard.routes import create_api_router
from guardden.dashboard.users import create_users_router
from guardden.dashboard.websocket import create_websocket_router
logger = logging.getLogger(__name__)
def create_app() -> FastAPI:
settings = get_dashboard_settings()
database = DashboardDatabase(settings)
oauth = build_oauth(settings)
app = FastAPI(title="GuardDen Dashboard")
app.add_middleware(SessionMiddleware, secret_key=settings.secret_key.get_secret_value())
if settings.cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def require_owner_dep(request: Request) -> None:
require_owner(settings, request)
@app.get("/api/health")
async def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/api/me")
async def me(request: Request) -> dict[str, bool | str | None]:
entra_oid = request.session.get("entra_oid")
discord_id = request.session.get("discord_id")
owner = str(entra_oid) == settings.owner_entra_object_id and str(discord_id) == str(
settings.owner_discord_id
)
return {
"entra": bool(entra_oid),
"discord": bool(discord_id),
"owner": owner,
"entra_oid": entra_oid,
"discord_id": discord_id,
}
@app.get("/auth/entra/login")
async def entra_login(request: Request) -> RedirectResponse:
redirect_uri = settings.callback_url("entra")
return await oauth.entra.authorize_redirect(request, redirect_uri)
@app.get("/auth/entra/callback")
async def entra_callback(request: Request) -> RedirectResponse:
token = await oauth.entra.authorize_access_token(request)
user = await oauth.entra.parse_id_token(request, token)
request.session["entra_oid"] = user.get("oid")
return RedirectResponse(url="/")
@app.get("/auth/discord/login")
async def discord_login(request: Request) -> RedirectResponse:
state = secrets.token_urlsafe(16)
request.session["discord_state"] = state
return RedirectResponse(url=discord_authorize_url(settings, state))
@app.get("/auth/discord/callback")
async def discord_callback(request: Request) -> RedirectResponse:
params = dict(request.query_params)
code = params.get("code")
state = params.get("state")
if not code or not state:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing code")
if state != request.session.get("discord_state"):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state")
profile = await exchange_discord_code(settings, code)
request.session["discord_id"] = profile.get("id")
return RedirectResponse(url="/")
@app.get("/auth/logout")
async def logout(request: Request) -> RedirectResponse:
request.session.clear()
return RedirectResponse(url="/")
# Include all API routers
app.include_router(create_api_router(settings, database))
app.include_router(create_analytics_router(settings, database))
app.include_router(create_users_router(settings, database))
app.include_router(create_config_router(settings, database))
app.include_router(create_websocket_router(settings))
static_dir = Path(settings.static_dir)
if static_dir.exists():
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
else:
logger.warning("Static directory not found: %s", static_dir)
return app
app = create_app()

View File

@@ -0,0 +1,87 @@
"""API routes for the GuardDen dashboard."""
from collections.abc import AsyncIterator
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from guardden.dashboard.auth import require_owner
from guardden.dashboard.config import DashboardSettings
from guardden.dashboard.db import DashboardDatabase
from guardden.dashboard.schemas import GuildSummary, ModerationLogEntry, PaginatedLogs
from guardden.models import Guild, ModerationLog
def create_api_router(
settings: DashboardSettings,
database: DashboardDatabase,
) -> APIRouter:
"""Create the dashboard API router."""
router = APIRouter(prefix="/api")
async def get_session() -> AsyncIterator[AsyncSession]:
async for session in database.session():
yield session
def require_owner_dep(request: Request) -> None:
require_owner(settings, request)
@router.get("/guilds", response_model=list[GuildSummary], dependencies=[Depends(require_owner_dep)])
async def list_guilds(
session: AsyncSession = Depends(get_session),
) -> list[GuildSummary]:
result = await session.execute(select(Guild).order_by(Guild.name.asc()))
guilds = result.scalars().all()
return [
GuildSummary(id=g.id, name=g.name, owner_id=g.owner_id, premium=g.premium)
for g in guilds
]
@router.get(
"/moderation/logs",
response_model=PaginatedLogs,
dependencies=[Depends(require_owner_dep)],
)
async def list_moderation_logs(
guild_id: int | None = Query(default=None),
limit: int = Query(default=50, ge=1, le=200),
offset: int = Query(default=0, ge=0),
session: AsyncSession = Depends(get_session),
) -> PaginatedLogs:
query = select(ModerationLog)
count_query = select(func.count(ModerationLog.id))
if guild_id:
query = query.where(ModerationLog.guild_id == guild_id)
count_query = count_query.where(ModerationLog.guild_id == guild_id)
query = query.order_by(ModerationLog.created_at.desc()).offset(offset).limit(limit)
total_result = await session.execute(count_query)
total = int(total_result.scalar() or 0)
result = await session.execute(query)
logs = result.scalars().all()
items = [
ModerationLogEntry(
id=log.id,
guild_id=log.guild_id,
target_id=log.target_id,
target_name=log.target_name,
moderator_id=log.moderator_id,
moderator_name=log.moderator_name,
action=log.action,
reason=log.reason,
duration=log.duration,
expires_at=log.expires_at,
channel_id=log.channel_id,
message_id=log.message_id,
message_content=log.message_content,
is_automatic=log.is_automatic,
created_at=log.created_at,
)
for log in logs
]
return PaginatedLogs(total=total, items=items)
return router

View File

@@ -0,0 +1,163 @@
"""Pydantic schemas for dashboard APIs."""
from datetime import datetime
from pydantic import BaseModel, Field
class GuildSummary(BaseModel):
id: int
name: str
owner_id: int
premium: bool
class ModerationLogEntry(BaseModel):
id: int
guild_id: int
target_id: int
target_name: str
moderator_id: int
moderator_name: str
action: str
reason: str | None
duration: int | None
expires_at: datetime | None
channel_id: int | None
message_id: int | None
message_content: str | None
is_automatic: bool
created_at: datetime
class PaginatedLogs(BaseModel):
total: int
items: list[ModerationLogEntry]
# Analytics Schemas
class TimeSeriesDataPoint(BaseModel):
timestamp: datetime
value: int
class ModerationStats(BaseModel):
total_actions: int
actions_by_type: dict[str, int]
actions_over_time: list[TimeSeriesDataPoint]
automatic_vs_manual: dict[str, int]
class UserActivityStats(BaseModel):
active_users: int
total_messages: int
new_joins_today: int
new_joins_week: int
class AIPerformanceStats(BaseModel):
total_checks: int
flagged_content: int
avg_confidence: float
false_positives: int = 0
avg_response_time_ms: float = 0.0
class AnalyticsSummary(BaseModel):
moderation_stats: ModerationStats
user_activity: UserActivityStats
ai_performance: AIPerformanceStats
# User Management Schemas
class UserProfile(BaseModel):
user_id: int
username: str
strike_count: int
total_warnings: int
total_kicks: int
total_bans: int
total_timeouts: int
first_seen: datetime
last_action: datetime | None
class UserNote(BaseModel):
id: int
user_id: int
guild_id: int
moderator_id: int
moderator_name: str
content: str
created_at: datetime
class CreateUserNote(BaseModel):
content: str = Field(min_length=1, max_length=2000)
class BulkModerationAction(BaseModel):
action: str = Field(pattern="^(ban|kick|timeout|warn)$")
user_ids: list[int] = Field(min_length=1, max_length=100)
reason: str | None = None
duration: int | None = None
class BulkActionResult(BaseModel):
success_count: int
failed_count: int
errors: dict[int, str]
# Configuration Schemas
class GuildSettings(BaseModel):
guild_id: int
prefix: str | None = None
log_channel_id: int | None = None
automod_enabled: bool = True
ai_moderation_enabled: bool = False
ai_sensitivity: int = Field(ge=0, le=100, default=50)
verification_enabled: bool = False
verification_role_id: int | None = None
max_warns_before_action: int = Field(ge=1, le=10, default=3)
class AutomodRuleConfig(BaseModel):
guild_id: int
banned_words_enabled: bool = True
scam_detection_enabled: bool = True
spam_detection_enabled: bool = True
invite_filter_enabled: bool = False
max_mentions: int = Field(ge=1, le=20, default=5)
max_emojis: int = Field(ge=1, le=50, default=10)
spam_threshold: int = Field(ge=1, le=20, default=5)
class ConfigExport(BaseModel):
version: str = "1.0"
guild_settings: GuildSettings
automod_rules: AutomodRuleConfig
exported_at: datetime
# WebSocket Event Schemas
class WebSocketEvent(BaseModel):
type: str
guild_id: int
timestamp: datetime
data: dict[str, object]
class ModerationEvent(WebSocketEvent):
type: str = "moderation_action"
data: dict[str, object]
class UserJoinEvent(WebSocketEvent):
type: str = "user_join"
data: dict[str, object]
class AIAlertEvent(WebSocketEvent):
type: str = "ai_alert"
data: dict[str, object]

View File

@@ -0,0 +1,246 @@
"""User management API routes for the GuardDen dashboard."""
from collections.abc import AsyncIterator
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from guardden.dashboard.auth import require_owner
from guardden.dashboard.config import DashboardSettings
from guardden.dashboard.db import DashboardDatabase
from guardden.dashboard.schemas import CreateUserNote, UserNote, UserProfile
from guardden.models import ModerationLog, UserActivity
from guardden.models import UserNote as UserNoteModel
def create_users_router(
settings: DashboardSettings,
database: DashboardDatabase,
) -> APIRouter:
"""Create the user management API router."""
router = APIRouter(prefix="/api/users")
async def get_session() -> AsyncIterator[AsyncSession]:
async for session in database.session():
yield session
def require_owner_dep(request: Request) -> None:
require_owner(settings, request)
@router.get(
"/search",
response_model=list[UserProfile],
dependencies=[Depends(require_owner_dep)],
)
async def search_users(
guild_id: int = Query(...),
username: str | None = Query(default=None),
min_strikes: int | None = Query(default=None, ge=0),
limit: int = Query(default=50, ge=1, le=200),
session: AsyncSession = Depends(get_session),
) -> list[UserProfile]:
"""Search for users in a guild with optional filters."""
query = select(UserActivity).where(UserActivity.guild_id == guild_id)
if username:
query = query.where(UserActivity.username.ilike(f"%{username}%"))
if min_strikes is not None:
query = query.where(UserActivity.strike_count >= min_strikes)
query = query.order_by(UserActivity.last_seen.desc()).limit(limit)
result = await session.execute(query)
users = result.scalars().all()
# Get last moderation action for each user
profiles = []
for user in users:
last_action_query = (
select(ModerationLog.created_at)
.where(ModerationLog.guild_id == guild_id)
.where(ModerationLog.target_id == user.user_id)
.order_by(ModerationLog.created_at.desc())
.limit(1)
)
last_action_result = await session.execute(last_action_query)
last_action = last_action_result.scalar()
profiles.append(
UserProfile(
user_id=user.user_id,
username=user.username,
strike_count=user.strike_count,
total_warnings=user.warning_count,
total_kicks=user.kick_count,
total_bans=user.ban_count,
total_timeouts=user.timeout_count,
first_seen=user.first_seen,
last_action=last_action,
)
)
return profiles
@router.get(
"/{user_id}/profile",
response_model=UserProfile,
dependencies=[Depends(require_owner_dep)],
)
async def get_user_profile(
user_id: int = Path(...),
guild_id: int = Query(...),
session: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Get detailed profile for a specific user."""
query = (
select(UserActivity)
.where(UserActivity.guild_id == guild_id)
.where(UserActivity.user_id == user_id)
)
result = await session.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found in this guild",
)
# Get last moderation action
last_action_query = (
select(ModerationLog.created_at)
.where(ModerationLog.guild_id == guild_id)
.where(ModerationLog.target_id == user_id)
.order_by(ModerationLog.created_at.desc())
.limit(1)
)
last_action_result = await session.execute(last_action_query)
last_action = last_action_result.scalar()
return UserProfile(
user_id=user.user_id,
username=user.username,
strike_count=user.strike_count,
total_warnings=user.warning_count,
total_kicks=user.kick_count,
total_bans=user.ban_count,
total_timeouts=user.timeout_count,
first_seen=user.first_seen,
last_action=last_action,
)
@router.get(
"/{user_id}/notes",
response_model=list[UserNote],
dependencies=[Depends(require_owner_dep)],
)
async def get_user_notes(
user_id: int = Path(...),
guild_id: int = Query(...),
session: AsyncSession = Depends(get_session),
) -> list[UserNote]:
"""Get all notes for a specific user."""
query = (
select(UserNoteModel)
.where(UserNoteModel.guild_id == guild_id)
.where(UserNoteModel.user_id == user_id)
.order_by(UserNoteModel.created_at.desc())
)
result = await session.execute(query)
notes = result.scalars().all()
return [
UserNote(
id=note.id,
user_id=note.user_id,
guild_id=note.guild_id,
moderator_id=note.moderator_id,
moderator_name=note.moderator_name,
content=note.content,
created_at=note.created_at,
)
for note in notes
]
@router.post(
"/{user_id}/notes",
response_model=UserNote,
dependencies=[Depends(require_owner_dep)],
)
async def create_user_note(
user_id: int = Path(...),
guild_id: int = Query(...),
note_data: CreateUserNote = ...,
request: Request = ...,
session: AsyncSession = Depends(get_session),
) -> UserNote:
"""Create a new note for a user."""
# Get moderator info from session
moderator_id = request.session.get("discord_id")
if not moderator_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Discord authentication required",
)
# Create the note
new_note = UserNoteModel(
user_id=user_id,
guild_id=guild_id,
moderator_id=int(moderator_id),
moderator_name="Dashboard User", # TODO: Fetch actual username
content=note_data.content,
created_at=datetime.now(),
)
session.add(new_note)
await session.commit()
await session.refresh(new_note)
return UserNote(
id=new_note.id,
user_id=new_note.user_id,
guild_id=new_note.guild_id,
moderator_id=new_note.moderator_id,
moderator_name=new_note.moderator_name,
content=new_note.content,
created_at=new_note.created_at,
)
@router.delete(
"/{user_id}/notes/{note_id}",
status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Depends(require_owner_dep)],
)
async def delete_user_note(
user_id: int = Path(...),
note_id: int = Path(...),
guild_id: int = Query(...),
session: AsyncSession = Depends(get_session),
) -> None:
"""Delete a user note."""
query = (
select(UserNoteModel)
.where(UserNoteModel.id == note_id)
.where(UserNoteModel.guild_id == guild_id)
.where(UserNoteModel.user_id == user_id)
)
result = await session.execute(query)
note = result.scalar_one_or_none()
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Note not found",
)
await session.delete(note)
await session.commit()
return router

View File

@@ -0,0 +1,221 @@
"""WebSocket support for real-time dashboard updates."""
import asyncio
import logging
from datetime import datetime
from typing import Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from guardden.dashboard.config import DashboardSettings
from guardden.dashboard.schemas import WebSocketEvent
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manage WebSocket connections for real-time updates."""
def __init__(self) -> None:
self.active_connections: dict[int, list[WebSocket]] = {}
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, guild_id: int) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
async with self._lock:
if guild_id not in self.active_connections:
self.active_connections[guild_id] = []
self.active_connections[guild_id].append(websocket)
logger.info("New WebSocket connection for guild %s", guild_id)
async def disconnect(self, websocket: WebSocket, guild_id: int) -> None:
"""Remove a WebSocket connection."""
async with self._lock:
if guild_id in self.active_connections:
if websocket in self.active_connections[guild_id]:
self.active_connections[guild_id].remove(websocket)
if not self.active_connections[guild_id]:
del self.active_connections[guild_id]
logger.info("WebSocket disconnected for guild %s", guild_id)
async def broadcast_to_guild(self, guild_id: int, event: WebSocketEvent) -> None:
"""Broadcast an event to all connections for a specific guild."""
async with self._lock:
connections = self.active_connections.get(guild_id, []).copy()
if not connections:
return
# Convert event to JSON
message = event.model_dump_json()
# Send to all connections
dead_connections = []
for connection in connections:
try:
await connection.send_text(message)
except Exception as e:
logger.warning("Failed to send message to WebSocket: %s", e)
dead_connections.append(connection)
# Clean up dead connections
if dead_connections:
async with self._lock:
if guild_id in self.active_connections:
for conn in dead_connections:
if conn in self.active_connections[guild_id]:
self.active_connections[guild_id].remove(conn)
if not self.active_connections[guild_id]:
del self.active_connections[guild_id]
async def broadcast_to_all(self, event: WebSocketEvent) -> None:
"""Broadcast an event to all connections."""
async with self._lock:
all_guilds = list(self.active_connections.keys())
for guild_id in all_guilds:
await self.broadcast_to_guild(guild_id, event)
def get_connection_count(self, guild_id: int | None = None) -> int:
"""Get the number of active connections."""
if guild_id is not None:
return len(self.active_connections.get(guild_id, []))
return sum(len(conns) for conns in self.active_connections.values())
# Global connection manager
connection_manager = ConnectionManager()
def create_websocket_router(settings: DashboardSettings) -> APIRouter:
"""Create the WebSocket API router."""
router = APIRouter()
@router.websocket("/ws/events")
async def websocket_events(websocket: WebSocket, guild_id: int) -> None:
"""WebSocket endpoint for real-time events."""
await connection_manager.connect(websocket, guild_id)
try:
# Send initial connection confirmation
await websocket.send_json(
{
"type": "connected",
"guild_id": guild_id,
"timestamp": datetime.now().isoformat(),
"data": {"message": "Connected to real-time events"},
}
)
# Keep connection alive and handle incoming messages
while True:
try:
# Wait for messages from client (ping/pong, etc.)
data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
# Echo back as heartbeat
if data == "ping":
await websocket.send_text("pong")
except asyncio.TimeoutError:
# Send periodic ping to keep connection alive
await websocket.send_json(
{
"type": "ping",
"guild_id": guild_id,
"timestamp": datetime.now().isoformat(),
"data": {},
}
)
except WebSocketDisconnect:
logger.info("Client disconnected from WebSocket for guild %s", guild_id)
except Exception as e:
logger.error("WebSocket error for guild %s: %s", guild_id, e)
finally:
await connection_manager.disconnect(websocket, guild_id)
return router
# Helper functions to broadcast events from other parts of the application
async def broadcast_moderation_action(
guild_id: int,
action: str,
target_id: int,
target_name: str,
moderator_name: str,
reason: str | None = None,
) -> None:
"""Broadcast a moderation action event."""
event = WebSocketEvent(
type="moderation_action",
guild_id=guild_id,
timestamp=datetime.now(),
data={
"action": action,
"target_id": target_id,
"target_name": target_name,
"moderator_name": moderator_name,
"reason": reason,
},
)
await connection_manager.broadcast_to_guild(guild_id, event)
async def broadcast_user_join(
guild_id: int,
user_id: int,
username: str,
) -> None:
"""Broadcast a user join event."""
event = WebSocketEvent(
type="user_join",
guild_id=guild_id,
timestamp=datetime.now(),
data={
"user_id": user_id,
"username": username,
},
)
await connection_manager.broadcast_to_guild(guild_id, event)
async def broadcast_ai_alert(
guild_id: int,
user_id: int,
severity: str,
category: str,
confidence: float,
) -> None:
"""Broadcast an AI moderation alert."""
event = WebSocketEvent(
type="ai_alert",
guild_id=guild_id,
timestamp=datetime.now(),
data={
"user_id": user_id,
"severity": severity,
"category": category,
"confidence": confidence,
},
)
await connection_manager.broadcast_to_guild(guild_id, event)
async def broadcast_system_event(
event_type: str,
data: dict[str, Any],
guild_id: int | None = None,
) -> None:
"""Broadcast a generic system event."""
event = WebSocketEvent(
type=event_type,
guild_id=guild_id or 0,
timestamp=datetime.now(),
data=data,
)
if guild_id:
await connection_manager.broadcast_to_guild(guild_id, event)
else:
await connection_manager.broadcast_to_all(event)

234
src/guardden/health.py Normal file
View File

@@ -0,0 +1,234 @@
"""Health check utilities for GuardDen."""
import asyncio
import logging
import sys
from typing import Dict, Any
from guardden.config import get_settings
from guardden.services.database import Database
from guardden.services.ai import create_ai_provider
from guardden.utils.logging import get_logger
logger = get_logger(__name__)
class HealthChecker:
"""Comprehensive health check system for GuardDen."""
def __init__(self):
self.settings = get_settings()
self.database = Database(self.settings)
self.ai_provider = create_ai_provider(self.settings)
async def check_database(self) -> Dict[str, Any]:
"""Check database connectivity and performance."""
try:
start_time = asyncio.get_event_loop().time()
async with self.database.session() as session:
# Simple query to check connectivity
result = await session.execute("SELECT 1 as test")
test_value = result.scalar()
end_time = asyncio.get_event_loop().time()
response_time_ms = (end_time - start_time) * 1000
return {
"status": "healthy" if test_value == 1 else "unhealthy",
"response_time_ms": round(response_time_ms, 2),
"connection_pool": {
"pool_size": self.database._engine.pool.size() if self.database._engine else 0,
"checked_in": self.database._engine.pool.checkedin() if self.database._engine else 0,
"checked_out": self.database._engine.pool.checkedout() if self.database._engine else 0,
}
}
except Exception as e:
logger.error("Database health check failed", exc_info=e)
return {
"status": "unhealthy",
"error": str(e),
"error_type": type(e).__name__
}
async def check_ai_provider(self) -> Dict[str, Any]:
"""Check AI provider connectivity."""
if self.settings.ai_provider == "none":
return {
"status": "disabled",
"provider": "none"
}
try:
# Simple test to check if AI provider is responsive
start_time = asyncio.get_event_loop().time()
# This is a minimal test - actual implementation would depend on provider
provider_type = type(self.ai_provider).__name__
end_time = asyncio.get_event_loop().time()
response_time_ms = (end_time - start_time) * 1000
return {
"status": "healthy",
"provider": self.settings.ai_provider,
"provider_type": provider_type,
"response_time_ms": round(response_time_ms, 2)
}
except Exception as e:
logger.error("AI provider health check failed", exc_info=e)
return {
"status": "unhealthy",
"provider": self.settings.ai_provider,
"error": str(e),
"error_type": type(e).__name__
}
async def check_discord_connectivity(self) -> Dict[str, Any]:
"""Check Discord API connectivity (basic test)."""
try:
import aiohttp
start_time = asyncio.get_event_loop().time()
async with aiohttp.ClientSession() as session:
async with session.get("https://discord.com/api/v10/gateway") as response:
if response.status == 200:
data = await response.json()
end_time = asyncio.get_event_loop().time()
response_time_ms = (end_time - start_time) * 1000
return {
"status": "healthy",
"response_time_ms": round(response_time_ms, 2),
"gateway_url": data.get("url")
}
else:
return {
"status": "unhealthy",
"http_status": response.status,
"error": f"HTTP {response.status}"
}
except Exception as e:
logger.error("Discord connectivity check failed", exc_info=e)
return {
"status": "unhealthy",
"error": str(e),
"error_type": type(e).__name__
}
async def get_system_info(self) -> Dict[str, Any]:
"""Get system information for health reporting."""
import psutil
import platform
try:
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return {
"platform": platform.platform(),
"python_version": platform.python_version(),
"cpu": {
"count": psutil.cpu_count(),
"usage_percent": psutil.cpu_percent(interval=1)
},
"memory": {
"total_mb": round(memory.total / 1024 / 1024),
"available_mb": round(memory.available / 1024 / 1024),
"usage_percent": memory.percent
},
"disk": {
"total_gb": round(disk.total / 1024 / 1024 / 1024),
"free_gb": round(disk.free / 1024 / 1024 / 1024),
"usage_percent": round((disk.used / disk.total) * 100, 1)
}
}
except Exception as e:
logger.error("Failed to get system info", exc_info=e)
return {
"error": str(e),
"error_type": type(e).__name__
}
async def perform_full_health_check(self) -> Dict[str, Any]:
"""Perform comprehensive health check."""
logger.info("Starting comprehensive health check")
checks = {
"database": await self.check_database(),
"ai_provider": await self.check_ai_provider(),
"discord_api": await self.check_discord_connectivity(),
"system": await self.get_system_info()
}
# Determine overall status
overall_status = "healthy"
for check_name, check_result in checks.items():
if check_name == "system":
continue # System info doesn't affect health status
status = check_result.get("status", "unknown")
if status in ["unhealthy", "error"]:
overall_status = "unhealthy"
break
elif status == "degraded" and overall_status == "healthy":
overall_status = "degraded"
result = {
"status": overall_status,
"timestamp": asyncio.get_event_loop().time(),
"checks": checks,
"configuration": {
"ai_provider": self.settings.ai_provider,
"log_level": self.settings.log_level,
"database_pool": {
"min": self.settings.database_pool_min,
"max": self.settings.database_pool_max
}
}
}
logger.info("Health check completed", extra={"overall_status": overall_status})
return result
async def main():
"""CLI health check command."""
import argparse
parser = argparse.ArgumentParser(description="GuardDen Health Check")
parser.add_argument("--check", action="store_true", help="Perform health check and exit")
parser.add_argument("--json", action="store_true", help="Output in JSON format")
args = parser.parse_args()
if args.check:
# Set up minimal logging for health check
logging.basicConfig(level=logging.WARNING)
health_checker = HealthChecker()
result = await health_checker.perform_full_health_check()
if args.json:
import json
print(json.dumps(result, indent=2))
else:
print(f"Overall Status: {result['status'].upper()}")
for check_name, check_result in result["checks"].items():
status = check_result.get("status", "unknown")
print(f" {check_name}: {status}")
if "response_time_ms" in check_result:
print(f" Response time: {check_result['response_time_ms']}ms")
if "error" in check_result:
print(f" Error: {check_result['error']}")
# Exit with non-zero code if unhealthy
if result["status"] != "healthy":
sys.exit(1)
else:
print("Use --check to perform health check")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,15 +1,19 @@
"""Database models for GuardDen."""
from guardden.models.analytics import AICheck, MessageActivity, UserActivity
from guardden.models.base import Base
from guardden.models.guild import BannedWord, Guild, GuildSettings
from guardden.models.moderation import ModerationLog, Strike, UserNote
__all__ = [
"AICheck",
"Base",
"BannedWord",
"Guild",
"GuildSettings",
"BannedWord",
"MessageActivity",
"ModerationLog",
"Strike",
"UserActivity",
"UserNote",
]

View File

@@ -0,0 +1,86 @@
"""Analytics models for tracking bot usage and performance."""
from datetime import datetime
from sqlalchemy import BigInteger, Boolean, DateTime, Float, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from guardden.models.base import Base, SnowflakeID, TimestampMixin
class AICheck(Base, TimestampMixin):
"""Record of AI moderation checks."""
__tablename__ = "ai_checks"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
channel_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
message_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
# Check result
flagged: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
category: Mapped[str | None] = mapped_column(String(50), nullable=True)
severity: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# Performance metrics
response_time_ms: Mapped[float] = mapped_column(Float, nullable=False)
provider: Mapped[str] = mapped_column(String(20), nullable=False)
# False positive tracking (set by moderators)
is_false_positive: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, index=True
)
reviewed_by: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
class MessageActivity(Base):
"""Daily message activity statistics per guild."""
__tablename__ = "message_activity"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
date: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
# Activity counts
total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
active_users: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
new_joins: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# Moderation activity
automod_triggers: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
ai_checks: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
manual_actions: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
class UserActivity(Base, TimestampMixin):
"""Track user activity and first/last seen timestamps."""
__tablename__ = "user_activity"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
# User information
username: Mapped[str] = mapped_column(String(100), nullable=False)
# Activity timestamps
first_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
last_message: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
# Activity counts
message_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
command_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# Moderation stats
strike_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
warning_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
kick_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
ban_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
timeout_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)

View File

@@ -3,7 +3,7 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -66,6 +66,15 @@ class GuildSettings(Base, TimestampMixin):
anti_spam_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
link_filter_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
# Automod thresholds
message_rate_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
message_rate_window: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
duplicate_threshold: Mapped[int] = mapped_column(Integer, default=3, nullable=False)
mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False)
mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
scam_allowlist: Mapped[list[str]] = mapped_column(JSONB, default=list, nullable=False)
# Strike thresholds (actions at each threshold)
strike_actions: Mapped[dict] = mapped_column(
JSONB,
@@ -81,6 +90,8 @@ class GuildSettings(Base, TimestampMixin):
# AI moderation settings
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
# Verification settings

View File

@@ -1,9 +1,12 @@
"""Services for GuardDen."""
from guardden.services.automod import AutomodService
from guardden.services.database import Database
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
from guardden.services.verification import ChallengeType, VerificationService
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from guardden.services.automod import AutomodService
from guardden.services.database import Database
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
from guardden.services.verification import ChallengeType, VerificationService
__all__ = [
"AutomodService",
@@ -14,3 +17,23 @@ __all__ = [
"get_rate_limiter",
"ratelimit",
]
_LAZY_ATTRS = {
"AutomodService": ("guardden.services.automod", "AutomodService"),
"Database": ("guardden.services.database", "Database"),
"RateLimiter": ("guardden.services.ratelimit", "RateLimiter"),
"get_rate_limiter": ("guardden.services.ratelimit", "get_rate_limiter"),
"ratelimit": ("guardden.services.ratelimit", "ratelimit"),
"ChallengeType": ("guardden.services.verification", "ChallengeType"),
"VerificationService": ("guardden.services.verification", "VerificationService"),
}
def __getattr__(name: str):
if name in _LAZY_ATTRS:
module_path, attr = _LAZY_ATTRS[name]
module = __import__(module_path, fromlist=[attr])
value = getattr(module, attr)
globals()[name] = value
return value
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

View File

@@ -5,10 +5,11 @@ from typing import Any
from guardden.services.ai.base import (
AIProvider,
ContentCategory,
ImageAnalysisResult,
ModerationResult,
PhishingAnalysisResult,
parse_categories,
run_with_retries,
)
logger = logging.getLogger(__name__)
@@ -96,7 +97,7 @@ class AnthropicProvider(AIProvider):
async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str:
"""Make an API call to Claude."""
try:
async def _request() -> str:
message = await self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
@@ -104,6 +105,13 @@ class AnthropicProvider(AIProvider):
messages=[{"role": "user", "content": user_content}],
)
return message.content[0].text
try:
return await run_with_retries(
_request,
logger=logger,
operation_name="Anthropic API call",
)
except Exception as e:
logger.error(f"Anthropic API error: {e}")
raise
@@ -145,11 +153,7 @@ class AnthropicProvider(AIProvider):
response = await self._call_api(system, user_message)
data = self._parse_json_response(response)
categories = [
ContentCategory(cat)
for cat in data.get("categories", [])
if cat in ContentCategory.__members__.values()
]
categories = parse_categories(data.get("categories", []))
return ModerationResult(
is_flagged=data.get("is_flagged", False),

View File

@@ -1,9 +1,12 @@
"""Base classes for AI providers."""
import asyncio
import logging
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal
from typing import Literal, TypeVar
class ContentCategory(str, Enum):
@@ -20,6 +23,64 @@ class ContentCategory(str, Enum):
MISINFORMATION = "misinformation"
_T = TypeVar("_T")
@dataclass(frozen=True)
class RetryConfig:
"""Retry configuration for AI calls."""
retries: int = 3
base_delay: float = 0.25
max_delay: float = 2.0
def parse_categories(values: list[str]) -> list[ContentCategory]:
"""Parse category values into ContentCategory enums."""
categories: list[ContentCategory] = []
for value in values:
try:
categories.append(ContentCategory(value))
except ValueError:
continue
return categories
async def run_with_retries(
operation: Callable[[], Awaitable[_T]],
*,
config: RetryConfig | None = None,
logger: logging.Logger | None = None,
operation_name: str = "AI call",
) -> _T:
"""Run an async operation with retries and backoff."""
retry_config = config or RetryConfig()
delay = retry_config.base_delay
last_error: Exception | None = None
for attempt in range(1, retry_config.retries + 1):
try:
return await operation()
except Exception as error: # noqa: BLE001 - we re-raise after retries
last_error = error
if attempt >= retry_config.retries:
raise
if logger:
logger.warning(
"%s failed (attempt %s/%s): %s",
operation_name,
attempt,
retry_config.retries,
error,
)
await asyncio.sleep(delay)
delay = min(retry_config.max_delay, delay * 2)
if last_error:
raise last_error
raise RuntimeError("Retry loop exited unexpectedly")
@dataclass
class ModerationResult:
"""Result of AI content moderation."""

View File

@@ -9,6 +9,7 @@ from guardden.services.ai.base import (
ImageAnalysisResult,
ModerationResult,
PhishingAnalysisResult,
run_with_retries,
)
logger = logging.getLogger(__name__)
@@ -41,7 +42,7 @@ class OpenAIProvider(AIProvider):
max_tokens: int = 500,
) -> str:
"""Make an API call to OpenAI."""
try:
async def _request() -> str:
response = await self.client.chat.completions.create(
model=self.model,
max_tokens=max_tokens,
@@ -52,6 +53,13 @@ class OpenAIProvider(AIProvider):
response_format={"type": "json_object"},
)
return response.choices[0].message.content or ""
try:
return await run_with_retries(
_request,
logger=logger,
operation_name="OpenAI chat completion",
)
except Exception as e:
logger.error(f"OpenAI API error: {e}")
raise
@@ -71,7 +79,14 @@ class OpenAIProvider(AIProvider):
"""Analyze text content for policy violations."""
# First, use OpenAI's built-in moderation API for quick check
try:
mod_response = await self.client.moderations.create(input=content)
async def _moderate() -> Any:
return await self.client.moderations.create(input=content)
mod_response = await run_with_retries(
_moderate,
logger=logger,
operation_name="OpenAI moderation",
)
results = mod_response.results[0]
# Map OpenAI categories to our categories
@@ -142,20 +157,27 @@ class OpenAIProvider(AIProvider):
sensitivity_note = " Be strict - flag suggestive content."
try:
response = await self.client.chat.completions.create(
model="gpt-4o-mini", # Use vision-capable model
max_tokens=500,
messages=[
{"role": "system", "content": system + sensitivity_note},
{
"role": "user",
"content": [
{"type": "text", "text": "Analyze this image for moderation."},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
],
response_format={"type": "json_object"},
async def _request() -> Any:
return await self.client.chat.completions.create(
model="gpt-4o-mini", # Use vision-capable model
max_tokens=500,
messages=[
{"role": "system", "content": system + sensitivity_note},
{
"role": "user",
"content": [
{"type": "text", "text": "Analyze this image for moderation."},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
],
response_format={"type": "json_object"},
)
response = await run_with_retries(
_request,
logger=logger,
operation_name="OpenAI image analysis",
)
data = self._parse_json_response(response.choices[0].message.content or "{}")

View File

@@ -2,17 +2,150 @@
import logging
import re
import signal
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import NamedTuple
from typing import NamedTuple, Sequence, TYPE_CHECKING
from urllib.parse import urlparse
import discord
if TYPE_CHECKING:
import discord
else:
try:
import discord # type: ignore
except ModuleNotFoundError: # pragma: no cover
class _DiscordStub:
class Message: # minimal stub for type hints
pass
from guardden.models import BannedWord
discord = _DiscordStub() # type: ignore
from guardden.models.guild import BannedWord
logger = logging.getLogger(__name__)
# Circuit breaker for regex safety
class RegexTimeoutError(Exception):
"""Raised when regex execution takes too long."""
pass
class RegexCircuitBreaker:
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
def __init__(self, timeout_seconds: float = 0.1):
self.timeout_seconds = timeout_seconds
self.failed_patterns: dict[str, datetime] = {}
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
def _timeout_handler(self, signum, frame):
"""Signal handler for regex timeout."""
raise RegexTimeoutError("Regex execution timed out")
def is_pattern_disabled(self, pattern: str) -> bool:
"""Check if a pattern is temporarily disabled due to timeouts."""
if pattern not in self.failed_patterns:
return False
failure_time = self.failed_patterns[pattern]
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
# Re-enable the pattern after threshold time
del self.failed_patterns[pattern]
return False
return True
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
"""Safely execute regex search with timeout protection."""
if self.is_pattern_disabled(pattern):
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
return False
# Basic pattern validation to catch obviously problematic patterns
if self._is_dangerous_pattern(pattern):
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
return False
old_handler = None
try:
# Set up timeout signal (Unix systems only)
if hasattr(signal, 'SIGALRM'):
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
start_time = time.perf_counter()
# Compile and execute regex
compiled_pattern = re.compile(pattern, flags)
result = bool(compiled_pattern.search(text))
execution_time = time.perf_counter() - start_time
# Log slow patterns for monitoring
if execution_time > self.timeout_seconds * 0.8:
logger.warning(
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
)
return result
except RegexTimeoutError:
# Pattern took too long, disable it temporarily
self.failed_patterns[pattern] = datetime.now(timezone.utc)
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
return False
except re.error as e:
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
return False
except Exception as e:
logger.error(f"Unexpected error in regex execution: {e}")
return False
finally:
# Clean up timeout signal
if hasattr(signal, 'SIGALRM') and old_handler is not None:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
def _is_dangerous_pattern(self, pattern: str) -> bool:
"""Basic heuristic to detect potentially dangerous regex patterns."""
# Check for patterns that are commonly problematic
dangerous_indicators = [
r'(\w+)+', # Nested quantifiers
r'(\d+)+', # Nested quantifiers on digits
r'(.+)+', # Nested quantifiers on anything
r'(.*)+', # Nested quantifiers on anything (greedy)
r'(\w*)+', # Nested quantifiers with *
r'(\S+)+', # Nested quantifiers on non-whitespace
]
# Check for excessively long patterns
if len(pattern) > 500:
return True
# Check for nested quantifiers (simplified detection)
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
return True
# Check for excessive repetition operators
if pattern.count('+') > 10 or pattern.count('*') > 10:
return True
# Check for specific dangerous patterns
for dangerous in dangerous_indicators:
if dangerous in pattern:
return True
return False
# Global circuit breaker instance
_regex_circuit_breaker = RegexCircuitBreaker()
# Known scam/phishing patterns
SCAM_PATTERNS = [
@@ -47,10 +180,10 @@ SUSPICIOUS_TLDS = {
".gq",
}
# URL pattern for extraction
# URL pattern for extraction - more restrictive for security
URL_PATTERN = re.compile(
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
r"https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?|"
r"(?:www\.)?[a-zA-Z0-9-]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|gov|edu)(?:/[^\s]*)?",
re.IGNORECASE,
)
@@ -91,6 +224,66 @@ class AutomodResult:
matched_filter: str = ""
@dataclass(frozen=True)
class SpamConfig:
"""Configuration for spam thresholds."""
message_rate_limit: int = 5
message_rate_window: int = 5
duplicate_threshold: int = 3
mention_limit: int = 5
mention_rate_limit: int = 10
mention_rate_window: int = 60
def normalize_domain(value: str) -> str:
"""Normalize a domain or URL for allowlist checks with security validation."""
if not value or not isinstance(value, str):
return ""
text = value.strip().lower()
if not text or len(text) > 2000: # Prevent excessively long URLs
return ""
# Sanitize input to prevent injection attacks
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
return ""
try:
if "://" not in text:
text = f"http://{text}"
parsed = urlparse(text)
hostname = parsed.hostname or ""
# Additional validation for hostname
if not hostname or len(hostname) > 253: # RFC limit
return ""
# Check for malicious patterns
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
return ""
# Remove www prefix
if hostname.startswith("www."):
hostname = hostname[4:]
return hostname
except (ValueError, UnicodeError, Exception):
# urlparse can raise various exceptions with malicious input
return ""
def is_allowed_domain(hostname: str, allowlist: set[str]) -> bool:
"""Check if a hostname is allowlisted."""
if not hostname:
return False
for domain in allowlist:
if hostname == domain or hostname.endswith(f".{domain}"):
return True
return False
class AutomodService:
"""Service for automatic content moderation."""
@@ -104,23 +297,25 @@ class AutomodService:
lambda: defaultdict(UserSpamTracker)
)
# Spam thresholds
self.message_rate_limit = 5 # messages per window
self.message_rate_window = 5 # seconds
self.duplicate_threshold = 3 # same message count
self.mention_limit = 5 # mentions per message
self.mention_rate_limit = 10 # mentions per window
self.mention_rate_window = 60 # seconds
# Default spam thresholds
self.default_spam_config = SpamConfig()
def _get_content_hash(self, content: str) -> str:
"""Get a normalized hash of message content for duplicate detection."""
# Normalize: lowercase, remove extra spaces, remove special chars
normalized = re.sub(r"[^\w\s]", "", content.lower())
normalized = re.sub(r"\s+", " ", normalized).strip()
# Use simple string operations for basic patterns to avoid regex overhead
normalized = content.lower()
# Remove special characters (simplified approach)
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
# Normalize whitespace
normalized = ' '.join(normalized.split())
return normalized
def check_banned_words(
self, content: str, banned_words: list[BannedWord]
self, content: str, banned_words: Sequence[BannedWord]
) -> AutomodResult | None:
"""Check message against banned words list."""
content_lower = content.lower()
@@ -129,12 +324,9 @@ class AutomodService:
matched = False
if banned.is_regex:
try:
if re.search(banned.pattern, content, re.IGNORECASE):
matched = True
except re.error:
logger.warning(f"Invalid regex pattern: {banned.pattern}")
continue
# Use circuit breaker for safe regex execution
if _regex_circuit_breaker.safe_regex_search(banned.pattern, content, re.IGNORECASE):
matched = True
else:
if banned.pattern.lower() in content_lower:
matched = True
@@ -155,7 +347,9 @@ class AutomodService:
return None
def check_scam_links(self, content: str) -> AutomodResult | None:
def check_scam_links(
self, content: str, allowlist: list[str] | None = None
) -> AutomodResult | None:
"""Check message for scam/phishing patterns."""
# Check for known scam patterns
for pattern in self._scam_patterns:
@@ -167,10 +361,25 @@ class AutomodService:
matched_filter="scam_pattern",
)
allowlist_set = {normalize_domain(domain) for domain in allowlist or [] if domain}
# Check URLs for suspicious TLDs
urls = URL_PATTERN.findall(content)
for url in urls:
# Limit URL length to prevent processing extremely long URLs
if len(url) > 2000:
continue
url_lower = url.lower()
hostname = normalize_domain(url)
# Skip if hostname normalization failed (security check)
if not hostname:
continue
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
continue
for tld in SUSPICIOUS_TLDS:
if tld in url_lower:
# Additional check: is it trying to impersonate a known domain?
@@ -194,12 +403,21 @@ class AutomodService:
return None
def check_spam(
self, message: discord.Message, anti_spam_enabled: bool = True
self,
message: discord.Message,
anti_spam_enabled: bool = True,
spam_config: SpamConfig | None = None,
) -> AutomodResult | None:
"""Check message for spam behavior."""
if not anti_spam_enabled:
return None
# Skip DM messages
if message.guild is None:
return None
config = spam_config or self.default_spam_config
guild_id = message.guild.id
user_id = message.author.id
tracker = self._spam_trackers[guild_id][user_id]
@@ -213,21 +431,24 @@ class AutomodService:
tracker.messages.append(SpamRecord(content_hash, now))
# Rate limit check
recent_window = now - timedelta(seconds=self.message_rate_window)
recent_window = now - timedelta(seconds=config.message_rate_window)
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
if len(recent_messages) > self.message_rate_limit:
if len(recent_messages) > config.message_rate_limit:
return AutomodResult(
should_delete=True,
should_timeout=True,
timeout_duration=60, # 1 minute timeout
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
reason=(
f"Sending messages too fast ({len(recent_messages)} in "
f"{config.message_rate_window}s)"
),
matched_filter="rate_limit",
)
# Duplicate message check
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
if duplicate_count >= self.duplicate_threshold:
if duplicate_count >= config.duplicate_threshold:
return AutomodResult(
should_delete=True,
should_warn=True,
@@ -240,7 +461,7 @@ class AutomodService:
if message.mention_everyone:
mention_count += 100 # Treat @everyone as many mentions
if mention_count > self.mention_limit:
if mention_count > config.mention_limit:
return AutomodResult(
should_delete=True,
should_timeout=True,
@@ -249,6 +470,26 @@ class AutomodService:
matched_filter="mass_mention",
)
if mention_count > 0:
if tracker.last_mention_time:
window = timedelta(seconds=config.mention_rate_window)
if now - tracker.last_mention_time > window:
tracker.mention_count = 0
tracker.mention_count += mention_count
tracker.last_mention_time = now
if tracker.mention_count > config.mention_rate_limit:
return AutomodResult(
should_delete=True,
should_timeout=True,
timeout_duration=300,
reason=(
"Too many mentions in a short period "
f"({tracker.mention_count} in {config.mention_rate_window}s)"
),
matched_filter="mention_rate",
)
return None
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:

View File

@@ -0,0 +1,155 @@
"""Redis caching service for improved performance."""
import asyncio
import json
import logging
from typing import Any, TypeVar
logger = logging.getLogger(__name__)
T = TypeVar("T")
class CacheService:
"""Service for caching data with Redis (optional) or in-memory fallback."""
def __init__(self, redis_url: str | None = None) -> None:
self.redis_url = redis_url
self._redis_client: Any = None
self._memory_cache: dict[str, tuple[Any, float]] = {}
self._lock = asyncio.Lock()
async def initialize(self) -> None:
"""Initialize Redis connection if URL is provided."""
if not self.redis_url:
logger.info("Redis URL not configured, using in-memory cache")
return
try:
import redis.asyncio as aioredis
self._redis_client = await aioredis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True,
)
# Test connection
await self._redis_client.ping()
logger.info("Redis cache initialized successfully")
except ImportError:
logger.warning("redis package not installed, using in-memory cache")
except Exception as e:
logger.error("Failed to connect to Redis: %s, using in-memory cache", e)
self._redis_client = None
async def close(self) -> None:
"""Close Redis connection."""
if self._redis_client:
await self._redis_client.close()
async def get(self, key: str) -> Any | None:
"""Get a value from cache."""
if self._redis_client:
try:
value = await self._redis_client.get(key)
if value:
return json.loads(value)
return None
except Exception as e:
logger.error("Redis get error for key %s: %s", key, e)
return None
else:
# In-memory fallback
async with self._lock:
if key in self._memory_cache:
value, expiry = self._memory_cache[key]
if expiry == 0 or asyncio.get_event_loop().time() < expiry:
return value
else:
del self._memory_cache[key]
return None
async def set(self, key: str, value: Any, ttl: int = 300) -> bool:
"""Set a value in cache with TTL in seconds."""
if self._redis_client:
try:
serialized = json.dumps(value)
await self._redis_client.set(key, serialized, ex=ttl)
return True
except Exception as e:
logger.error("Redis set error for key %s: %s", key, e)
return False
else:
# In-memory fallback
async with self._lock:
expiry = asyncio.get_event_loop().time() + ttl if ttl > 0 else 0
self._memory_cache[key] = (value, expiry)
return True
async def delete(self, key: str) -> bool:
"""Delete a value from cache."""
if self._redis_client:
try:
await self._redis_client.delete(key)
return True
except Exception as e:
logger.error("Redis delete error for key %s: %s", key, e)
return False
else:
async with self._lock:
if key in self._memory_cache:
del self._memory_cache[key]
return True
async def clear_pattern(self, pattern: str) -> int:
"""Clear all keys matching a pattern."""
if self._redis_client:
try:
keys = []
async for key in self._redis_client.scan_iter(match=pattern):
keys.append(key)
if keys:
await self._redis_client.delete(*keys)
return len(keys)
except Exception as e:
logger.error("Redis clear pattern error for %s: %s", pattern, e)
return 0
else:
# In-memory fallback
async with self._lock:
import fnmatch
keys_to_delete = [
key for key in self._memory_cache.keys() if fnmatch.fnmatch(key, pattern)
]
for key in keys_to_delete:
del self._memory_cache[key]
return len(keys_to_delete)
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
if self._redis_client:
return {"type": "redis", "url": self.redis_url}
else:
return {
"type": "memory",
"size": len(self._memory_cache),
}
# Global cache instance
_cache_service: CacheService | None = None
def get_cache_service() -> CacheService:
"""Get the global cache service instance."""
global _cache_service
if _cache_service is None:
_cache_service = CacheService()
return _cache_service
def set_cache_service(service: CacheService) -> None:
"""Set the global cache service instance."""
global _cache_service
_cache_service = service

View File

@@ -1,30 +1,43 @@
"""Guild configuration service."""
import logging
from functools import lru_cache
import discord
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from guardden.models import BannedWord, Guild, GuildSettings
from guardden.services.cache import CacheService, get_cache_service
from guardden.services.database import Database
logger = logging.getLogger(__name__)
class GuildConfigService:
"""Manages guild configurations with caching."""
"""Manages guild configurations with multi-tier caching."""
def __init__(self, database: Database) -> None:
def __init__(self, database: Database, cache: CacheService | None = None) -> None:
self.database = database
self._cache: dict[int, GuildSettings] = {}
self.cache = cache or get_cache_service()
self._memory_cache: dict[int, GuildSettings] = {}
self._cache_ttl = 300 # 5 minutes
async def get_config(self, guild_id: int) -> GuildSettings | None:
"""Get guild configuration, using cache if available."""
if guild_id in self._cache:
return self._cache[guild_id]
"""Get guild configuration, using multi-tier cache."""
# Check memory cache first
if guild_id in self._memory_cache:
return self._memory_cache[guild_id]
# Check Redis cache
cache_key = f"guild_config:{guild_id}"
cached_data = await self.cache.get(cache_key)
if cached_data:
# Store in memory cache for faster access
settings = GuildSettings(**cached_data)
self._memory_cache[guild_id] = settings
return settings
# Fetch from database
async with self.database.session() as session:
result = await session.execute(
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
@@ -32,7 +45,19 @@ class GuildConfigService:
settings = result.scalar_one_or_none()
if settings:
self._cache[guild_id] = settings
# Store in both caches
self._memory_cache[guild_id] = settings
# Serialize settings for Redis
settings_dict = {
"guild_id": settings.guild_id,
"prefix": settings.prefix,
"log_channel_id": settings.log_channel_id,
"automod_enabled": settings.automod_enabled,
"ai_moderation_enabled": settings.ai_moderation_enabled,
"ai_sensitivity": settings.ai_sensitivity,
# Add other fields as needed
}
await self.cache.set(cache_key, settings_dict, ttl=self._cache_ttl)
return settings
@@ -94,9 +119,11 @@ class GuildConfigService:
return settings
def invalidate_cache(self, guild_id: int) -> None:
"""Remove a guild from the cache."""
self._cache.pop(guild_id, None)
async def invalidate_cache(self, guild_id: int) -> None:
"""Remove a guild from all caches."""
self._memory_cache.pop(guild_id, None)
cache_key = f"guild_config:{guild_id}"
await self.cache.delete(cache_key)
async def get_banned_words(self, guild_id: int) -> list[BannedWord]:
"""Get all banned words for a guild."""

View File

@@ -5,6 +5,7 @@ from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from functools import wraps
from typing import Callable
logger = logging.getLogger(__name__)
@@ -211,6 +212,23 @@ class RateLimiter:
bucket_key=bucket_key,
)
def acquire_command(
self,
command_name: str,
user_id: int | None = None,
guild_id: int | None = None,
channel_id: int | None = None,
) -> RateLimitResult:
"""Acquire a per-command rate limit slot."""
action = f"command:{command_name}"
if action not in self._configs:
base = self._configs.get("command", RateLimitConfig(5, 10, RateLimitScope.MEMBER))
self.configure(
action,
RateLimitConfig(base.max_requests, base.window_seconds, base.scope),
)
return self.acquire(action, user_id=user_id, guild_id=guild_id, channel_id=channel_id)
def reset(
self,
action: str,
@@ -266,6 +284,7 @@ def ratelimit(
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, ctx, *args, **kwargs):
limiter = get_rate_limiter()
@@ -292,9 +311,6 @@ def ratelimit(
return await func(self, ctx, *args, **kwargs)
# Preserve function metadata
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper
return decorator

View File

@@ -10,8 +10,6 @@ from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any
import discord
logger = logging.getLogger(__name__)
@@ -217,6 +215,28 @@ class EmojiChallengeGenerator(ChallengeGenerator):
return names.get(emoji, "correct")
class QuestionsChallengeGenerator(ChallengeGenerator):
"""Generates custom question challenges."""
DEFAULT_QUESTIONS = [
("What color is the sky on a clear day?", "blue"),
("Type the word 'verified' to continue.", "verified"),
("What is 2 + 2?", "4"),
("What planet do we live on?", "earth"),
]
def __init__(self, questions: list[tuple[str, str]] | None = None) -> None:
self.questions = questions or self.DEFAULT_QUESTIONS
def generate(self) -> Challenge:
question, answer = random.choice(self.questions)
return Challenge(
challenge_type=ChallengeType.QUESTIONS,
question=question,
answer=answer,
)
class VerificationService:
"""Service for managing member verification."""
@@ -230,6 +250,7 @@ class VerificationService:
ChallengeType.CAPTCHA: CaptchaChallengeGenerator(),
ChallengeType.MATH: MathChallengeGenerator(),
ChallengeType.EMOJI: EmojiChallengeGenerator(),
ChallengeType.QUESTIONS: QuestionsChallengeGenerator(),
}
def create_challenge(

View File

@@ -1,5 +1,30 @@
"""Utility functions for GuardDen."""
from datetime import timedelta
from guardden.utils.logging import setup_logging
__all__ = ["setup_logging"]
def parse_duration(duration_str: str) -> timedelta | None:
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
import re
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
if not match:
return None
amount = int(match.group(1))
unit = match.group(2)
units = {
"s": timedelta(seconds=amount),
"m": timedelta(minutes=amount),
"h": timedelta(hours=amount),
"d": timedelta(days=amount),
"w": timedelta(weeks=amount),
}
return units.get(unit)
__all__ = ["parse_duration", "setup_logging"]

View File

@@ -1,27 +1,294 @@
"""Logging configuration for GuardDen."""
"""Structured logging utilities for GuardDen."""
import json
import logging
import sys
from typing import Literal
from datetime import datetime, timezone
from typing import Any, Dict, Literal
try:
import structlog
from structlog.contextvars import bind_contextvars, clear_contextvars, unbind_contextvars
from structlog.stdlib import BoundLogger
STRUCTLOG_AVAILABLE = True
except ImportError:
STRUCTLOG_AVAILABLE = False
# Fallback types when structlog is not available
BoundLogger = logging.Logger
def setup_logging(level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO") -> None:
"""Configure logging for the application."""
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
class JSONFormatter(logging.Formatter):
"""Custom JSON formatter for structured logging."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def format(self, record: logging.LogRecord) -> str:
"""Format log record as JSON."""
log_data = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"module": record.module,
"function": record.funcName,
"line": record.lineno,
}
# Add exception information if present
if record.exc_info:
log_data["exception"] = {
"type": record.exc_info[0].__name__ if record.exc_info[0] else None,
"message": str(record.exc_info[1]) if record.exc_info[1] else None,
"traceback": self.formatException(record.exc_info) if record.exc_info else None,
}
# Add extra fields from the record
extra_fields = {}
for key, value in record.__dict__.items():
if key not in {
'name', 'msg', 'args', 'levelname', 'levelno', 'pathname', 'filename',
'module', 'lineno', 'funcName', 'created', 'msecs', 'relativeCreated',
'thread', 'threadName', 'processName', 'process', 'getMessage',
'exc_info', 'exc_text', 'stack_info', 'message'
}:
extra_fields[key] = value
if extra_fields:
log_data["extra"] = extra_fields
return json.dumps(log_data, default=str, ensure_ascii=False)
# Configure root logger
logging.basicConfig(
level=getattr(logging, level),
format=log_format,
datefmt=date_format,
handlers=[logging.StreamHandler(sys.stdout)],
)
# Reduce noise from third-party libraries
logging.getLogger("discord").setLevel(logging.WARNING)
logging.getLogger("discord.http").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(
logging.DEBUG if level == "DEBUG" else logging.WARNING
)
class GuardDenLogger:
"""Custom logger configuration for GuardDen."""
def __init__(self, level: str = "INFO", json_format: bool = False):
self.level = level.upper()
self.json_format = json_format
self.configure_logging()
def configure_logging(self) -> None:
"""Configure structured logging for the application."""
# Clear any existing configuration
logging.root.handlers.clear()
if STRUCTLOG_AVAILABLE and self.json_format:
self._configure_structlog()
else:
self._configure_stdlib_logging()
# Configure specific loggers
self._configure_library_loggers()
def _configure_structlog(self) -> None:
"""Configure structlog for structured logging."""
structlog.configure(
processors=[
# Add context variables to log entries
structlog.contextvars.merge_contextvars,
# Add log level to event dict
structlog.stdlib.filter_by_level,
# Add logger name to event dict
structlog.stdlib.add_logger_name,
# Add log level to event dict
structlog.stdlib.add_log_level,
# Perform %-style formatting
structlog.stdlib.PositionalArgumentsFormatter(),
# Add timestamp
structlog.processors.TimeStamper(fmt="iso"),
# Add stack info when requested
structlog.processors.StackInfoRenderer(),
# Format exceptions
structlog.processors.format_exc_info,
# Unicode-encode strings
structlog.processors.UnicodeDecoder(),
# Pass to stdlib logging
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
wrapper_class=structlog.stdlib.BoundLogger,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
# Configure stdlib logging with JSON formatter
handler = logging.StreamHandler(sys.stdout)
formatter = JSONFormatter()
handler.setFormatter(formatter)
# Set up root logger
root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(getattr(logging, self.level))
def _configure_stdlib_logging(self) -> None:
"""Configure standard library logging."""
if self.json_format:
handler = logging.StreamHandler(sys.stdout)
formatter = JSONFormatter()
else:
# Use traditional format for development
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(log_format, datefmt=date_format)
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=getattr(logging, self.level),
handlers=[handler],
)
def _configure_library_loggers(self) -> None:
"""Configure logging levels for third-party libraries."""
# Discord.py can be quite verbose
logging.getLogger("discord").setLevel(logging.WARNING)
logging.getLogger("discord.http").setLevel(logging.WARNING)
logging.getLogger("discord.gateway").setLevel(logging.WARNING)
logging.getLogger("discord.client").setLevel(logging.WARNING)
# SQLAlchemy logging
logging.getLogger("sqlalchemy.engine").setLevel(
logging.DEBUG if self.level == "DEBUG" else logging.WARNING
)
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
# HTTP libraries
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
# Other libraries
logging.getLogger("asyncio").setLevel(logging.WARNING)
def get_logger(name: str) -> BoundLogger:
"""Get a structured logger instance."""
if STRUCTLOG_AVAILABLE:
return structlog.get_logger(name)
else:
return logging.getLogger(name)
def bind_context(**kwargs: Any) -> None:
"""Bind context variables for structured logging."""
if STRUCTLOG_AVAILABLE:
bind_contextvars(**kwargs)
def unbind_context(*keys: str) -> None:
"""Unbind specific context variables."""
if STRUCTLOG_AVAILABLE:
unbind_contextvars(*keys)
def clear_context() -> None:
"""Clear all context variables."""
if STRUCTLOG_AVAILABLE:
clear_contextvars()
class LoggingMiddleware:
"""Middleware for logging Discord bot events and commands."""
def __init__(self, logger: BoundLogger):
self.logger = logger
def log_command_start(self, ctx, command_name: str) -> None:
"""Log when a command starts."""
bind_context(
command=command_name,
user_id=ctx.author.id,
user_name=str(ctx.author),
guild_id=ctx.guild.id if ctx.guild else None,
guild_name=ctx.guild.name if ctx.guild else None,
channel_id=ctx.channel.id,
channel_name=getattr(ctx.channel, 'name', 'DM'),
)
if hasattr(self.logger, 'info'):
self.logger.info(
"Command started",
extra={
"command": command_name,
"args": ctx.args if hasattr(ctx, 'args') else None,
}
)
def log_command_success(self, ctx, command_name: str, duration: float) -> None:
"""Log successful command completion."""
if hasattr(self.logger, 'info'):
self.logger.info(
"Command completed successfully",
extra={
"command": command_name,
"duration_ms": round(duration * 1000, 2),
}
)
def log_command_error(self, ctx, command_name: str, error: Exception, duration: float) -> None:
"""Log command errors."""
if hasattr(self.logger, 'error'):
self.logger.error(
"Command failed",
exc_info=error,
extra={
"command": command_name,
"error_type": type(error).__name__,
"error_message": str(error),
"duration_ms": round(duration * 1000, 2),
}
)
def log_moderation_action(
self,
action: str,
target_id: int,
target_name: str,
moderator_id: int,
moderator_name: str,
guild_id: int,
reason: str = None,
duration: int = None,
**extra: Any,
) -> None:
"""Log moderation actions."""
if hasattr(self.logger, 'info'):
self.logger.info(
"Moderation action performed",
extra={
"action": action,
"target_id": target_id,
"target_name": target_name,
"moderator_id": moderator_id,
"moderator_name": moderator_name,
"guild_id": guild_id,
"reason": reason,
"duration_seconds": duration,
**extra,
}
)
# Global logging middleware instance
_logging_middleware: LoggingMiddleware = None
def get_logging_middleware() -> LoggingMiddleware:
"""Get the global logging middleware instance."""
global _logging_middleware
if _logging_middleware is None:
logger = get_logger("guardden.middleware")
_logging_middleware = LoggingMiddleware(logger)
return _logging_middleware
def setup_logging(
level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO",
json_format: bool = False
) -> None:
"""Set up logging for the GuardDen application."""
GuardDenLogger(level=level, json_format=json_format)

View File

@@ -0,0 +1,328 @@
"""Prometheus metrics utilities for GuardDen."""
import time
from functools import wraps
from typing import Dict, Optional, Any
try:
from prometheus_client import Counter, Histogram, Gauge, Info, start_http_server, CollectorRegistry, REGISTRY
PROMETHEUS_AVAILABLE = True
except ImportError:
PROMETHEUS_AVAILABLE = False
# Mock objects when Prometheus client is not available
class MockMetric:
def inc(self, *args, **kwargs): pass
def observe(self, *args, **kwargs): pass
def set(self, *args, **kwargs): pass
def info(self, *args, **kwargs): pass
Counter = Histogram = Gauge = Info = MockMetric
CollectorRegistry = REGISTRY = None
class GuardDenMetrics:
"""Centralized metrics collection for GuardDen."""
def __init__(self, registry: Optional[CollectorRegistry] = None):
self.registry = registry or REGISTRY
self.enabled = PROMETHEUS_AVAILABLE
if not self.enabled:
return
# Bot metrics
self.bot_commands_total = Counter(
'guardden_commands_total',
'Total number of commands executed',
['command', 'guild', 'status'],
registry=self.registry
)
self.bot_command_duration = Histogram(
'guardden_command_duration_seconds',
'Command execution duration in seconds',
['command', 'guild'],
registry=self.registry
)
self.bot_guilds_total = Gauge(
'guardden_guilds_total',
'Total number of guilds the bot is in',
registry=self.registry
)
self.bot_users_total = Gauge(
'guardden_users_total',
'Total number of users across all guilds',
registry=self.registry
)
# Moderation metrics
self.moderation_actions_total = Counter(
'guardden_moderation_actions_total',
'Total number of moderation actions',
['action', 'guild', 'automated'],
registry=self.registry
)
self.automod_triggers_total = Counter(
'guardden_automod_triggers_total',
'Total number of automod triggers',
['filter_type', 'guild', 'action'],
registry=self.registry
)
# AI metrics
self.ai_requests_total = Counter(
'guardden_ai_requests_total',
'Total number of AI provider requests',
['provider', 'operation', 'status'],
registry=self.registry
)
self.ai_request_duration = Histogram(
'guardden_ai_request_duration_seconds',
'AI request duration in seconds',
['provider', 'operation'],
registry=self.registry
)
self.ai_confidence_score = Histogram(
'guardden_ai_confidence_score',
'AI confidence scores',
['provider', 'operation'],
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
registry=self.registry
)
# Database metrics
self.database_connections_active = Gauge(
'guardden_database_connections_active',
'Number of active database connections',
registry=self.registry
)
self.database_query_duration = Histogram(
'guardden_database_query_duration_seconds',
'Database query duration in seconds',
['operation'],
registry=self.registry
)
# System metrics
self.bot_info = Info(
'guardden_bot_info',
'Bot information',
registry=self.registry
)
self.last_heartbeat = Gauge(
'guardden_last_heartbeat_timestamp',
'Timestamp of last successful heartbeat',
registry=self.registry
)
def record_command(self, command: str, guild_id: Optional[int], status: str, duration: float):
"""Record command execution metrics."""
if not self.enabled:
return
guild_str = str(guild_id) if guild_id else 'dm'
self.bot_commands_total.labels(command=command, guild=guild_str, status=status).inc()
self.bot_command_duration.labels(command=command, guild=guild_str).observe(duration)
def record_moderation_action(self, action: str, guild_id: int, automated: bool):
"""Record moderation action metrics."""
if not self.enabled:
return
self.moderation_actions_total.labels(
action=action,
guild=str(guild_id),
automated=str(automated).lower()
).inc()
def record_automod_trigger(self, filter_type: str, guild_id: int, action: str):
"""Record automod trigger metrics."""
if not self.enabled:
return
self.automod_triggers_total.labels(
filter_type=filter_type,
guild=str(guild_id),
action=action
).inc()
def record_ai_request(self, provider: str, operation: str, status: str, duration: float, confidence: Optional[float] = None):
"""Record AI request metrics."""
if not self.enabled:
return
self.ai_requests_total.labels(
provider=provider,
operation=operation,
status=status
).inc()
self.ai_request_duration.labels(
provider=provider,
operation=operation
).observe(duration)
if confidence is not None:
self.ai_confidence_score.labels(
provider=provider,
operation=operation
).observe(confidence)
def update_guild_count(self, count: int):
"""Update total guild count."""
if not self.enabled:
return
self.bot_guilds_total.set(count)
def update_user_count(self, count: int):
"""Update total user count."""
if not self.enabled:
return
self.bot_users_total.set(count)
def update_database_connections(self, active: int):
"""Update active database connections."""
if not self.enabled:
return
self.database_connections_active.set(active)
def record_database_query(self, operation: str, duration: float):
"""Record database query metrics."""
if not self.enabled:
return
self.database_query_duration.labels(operation=operation).observe(duration)
def update_bot_info(self, info: Dict[str, str]):
"""Update bot information."""
if not self.enabled:
return
self.bot_info.info(info)
def heartbeat(self):
"""Record heartbeat timestamp."""
if not self.enabled:
return
self.last_heartbeat.set(time.time())
# Global metrics instance
_metrics: Optional[GuardDenMetrics] = None
def get_metrics() -> GuardDenMetrics:
"""Get the global metrics instance."""
global _metrics
if _metrics is None:
_metrics = GuardDenMetrics()
return _metrics
def start_metrics_server(port: int = 8001) -> None:
"""Start Prometheus metrics HTTP server."""
if PROMETHEUS_AVAILABLE:
start_http_server(port)
def metrics_middleware(func):
"""Decorator to automatically record command metrics."""
@wraps(func)
async def wrapper(*args, **kwargs):
if not PROMETHEUS_AVAILABLE:
return await func(*args, **kwargs)
start_time = time.time()
status = "success"
try:
# Try to extract context information
ctx = None
if args and hasattr(args[0], 'qualified_name'):
# This is likely a command
command_name = args[0].qualified_name
if len(args) > 1 and hasattr(args[1], 'guild'):
ctx = args[1]
else:
command_name = func.__name__
result = await func(*args, **kwargs)
return result
except Exception as e:
status = "error"
raise
finally:
duration = time.time() - start_time
guild_id = ctx.guild.id if ctx and ctx.guild else None
metrics = get_metrics()
metrics.record_command(
command=command_name,
guild_id=guild_id,
status=status,
duration=duration
)
return wrapper
class MetricsCollector:
"""Periodic metrics collector for system stats."""
def __init__(self, bot):
self.bot = bot
self.metrics = get_metrics()
async def collect_bot_metrics(self):
"""Collect basic bot metrics."""
if not PROMETHEUS_AVAILABLE:
return
# Guild count
guild_count = len(self.bot.guilds)
self.metrics.update_guild_count(guild_count)
# Total user count across all guilds
total_users = sum(guild.member_count or 0 for guild in self.bot.guilds)
self.metrics.update_user_count(total_users)
# Database connections if available
if hasattr(self.bot, 'database') and self.bot.database._engine:
try:
pool = self.bot.database._engine.pool
if hasattr(pool, 'checkedout'):
active_connections = pool.checkedout()
self.metrics.update_database_connections(active_connections)
except Exception:
pass # Ignore database connection metrics errors
# Bot info
self.metrics.update_bot_info({
'version': getattr(self.bot, 'version', 'unknown'),
'python_version': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
'discord_py_version': str(discord.__version__) if 'discord' in globals() else 'unknown',
})
# Heartbeat
self.metrics.heartbeat()
def setup_metrics(bot, port: int = 8001) -> Optional[MetricsCollector]:
"""Set up metrics collection for the bot."""
if not PROMETHEUS_AVAILABLE:
return None
try:
start_metrics_server(port)
collector = MetricsCollector(bot)
return collector
except Exception as e:
# Log error but don't fail startup
logger = __import__('logging').getLogger(__name__)
logger.error(f"Failed to start metrics server: {e}")
return None

View File

@@ -0,0 +1,10 @@
"""Rate limit helpers for Discord commands."""
from dataclasses import dataclass
@dataclass
class RateLimitExceeded(Exception):
"""Raised when a command is rate limited."""
retry_after: float