quick commit
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 6m9s
CI/CD Pipeline / Security Scanning (push) Successful in 26s
CI/CD Pipeline / Tests (3.11) (push) Failing after 5m24s
CI/CD Pipeline / Tests (3.12) (push) Failing after 5m23s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
CI/CD Pipeline / Deploy to Staging (push) Has been skipped
CI/CD Pipeline / Deploy to Production (push) Has been skipped
CI/CD Pipeline / Notification (push) Successful in 1s
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""Main bot class for GuardDen."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import platform
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
@@ -9,11 +11,14 @@ from discord.ext import commands
|
||||
from guardden.config import Settings
|
||||
from guardden.services.ai import AIProvider, create_ai_provider
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ratelimit import RateLimiter
|
||||
from guardden.utils.logging import get_logger, get_logging_middleware, setup_logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guardden.services.guild_config import GuildConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
logging_middleware = get_logging_middleware()
|
||||
|
||||
|
||||
class GuardDen(commands.Bot):
|
||||
@@ -37,6 +42,7 @@ class GuardDen(commands.Bot):
|
||||
self.database = Database(settings)
|
||||
self.guild_config: "GuildConfigService | None" = None
|
||||
self.ai_provider: AIProvider | None = None
|
||||
self.rate_limiter = RateLimiter()
|
||||
|
||||
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
||||
"""Get the command prefix for a guild."""
|
||||
@@ -50,10 +56,32 @@ class GuardDen(commands.Bot):
|
||||
|
||||
return [self.settings.discord_prefix]
|
||||
|
||||
def is_guild_allowed(self, guild_id: int) -> bool:
|
||||
"""Check if a guild is allowed to run the bot."""
|
||||
return not self.settings.allowed_guilds or guild_id in self.settings.allowed_guilds
|
||||
|
||||
def is_owner_allowed(self, user_id: int) -> bool:
|
||||
"""Check if a user is allowed elevated access."""
|
||||
return not self.settings.owner_ids or user_id in self.settings.owner_ids
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Called when the bot is starting up."""
|
||||
logger.info("Starting GuardDen setup...")
|
||||
|
||||
self.settings.validate_configuration()
|
||||
logger.info(
|
||||
"Configuration loaded: ai_provider=%s, log_level=%s, allowed_guilds=%s, owner_ids=%s",
|
||||
self.settings.ai_provider,
|
||||
self.settings.log_level,
|
||||
self.settings.allowed_guilds or "all",
|
||||
self.settings.owner_ids or "admins",
|
||||
)
|
||||
logger.info(
|
||||
"Runtime versions: python=%s, discord.py=%s",
|
||||
platform.python_version(),
|
||||
discord.__version__,
|
||||
)
|
||||
|
||||
# Connect to database
|
||||
await self.database.connect()
|
||||
await self.database.create_tables()
|
||||
@@ -86,14 +114,27 @@ class GuardDen(commands.Bot):
|
||||
"guardden.cogs.automod",
|
||||
"guardden.cogs.ai_moderation",
|
||||
"guardden.cogs.verification",
|
||||
"guardden.cogs.health",
|
||||
]
|
||||
|
||||
failed_cogs = []
|
||||
for cog in cogs:
|
||||
try:
|
||||
await self.load_extension(cog)
|
||||
logger.info(f"Loaded cog: {cog}")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import cog {cog}: {e}")
|
||||
failed_cogs.append(cog)
|
||||
except commands.ExtensionError as e:
|
||||
logger.error(f"Discord extension error loading {cog}: {e}")
|
||||
failed_cogs.append(cog)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cog {cog}: {e}")
|
||||
logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True)
|
||||
failed_cogs.append(cog)
|
||||
|
||||
if failed_cogs:
|
||||
logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}")
|
||||
# Don't fail startup if some cogs fail to load, but log it prominently
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Called when the bot is fully connected and ready."""
|
||||
@@ -103,9 +144,30 @@ class GuardDen(commands.Bot):
|
||||
|
||||
# Ensure all guilds have database entries
|
||||
if self.guild_config:
|
||||
initialized = 0
|
||||
failed_guilds = []
|
||||
|
||||
for guild in self.guilds:
|
||||
await self.guild_config.create_guild(guild)
|
||||
logger.info(f"Initialized config for {len(self.guilds)} guild(s)")
|
||||
try:
|
||||
if not self.is_guild_allowed(guild.id):
|
||||
logger.warning(
|
||||
"Leaving unauthorized guild %s (ID: %s)", guild.name, guild.id
|
||||
)
|
||||
try:
|
||||
await guild.leave()
|
||||
except discord.HTTPException as e:
|
||||
logger.error(f"Failed to leave guild {guild.id}: {e}")
|
||||
continue
|
||||
|
||||
await self.guild_config.create_guild(guild)
|
||||
initialized += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", exc_info=True)
|
||||
failed_guilds.append(guild.id)
|
||||
|
||||
logger.info("Initialized config for %s guild(s)", initialized)
|
||||
if failed_guilds:
|
||||
logger.warning(f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}")
|
||||
|
||||
# Set presence
|
||||
activity = discord.Activity(
|
||||
@@ -117,6 +179,7 @@ class GuardDen(commands.Bot):
|
||||
async def close(self) -> None:
|
||||
"""Clean up when shutting down."""
|
||||
logger.info("Shutting down GuardDen...")
|
||||
await self._shutdown_cogs()
|
||||
if self.ai_provider:
|
||||
try:
|
||||
await self.ai_provider.close()
|
||||
@@ -125,10 +188,30 @@ class GuardDen(commands.Bot):
|
||||
await self.database.disconnect()
|
||||
await super().close()
|
||||
|
||||
async def _shutdown_cogs(self) -> None:
|
||||
"""Ensure cogs can clean up background tasks."""
|
||||
for cog in list(self.cogs.values()):
|
||||
unload = getattr(cog, "cog_unload", None)
|
||||
if unload is None:
|
||||
continue
|
||||
try:
|
||||
result = unload()
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as e:
|
||||
logger.error("Error during cog unload (%s): %s", cog.qualified_name, e)
|
||||
|
||||
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||
"""Called when the bot joins a new guild."""
|
||||
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
||||
|
||||
if not self.is_guild_allowed(guild.id):
|
||||
logger.warning(
|
||||
"Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id
|
||||
)
|
||||
await guild.leave()
|
||||
return
|
||||
|
||||
if self.guild_config:
|
||||
await self.guild_config.create_guild(guild)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,12 +18,32 @@ class Admin(commands.Cog):
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Ensure only administrators can use these commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
if not self.bot.is_owner_allowed(ctx.author.id):
|
||||
return False
|
||||
return ctx.author.guild_permissions.administrator
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
@commands.group(name="config", invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def config(self, ctx: commands.Context) -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""AI-powered moderation cog."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
@@ -9,16 +8,13 @@ import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.models import ModerationLog
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||
from guardden.services.automod import URL_PATTERN, is_allowed_domain, normalize_domain
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# URL pattern for extraction
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class AIModeration(commands.Cog):
|
||||
"""AI-powered content moderation."""
|
||||
@@ -28,6 +24,30 @@ class AIModeration(commands.Cog):
|
||||
# Track recently analyzed messages to avoid duplicates (deque auto-removes oldest)
|
||||
self._analyzed_messages: deque[int] = deque(maxlen=1000)
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Optional owner allowlist for AI commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
return self.bot.is_owner_allowed(ctx.author.id)
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
def _should_analyze(self, message: discord.Message) -> bool:
|
||||
"""Determine if a message should be analyzed by AI."""
|
||||
# Skip if already analyzed
|
||||
@@ -67,21 +87,37 @@ class AIModeration(commands.Cog):
|
||||
threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30
|
||||
if result.severity < threshold:
|
||||
logger.debug(
|
||||
f"AI flagged content but below threshold: "
|
||||
f"severity={result.severity}, threshold={threshold}"
|
||||
"AI flagged content but below threshold: severity=%s, threshold=%s",
|
||||
result.severity,
|
||||
threshold,
|
||||
)
|
||||
return
|
||||
|
||||
if result.confidence < config.ai_confidence_threshold:
|
||||
logger.debug(
|
||||
"AI flagged content but below confidence threshold: confidence=%s, threshold=%s",
|
||||
result.confidence,
|
||||
config.ai_confidence_threshold,
|
||||
)
|
||||
return
|
||||
|
||||
log_only = config.ai_log_only
|
||||
|
||||
# Determine action based on suggested action and severity
|
||||
should_delete = result.suggested_action in ("delete", "timeout", "ban")
|
||||
should_timeout = result.suggested_action in ("timeout", "ban") and result.severity > 70
|
||||
should_delete = not log_only and result.suggested_action in ("delete", "timeout", "ban")
|
||||
should_timeout = (
|
||||
not log_only
|
||||
and result.suggested_action in ("timeout", "ban")
|
||||
and result.severity > 70
|
||||
)
|
||||
timeout_duration: int | None = None
|
||||
|
||||
# Delete message if needed
|
||||
if should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot delete message: missing permissions")
|
||||
logger.warning("Cannot delete message: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
@@ -96,8 +132,19 @@ class AIModeration(commands.Cog):
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await self._log_ai_db_action(
|
||||
message,
|
||||
result,
|
||||
analysis_type,
|
||||
log_only=log_only,
|
||||
timeout_duration=timeout_duration,
|
||||
)
|
||||
|
||||
# Log to mod channel
|
||||
await self._log_ai_action(message, result, analysis_type)
|
||||
await self._log_ai_action(message, result, analysis_type, log_only=log_only)
|
||||
|
||||
if log_only:
|
||||
return
|
||||
|
||||
# Notify user
|
||||
try:
|
||||
@@ -122,6 +169,7 @@ class AIModeration(commands.Cog):
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
log_only: bool = False,
|
||||
) -> None:
|
||||
"""Log an AI moderation action."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
@@ -142,9 +190,10 @@ class AIModeration(commands.Cog):
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
|
||||
action_label = "log-only" if log_only else result.suggested_action
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Action", value=result.suggested_action, inline=True)
|
||||
embed.add_field(name="Action", value=action_label, inline=True)
|
||||
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories or "None", inline=False)
|
||||
@@ -160,10 +209,43 @@ class AIModeration(commands.Cog):
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
async def _log_ai_db_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
log_only: bool,
|
||||
timeout_duration: int | None,
|
||||
) -> None:
|
||||
"""Log an AI moderation action to the database."""
|
||||
action = "ai_log" if log_only else f"ai_{result.suggested_action}"
|
||||
reason = result.explanation or f"AI moderation flagged content ({analysis_type})"
|
||||
expires_at = None
|
||||
if timeout_duration:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=timeout_duration)
|
||||
|
||||
async with self.bot.database.session() as session:
|
||||
entry = ModerationLog(
|
||||
guild_id=message.guild.id,
|
||||
target_id=message.author.id,
|
||||
target_name=str(message.author),
|
||||
moderator_id=self.bot.user.id if self.bot.user else 0,
|
||||
moderator_name=str(self.bot.user) if self.bot.user else "GuardDen",
|
||||
action=action,
|
||||
reason=reason,
|
||||
duration=timeout_duration,
|
||||
expires_at=expires_at,
|
||||
channel_id=message.channel.id,
|
||||
message_id=message.id,
|
||||
message_content=message.content,
|
||||
is_automatic=True,
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Analyze messages with AI moderation."""
|
||||
print(f"[AI_MOD] Received message from {message.author}", flush=True)
|
||||
logger.debug("AI moderation received message from %s", message.author)
|
||||
|
||||
# Skip bot messages early
|
||||
if message.author.bot:
|
||||
@@ -247,7 +329,11 @@ class AIModeration(commands.Cog):
|
||||
|
||||
# Analyze URLs for phishing
|
||||
urls = URL_PATTERN.findall(message.content)
|
||||
allowlist = {normalize_domain(domain) for domain in config.scam_allowlist if domain}
|
||||
for url in urls[:3]: # Limit to first 3 URLs
|
||||
hostname = normalize_domain(url)
|
||||
if allowlist and is_allowed_domain(hostname, allowlist):
|
||||
continue
|
||||
phishing_result = await self.bot.ai_provider.analyze_phishing(
|
||||
url=url,
|
||||
message_content=message.content,
|
||||
@@ -291,6 +377,16 @@ class AIModeration(commands.Cog):
|
||||
value=f"{config.ai_sensitivity}/100" if config else "50/100",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Confidence Threshold",
|
||||
value=f"{config.ai_confidence_threshold:.2f}" if config else "0.70",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Log Only",
|
||||
value="✅ Enabled" if config and config.ai_log_only else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="AI Provider",
|
||||
value=self.bot.settings.ai_provider.capitalize(),
|
||||
@@ -333,6 +429,27 @@ class AIModeration(commands.Cog):
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level)
|
||||
await ctx.send(f"AI sensitivity set to {level}/100.")
|
||||
|
||||
@ai_cmd.command(name="threshold")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_threshold(self, ctx: commands.Context, value: float) -> None:
|
||||
"""Set AI confidence threshold (0.0-1.0)."""
|
||||
if not 0.0 <= value <= 1.0:
|
||||
await ctx.send("Threshold must be between 0.0 and 1.0.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_confidence_threshold=value)
|
||||
await ctx.send(f"AI confidence threshold set to {value:.2f}.")
|
||||
|
||||
@ai_cmd.command(name="logonly")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_logonly(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable log-only mode for AI moderation."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_log_only=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"AI log-only mode {status}.")
|
||||
|
||||
@ai_cmd.command(name="nsfw")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
|
||||
@@ -2,12 +2,21 @@
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.automod import AutomodResult, AutomodService
|
||||
from guardden.models import ModerationLog, Strike
|
||||
from guardden.services.automod import (
|
||||
AutomodResult,
|
||||
AutomodService,
|
||||
SpamConfig,
|
||||
normalize_domain,
|
||||
)
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,6 +28,135 @@ class Automod(commands.Cog):
|
||||
self.bot = bot
|
||||
self.automod = AutomodService()
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Optional owner allowlist for automod commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
return self.bot.is_owner_allowed(ctx.author.id)
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
def _spam_config(self, config) -> SpamConfig:
|
||||
if not config:
|
||||
return self.automod.default_spam_config
|
||||
return SpamConfig(
|
||||
message_rate_limit=config.message_rate_limit,
|
||||
message_rate_window=config.message_rate_window,
|
||||
duplicate_threshold=config.duplicate_threshold,
|
||||
mention_limit=config.mention_limit,
|
||||
mention_rate_limit=config.mention_rate_limit,
|
||||
mention_rate_window=config.mention_rate_window,
|
||||
)
|
||||
|
||||
async def _get_strike_count(self, guild_id: int, user_id: int) -> int:
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(func.sum(Strike.points)).where(
|
||||
Strike.guild_id == guild_id,
|
||||
Strike.user_id == user_id,
|
||||
Strike.is_active == True,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return total or 0
|
||||
|
||||
async def _add_strike(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
member: discord.Member,
|
||||
reason: str,
|
||||
) -> int:
|
||||
async with self.bot.database.session() as session:
|
||||
strike = Strike(
|
||||
guild_id=guild.id,
|
||||
user_id=member.id,
|
||||
user_name=str(member),
|
||||
moderator_id=self.bot.user.id if self.bot.user else 0,
|
||||
reason=reason,
|
||||
points=1,
|
||||
)
|
||||
session.add(strike)
|
||||
|
||||
return await self._get_strike_count(guild.id, member.id)
|
||||
|
||||
async def _apply_strike_actions(
|
||||
self,
|
||||
member: discord.Member,
|
||||
total_strikes: int,
|
||||
config,
|
||||
) -> None:
|
||||
if not config or not config.strike_actions:
|
||||
return
|
||||
|
||||
for threshold, action_config in sorted(
|
||||
config.strike_actions.items(), key=lambda item: int(item[0]), reverse=True
|
||||
):
|
||||
if total_strikes < int(threshold):
|
||||
continue
|
||||
action = action_config.get("action")
|
||||
if action == "ban":
|
||||
await member.ban(reason=f"Automod: {total_strikes} strikes")
|
||||
elif action == "kick":
|
||||
await member.kick(reason=f"Automod: {total_strikes} strikes")
|
||||
elif action == "timeout":
|
||||
duration = action_config.get("duration", 3600)
|
||||
await member.timeout(
|
||||
timedelta(seconds=duration),
|
||||
reason=f"Automod: {total_strikes} strikes",
|
||||
)
|
||||
break
|
||||
|
||||
async def _log_database_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
async with self.bot.database.session() as session:
|
||||
action = "delete"
|
||||
if result.should_timeout:
|
||||
action = "timeout"
|
||||
elif result.should_strike:
|
||||
action = "strike"
|
||||
elif result.should_warn:
|
||||
action = "warn"
|
||||
|
||||
expires_at = None
|
||||
if result.timeout_duration:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=result.timeout_duration)
|
||||
|
||||
log_entry = ModerationLog(
|
||||
guild_id=message.guild.id,
|
||||
target_id=message.author.id,
|
||||
target_name=str(message.author),
|
||||
moderator_id=self.bot.user.id if self.bot.user else 0,
|
||||
moderator_name=str(self.bot.user) if self.bot.user else "GuardDen",
|
||||
action=action,
|
||||
reason=result.reason,
|
||||
duration=result.timeout_duration or None,
|
||||
expires_at=expires_at,
|
||||
channel_id=message.channel.id,
|
||||
message_id=message.id,
|
||||
message_content=message.content,
|
||||
is_automatic=True,
|
||||
)
|
||||
session.add(log_entry)
|
||||
|
||||
async def _handle_violation(
|
||||
self,
|
||||
message: discord.Message,
|
||||
@@ -45,8 +183,15 @@ class Automod(commands.Cog):
|
||||
logger.warning(f"Cannot timeout {message.author}: missing permissions")
|
||||
|
||||
# Log the action
|
||||
await self._log_database_action(message, result)
|
||||
await self._log_automod_action(message, result)
|
||||
|
||||
# Apply strike escalation if configured
|
||||
if (result.should_warn or result.should_strike) and isinstance(message.author, discord.Member):
|
||||
total = await self._add_strike(message.guild, message.author, result.reason)
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
await self._apply_strike_actions(message.author, total, config)
|
||||
|
||||
# Notify the user via DM
|
||||
try:
|
||||
embed = discord.Embed(
|
||||
@@ -136,13 +281,22 @@ class Automod(commands.Cog):
|
||||
if banned_words:
|
||||
result = self.automod.check_banned_words(message.content, banned_words)
|
||||
|
||||
spam_config = self._spam_config(config)
|
||||
|
||||
# Check scam links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_scam_links(message.content)
|
||||
result = self.automod.check_scam_links(
|
||||
message.content,
|
||||
allowlist=config.scam_allowlist,
|
||||
)
|
||||
|
||||
# Check spam
|
||||
if not result and config.anti_spam_enabled:
|
||||
result = self.automod.check_spam(message, anti_spam_enabled=True)
|
||||
result = self.automod.check_spam(
|
||||
message,
|
||||
anti_spam_enabled=True,
|
||||
spam_config=spam_config,
|
||||
)
|
||||
|
||||
# Check invite links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
@@ -194,20 +348,27 @@ class Automod(commands.Cog):
|
||||
inline=True,
|
||||
)
|
||||
|
||||
spam_config = self._spam_config(config)
|
||||
|
||||
# Show thresholds
|
||||
embed.add_field(
|
||||
name="Rate Limit",
|
||||
value=f"{self.automod.message_rate_limit} msgs / {self.automod.message_rate_window}s",
|
||||
value=f"{spam_config.message_rate_limit} msgs / {spam_config.message_rate_window}s",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Duplicate Threshold",
|
||||
value=f"{self.automod.duplicate_threshold} same messages",
|
||||
value=f"{spam_config.duplicate_threshold} same messages",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mention Limit",
|
||||
value=f"{self.automod.mention_limit} per message",
|
||||
value=f"{spam_config.mention_limit} per message",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mention Rate",
|
||||
value=f"{spam_config.mention_rate_limit} mentions / {spam_config.mention_rate_window}s",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
@@ -220,6 +381,82 @@ class Automod(commands.Cog):
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@automod_cmd.command(name="threshold")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_threshold(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
setting: Literal[
|
||||
"message_rate_limit",
|
||||
"message_rate_window",
|
||||
"duplicate_threshold",
|
||||
"mention_limit",
|
||||
"mention_rate_limit",
|
||||
"mention_rate_window",
|
||||
],
|
||||
value: int,
|
||||
) -> None:
|
||||
"""Update a single automod threshold."""
|
||||
if value <= 0:
|
||||
await ctx.send("Threshold values must be positive.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, **{setting: value})
|
||||
await ctx.send(f"Updated `{setting}` to {value}.")
|
||||
|
||||
@automod_cmd.group(name="allowlist", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_allowlist(self, ctx: commands.Context) -> None:
|
||||
"""Show the scam link allowlist."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
allowlist = sorted(config.scam_allowlist) if config else []
|
||||
if not allowlist:
|
||||
await ctx.send("No allowlisted domains configured.")
|
||||
return
|
||||
|
||||
formatted = "\n".join(f"- `{domain}`" for domain in allowlist[:20])
|
||||
await ctx.send(f"Allowed domains:\n{formatted}")
|
||||
|
||||
@automod_allowlist.command(name="add")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_allowlist_add(self, ctx: commands.Context, domain: str) -> None:
|
||||
"""Add a domain to the scam link allowlist."""
|
||||
normalized = normalize_domain(domain)
|
||||
if not normalized:
|
||||
await ctx.send("Provide a valid domain or URL to allowlist.")
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
allowlist = list(config.scam_allowlist) if config else []
|
||||
|
||||
if normalized in allowlist:
|
||||
await ctx.send(f"`{normalized}` is already allowlisted.")
|
||||
return
|
||||
|
||||
allowlist.append(normalized)
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, scam_allowlist=allowlist)
|
||||
await ctx.send(f"Added `{normalized}` to the allowlist.")
|
||||
|
||||
@automod_allowlist.command(name="remove")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_allowlist_remove(self, ctx: commands.Context, domain: str) -> None:
|
||||
"""Remove a domain from the scam link allowlist."""
|
||||
normalized = normalize_domain(domain)
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
allowlist = list(config.scam_allowlist) if config else []
|
||||
|
||||
if normalized not in allowlist:
|
||||
await ctx.send(f"`{normalized}` is not in the allowlist.")
|
||||
return
|
||||
|
||||
allowlist.remove(normalized)
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, scam_allowlist=allowlist)
|
||||
await ctx.send(f"Removed `{normalized}` from the allowlist.")
|
||||
|
||||
@automod_cmd.command(name="test")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
@@ -235,7 +472,7 @@ class Automod(commands.Cog):
|
||||
results.append(f"**Banned Words**: {result.reason}")
|
||||
|
||||
# Check scam links
|
||||
result = self.automod.check_scam_links(text)
|
||||
result = self.automod.check_scam_links(text, allowlist=config.scam_allowlist if config else [])
|
||||
if result:
|
||||
results.append(f"**Scam Detection**: {result.reason}")
|
||||
|
||||
|
||||
71
src/guardden/cogs/health.py
Normal file
71
src/guardden/cogs/health.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Health check commands."""
|
||||
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import select
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Health(commands.Cog):
|
||||
"""Health checks for the bot."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
if not ctx.guild:
|
||||
return False
|
||||
if not self.bot.is_owner_allowed(ctx.author.id):
|
||||
return False
|
||||
return ctx.author.guild_permissions.administrator
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
@commands.command(name="health")
|
||||
@commands.guild_only()
|
||||
async def health(self, ctx: commands.Context) -> None:
|
||||
"""Check database and AI provider health."""
|
||||
db_status = "ok"
|
||||
try:
|
||||
async with self.bot.database.session() as session:
|
||||
await session.execute(select(1))
|
||||
except Exception as exc: # pragma: no cover - external dependency
|
||||
logger.exception("Health check database failure")
|
||||
db_status = f"error: {exc}"
|
||||
|
||||
ai_status = "disabled"
|
||||
if self.bot.settings.ai_provider != "none":
|
||||
ai_status = "ok" if self.bot.ai_provider else "unavailable"
|
||||
|
||||
embed = discord.Embed(title="GuardDen Health", color=discord.Color.green())
|
||||
embed.add_field(name="Database", value=db_status, inline=False)
|
||||
embed.add_field(name="AI Provider", value=ai_status, inline=False)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the health cog."""
|
||||
await bot.add_cog(Health(bot))
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Moderation commands and automod features."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
@@ -10,36 +9,43 @@ from sqlalchemy import func, select
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.models import ModerationLog, Strike
|
||||
from guardden.utils import parse_duration
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_duration(duration_str: str) -> timedelta | None:
|
||||
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
|
||||
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
units = {
|
||||
"s": timedelta(seconds=amount),
|
||||
"m": timedelta(minutes=amount),
|
||||
"h": timedelta(hours=amount),
|
||||
"d": timedelta(days=amount),
|
||||
"w": timedelta(weeks=amount),
|
||||
}
|
||||
|
||||
return units.get(unit)
|
||||
|
||||
|
||||
class Moderation(commands.Cog):
|
||||
"""Moderation commands for server management."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
if not ctx.guild:
|
||||
return False
|
||||
if not self.bot.is_owner_allowed(ctx.author.id):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
async def _log_action(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
@@ -334,7 +340,15 @@ class Moderation(commands.Cog):
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await member.kick(reason=f"{ctx.author}: {reason}")
|
||||
try:
|
||||
await member.kick(reason=f"{ctx.author}: {reason}")
|
||||
except discord.Forbidden:
|
||||
await ctx.send("❌ I don't have permission to kick this member.")
|
||||
return
|
||||
except discord.HTTPException as e:
|
||||
await ctx.send(f"❌ Failed to kick member: {e}")
|
||||
return
|
||||
|
||||
await self._log_action(ctx.guild, member, ctx.author, "kick", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
@@ -346,7 +360,10 @@ class Moderation(commands.Cog):
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
try:
|
||||
await ctx.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
await ctx.send(f"✅ {member} has been kicked from the server.")
|
||||
|
||||
@commands.command(name="ban")
|
||||
@commands.has_permissions(ban_members=True)
|
||||
@@ -376,7 +393,15 @@ class Moderation(commands.Cog):
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
|
||||
try:
|
||||
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
|
||||
except discord.Forbidden:
|
||||
await ctx.send("❌ I don't have permission to ban this member.")
|
||||
return
|
||||
except discord.HTTPException as e:
|
||||
await ctx.send(f"❌ Failed to ban member: {e}")
|
||||
return
|
||||
|
||||
await self._log_action(ctx.guild, member, ctx.author, "ban", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
@@ -388,7 +413,10 @@ class Moderation(commands.Cog):
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
try:
|
||||
await ctx.send(embed=embed)
|
||||
except discord.HTTPException:
|
||||
await ctx.send(f"✅ {member} has been banned from the server.")
|
||||
|
||||
@commands.command(name="unban")
|
||||
@commands.has_permissions(ban_members=True)
|
||||
|
||||
@@ -13,6 +13,7 @@ from guardden.services.verification import (
|
||||
PendingVerification,
|
||||
VerificationService,
|
||||
)
|
||||
from guardden.utils.ratelimit import RateLimitExceeded
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -155,6 +156,31 @@ class Verification(commands.Cog):
|
||||
self.service = VerificationService()
|
||||
self.cleanup_task.start()
|
||||
|
||||
def cog_check(self, ctx: commands.Context) -> bool:
|
||||
if not ctx.guild:
|
||||
return False
|
||||
if not self.bot.is_owner_allowed(ctx.author.id):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def cog_before_invoke(self, ctx: commands.Context) -> None:
|
||||
if not ctx.command:
|
||||
return
|
||||
result = self.bot.rate_limiter.acquire_command(
|
||||
ctx.command.qualified_name,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
if result.is_limited:
|
||||
raise RateLimitExceeded(result.reset_after)
|
||||
|
||||
async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None:
|
||||
if isinstance(error, RateLimitExceeded):
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {error.retry_after:.1f} seconds."
|
||||
)
|
||||
|
||||
def cog_unload(self) -> None:
|
||||
self.cleanup_task.cancel()
|
||||
|
||||
|
||||
@@ -1,12 +1,70 @@
|
||||
"""Configuration management for GuardDen."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic import Field, SecretStr, field_validator, ValidationError
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
|
||||
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
|
||||
|
||||
|
||||
def _validate_discord_id(value: str | int) -> int:
|
||||
"""Validate a Discord snowflake ID."""
|
||||
if isinstance(value, int):
|
||||
id_str = str(value)
|
||||
else:
|
||||
id_str = str(value).strip()
|
||||
|
||||
# Check format
|
||||
if not DISCORD_ID_PATTERN.match(id_str):
|
||||
raise ValueError(f"Invalid Discord ID format: {id_str}")
|
||||
|
||||
# Convert to int and validate range
|
||||
discord_id = int(id_str)
|
||||
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
|
||||
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
|
||||
raise ValueError(f"Discord ID out of valid range: {discord_id}")
|
||||
|
||||
return discord_id
|
||||
|
||||
|
||||
def _parse_id_list(value: Any) -> list[int]:
|
||||
"""Parse an environment value into a list of valid Discord IDs."""
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
items: list[Any]
|
||||
if isinstance(value, list):
|
||||
items = value
|
||||
elif isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
# Only allow comma or semicolon separated values, no JSON parsing for security
|
||||
items = [part.strip() for part in text.replace(";", ",").split(",") if part.strip()]
|
||||
else:
|
||||
items = [value]
|
||||
|
||||
parsed: list[int] = []
|
||||
seen: set[int] = set()
|
||||
for item in items:
|
||||
try:
|
||||
discord_id = _validate_discord_id(item)
|
||||
if discord_id not in seen:
|
||||
parsed.append(discord_id)
|
||||
seen.add(discord_id)
|
||||
except (ValueError, TypeError):
|
||||
# Skip invalid IDs rather than failing silently
|
||||
continue
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
@@ -40,11 +98,79 @@ class Settings(BaseSettings):
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||
default="INFO", description="Logging level"
|
||||
)
|
||||
log_json: bool = Field(default=False, description="Use JSON structured logging format")
|
||||
log_file: str | None = Field(default=None, description="Log file path (optional)")
|
||||
|
||||
# Access control
|
||||
allowed_guilds: list[int] = Field(
|
||||
default_factory=list,
|
||||
description="Guild IDs the bot is allowed to join (empty = allow all)",
|
||||
)
|
||||
owner_ids: list[int] = Field(
|
||||
default_factory=list,
|
||||
description="Owner user IDs with elevated access (empty = allow admins)",
|
||||
)
|
||||
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||
|
||||
@field_validator("allowed_guilds", "owner_ids", mode="before")
|
||||
@classmethod
|
||||
def _validate_id_list(cls, value: Any) -> list[int]:
|
||||
return _parse_id_list(value)
|
||||
|
||||
@field_validator("discord_token")
|
||||
@classmethod
|
||||
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
|
||||
"""Validate Discord bot token format."""
|
||||
token = value.get_secret_value()
|
||||
if not token:
|
||||
raise ValueError("Discord token cannot be empty")
|
||||
|
||||
# Basic Discord token format validation (not perfect but catches common issues)
|
||||
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
|
||||
raise ValueError("Invalid Discord token format")
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("anthropic_api_key", "openai_api_key")
|
||||
@classmethod
|
||||
def _validate_api_key(cls, value: SecretStr | None) -> SecretStr | None:
|
||||
"""Validate API key format if provided."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
key = value.get_secret_value()
|
||||
if not key:
|
||||
return None
|
||||
|
||||
# Basic API key validation
|
||||
if len(key) < 20:
|
||||
raise ValueError("API key too short to be valid")
|
||||
|
||||
return value
|
||||
|
||||
def validate_configuration(self) -> None:
|
||||
"""Validate the settings for runtime usage."""
|
||||
# AI provider validation
|
||||
if self.ai_provider == "anthropic" and not self.anthropic_api_key:
|
||||
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
|
||||
if self.ai_provider == "openai" and not self.openai_api_key:
|
||||
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
|
||||
|
||||
# Database pool validation
|
||||
if self.database_pool_min > self.database_pool_max:
|
||||
raise ValueError("database_pool_min cannot be greater than database_pool_max")
|
||||
if self.database_pool_min < 1:
|
||||
raise ValueError("database_pool_min must be at least 1")
|
||||
|
||||
# Data directory validation
|
||||
if not isinstance(self.data_dir, Path):
|
||||
raise ValueError("data_dir must be a valid path")
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
return Settings()
|
||||
settings = Settings()
|
||||
settings.validate_configuration()
|
||||
return settings
|
||||
|
||||
1
src/guardden/dashboard/__init__.py
Normal file
1
src/guardden/dashboard/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Dashboard application package."""
|
||||
267
src/guardden/dashboard/analytics.py
Normal file
267
src/guardden/dashboard/analytics.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Analytics API routes for the GuardDen dashboard."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from guardden.dashboard.auth import require_owner
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
from guardden.dashboard.db import DashboardDatabase
|
||||
from guardden.dashboard.schemas import (
|
||||
AIPerformanceStats,
|
||||
AnalyticsSummary,
|
||||
ModerationStats,
|
||||
TimeSeriesDataPoint,
|
||||
UserActivityStats,
|
||||
)
|
||||
from guardden.models import AICheck, MessageActivity, ModerationLog, UserActivity
|
||||
|
||||
|
||||
def create_analytics_router(
|
||||
settings: DashboardSettings,
|
||||
database: DashboardDatabase,
|
||||
) -> APIRouter:
|
||||
"""Create the analytics API router."""
|
||||
router = APIRouter(prefix="/api/analytics")
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
async for session in database.session():
|
||||
yield session
|
||||
|
||||
def require_owner_dep(request: Request) -> None:
|
||||
require_owner(settings, request)
|
||||
|
||||
@router.get(
|
||||
"/summary",
|
||||
response_model=AnalyticsSummary,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def analytics_summary(
|
||||
guild_id: int | None = Query(default=None),
|
||||
days: int = Query(default=7, ge=1, le=90),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AnalyticsSummary:
|
||||
"""Get analytics summary for the specified time period."""
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
# Moderation stats
|
||||
mod_query = select(ModerationLog).where(ModerationLog.created_at >= start_date)
|
||||
if guild_id:
|
||||
mod_query = mod_query.where(ModerationLog.guild_id == guild_id)
|
||||
|
||||
mod_result = await session.execute(mod_query)
|
||||
mod_logs = mod_result.scalars().all()
|
||||
|
||||
total_actions = len(mod_logs)
|
||||
actions_by_type: dict[str, int] = {}
|
||||
automatic_count = 0
|
||||
manual_count = 0
|
||||
|
||||
for log in mod_logs:
|
||||
actions_by_type[log.action] = actions_by_type.get(log.action, 0) + 1
|
||||
if log.is_automatic:
|
||||
automatic_count += 1
|
||||
else:
|
||||
manual_count += 1
|
||||
|
||||
# Time series data (group by day)
|
||||
time_series: dict[str, int] = {}
|
||||
for log in mod_logs:
|
||||
day_key = log.created_at.strftime("%Y-%m-%d")
|
||||
time_series[day_key] = time_series.get(day_key, 0) + 1
|
||||
|
||||
actions_over_time = [
|
||||
TimeSeriesDataPoint(timestamp=datetime.strptime(day, "%Y-%m-%d"), value=count)
|
||||
for day, count in sorted(time_series.items())
|
||||
]
|
||||
|
||||
moderation_stats = ModerationStats(
|
||||
total_actions=total_actions,
|
||||
actions_by_type=actions_by_type,
|
||||
actions_over_time=actions_over_time,
|
||||
automatic_vs_manual={"automatic": automatic_count, "manual": manual_count},
|
||||
)
|
||||
|
||||
# User activity stats
|
||||
activity_query = select(MessageActivity).where(MessageActivity.date >= start_date)
|
||||
if guild_id:
|
||||
activity_query = activity_query.where(MessageActivity.guild_id == guild_id)
|
||||
|
||||
activity_result = await session.execute(activity_query)
|
||||
activities = activity_result.scalars().all()
|
||||
|
||||
total_messages = sum(a.total_messages for a in activities)
|
||||
active_users = max((a.active_users for a in activities), default=0)
|
||||
|
||||
# New joins
|
||||
today = datetime.now().date()
|
||||
week_ago = today - timedelta(days=7)
|
||||
new_joins_today = sum(a.new_joins for a in activities if a.date.date() == today)
|
||||
new_joins_week = sum(a.new_joins for a in activities if a.date.date() >= week_ago)
|
||||
|
||||
user_activity = UserActivityStats(
|
||||
active_users=active_users,
|
||||
total_messages=total_messages,
|
||||
new_joins_today=new_joins_today,
|
||||
new_joins_week=new_joins_week,
|
||||
)
|
||||
|
||||
# AI performance stats
|
||||
ai_query = select(AICheck).where(AICheck.created_at >= start_date)
|
||||
if guild_id:
|
||||
ai_query = ai_query.where(AICheck.guild_id == guild_id)
|
||||
|
||||
ai_result = await session.execute(ai_query)
|
||||
ai_checks = ai_result.scalars().all()
|
||||
|
||||
total_checks = len(ai_checks)
|
||||
flagged_content = sum(1 for c in ai_checks if c.flagged)
|
||||
avg_confidence = (
|
||||
sum(c.confidence for c in ai_checks) / total_checks if total_checks > 0 else 0.0
|
||||
)
|
||||
false_positives = sum(1 for c in ai_checks if c.is_false_positive)
|
||||
avg_response_time = (
|
||||
sum(c.response_time_ms for c in ai_checks) / total_checks if total_checks > 0 else 0.0
|
||||
)
|
||||
|
||||
ai_performance = AIPerformanceStats(
|
||||
total_checks=total_checks,
|
||||
flagged_content=flagged_content,
|
||||
avg_confidence=avg_confidence,
|
||||
false_positives=false_positives,
|
||||
avg_response_time_ms=avg_response_time,
|
||||
)
|
||||
|
||||
return AnalyticsSummary(
|
||||
moderation_stats=moderation_stats,
|
||||
user_activity=user_activity,
|
||||
ai_performance=ai_performance,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/moderation-stats",
|
||||
response_model=ModerationStats,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def moderation_stats(
|
||||
guild_id: int | None = Query(default=None),
|
||||
days: int = Query(default=30, ge=1, le=90),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> ModerationStats:
|
||||
"""Get detailed moderation statistics."""
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = select(ModerationLog).where(ModerationLog.created_at >= start_date)
|
||||
if guild_id:
|
||||
query = query.where(ModerationLog.guild_id == guild_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
total_actions = len(logs)
|
||||
actions_by_type: dict[str, int] = {}
|
||||
automatic_count = 0
|
||||
manual_count = 0
|
||||
|
||||
for log in logs:
|
||||
actions_by_type[log.action] = actions_by_type.get(log.action, 0) + 1
|
||||
if log.is_automatic:
|
||||
automatic_count += 1
|
||||
else:
|
||||
manual_count += 1
|
||||
|
||||
# Time series data
|
||||
time_series: dict[str, int] = {}
|
||||
for log in logs:
|
||||
day_key = log.created_at.strftime("%Y-%m-%d")
|
||||
time_series[day_key] = time_series.get(day_key, 0) + 1
|
||||
|
||||
actions_over_time = [
|
||||
TimeSeriesDataPoint(timestamp=datetime.strptime(day, "%Y-%m-%d"), value=count)
|
||||
for day, count in sorted(time_series.items())
|
||||
]
|
||||
|
||||
return ModerationStats(
|
||||
total_actions=total_actions,
|
||||
actions_by_type=actions_by_type,
|
||||
actions_over_time=actions_over_time,
|
||||
automatic_vs_manual={"automatic": automatic_count, "manual": manual_count},
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/user-activity",
|
||||
response_model=UserActivityStats,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def user_activity_stats(
|
||||
guild_id: int | None = Query(default=None),
|
||||
days: int = Query(default=7, ge=1, le=90),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserActivityStats:
|
||||
"""Get user activity statistics."""
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = select(MessageActivity).where(MessageActivity.date >= start_date)
|
||||
if guild_id:
|
||||
query = query.where(MessageActivity.guild_id == guild_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
activities = result.scalars().all()
|
||||
|
||||
total_messages = sum(a.total_messages for a in activities)
|
||||
active_users = max((a.active_users for a in activities), default=0)
|
||||
|
||||
today = datetime.now().date()
|
||||
week_ago = today - timedelta(days=7)
|
||||
new_joins_today = sum(a.new_joins for a in activities if a.date.date() == today)
|
||||
new_joins_week = sum(a.new_joins for a in activities if a.date.date() >= week_ago)
|
||||
|
||||
return UserActivityStats(
|
||||
active_users=active_users,
|
||||
total_messages=total_messages,
|
||||
new_joins_today=new_joins_today,
|
||||
new_joins_week=new_joins_week,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/ai-performance",
|
||||
response_model=AIPerformanceStats,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def ai_performance_stats(
|
||||
guild_id: int | None = Query(default=None),
|
||||
days: int = Query(default=30, ge=1, le=90),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AIPerformanceStats:
|
||||
"""Get AI moderation performance statistics."""
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = select(AICheck).where(AICheck.created_at >= start_date)
|
||||
if guild_id:
|
||||
query = query.where(AICheck.guild_id == guild_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
checks = result.scalars().all()
|
||||
|
||||
total_checks = len(checks)
|
||||
flagged_content = sum(1 for c in checks if c.flagged)
|
||||
avg_confidence = (
|
||||
sum(c.confidence for c in checks) / total_checks if total_checks > 0 else 0.0
|
||||
)
|
||||
false_positives = sum(1 for c in checks if c.is_false_positive)
|
||||
avg_response_time = (
|
||||
sum(c.response_time_ms for c in checks) / total_checks if total_checks > 0 else 0.0
|
||||
)
|
||||
|
||||
return AIPerformanceStats(
|
||||
total_checks=total_checks,
|
||||
flagged_content=flagged_content,
|
||||
avg_confidence=avg_confidence,
|
||||
false_positives=false_positives,
|
||||
avg_response_time_ms=avg_response_time,
|
||||
)
|
||||
|
||||
return router
|
||||
78
src/guardden/dashboard/auth.py
Normal file
78
src/guardden/dashboard/auth.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Authentication helpers for the dashboard."""
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
|
||||
|
||||
def build_oauth(settings: DashboardSettings) -> OAuth:
|
||||
"""Build OAuth client registrations."""
|
||||
oauth = OAuth()
|
||||
oauth.register(
|
||||
name="entra",
|
||||
client_id=settings.entra_client_id,
|
||||
client_secret=settings.entra_client_secret.get_secret_value(),
|
||||
server_metadata_url=(
|
||||
"https://login.microsoftonline.com/"
|
||||
f"{settings.entra_tenant_id}/v2.0/.well-known/openid-configuration"
|
||||
),
|
||||
client_kwargs={"scope": "openid profile email"},
|
||||
)
|
||||
return oauth
|
||||
|
||||
|
||||
def discord_authorize_url(settings: DashboardSettings, state: str) -> str:
|
||||
"""Generate the Discord OAuth authorization URL."""
|
||||
query = urlencode(
|
||||
{
|
||||
"client_id": settings.discord_client_id,
|
||||
"redirect_uri": settings.callback_url("discord"),
|
||||
"response_type": "code",
|
||||
"scope": "identify",
|
||||
"state": state,
|
||||
}
|
||||
)
|
||||
return f"https://discord.com/oauth2/authorize?{query}"
|
||||
|
||||
|
||||
async def exchange_discord_code(settings: DashboardSettings, code: str) -> dict[str, Any]:
|
||||
"""Exchange a Discord OAuth code for a user profile."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
token_response = await client.post(
|
||||
"https://discord.com/api/oauth2/token",
|
||||
data={
|
||||
"client_id": settings.discord_client_id,
|
||||
"client_secret": settings.discord_client_secret.get_secret_value(),
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": settings.callback_url("discord"),
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
token_data = token_response.json()
|
||||
|
||||
user_response = await client.get(
|
||||
"https://discord.com/api/users/@me",
|
||||
headers={"Authorization": f"Bearer {token_data['access_token']}"},
|
||||
)
|
||||
user_response.raise_for_status()
|
||||
return user_response.json()
|
||||
|
||||
|
||||
def require_owner(settings: DashboardSettings, request: Request) -> None:
|
||||
"""Ensure the current session is the configured owner."""
|
||||
session = request.session
|
||||
entra_oid = session.get("entra_oid")
|
||||
discord_id = session.get("discord_id")
|
||||
if not entra_oid or not discord_id:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
||||
if str(entra_oid) != settings.owner_entra_object_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
if int(discord_id) != settings.owner_discord_id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
68
src/guardden/dashboard/config.py
Normal file
68
src/guardden/dashboard/config.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Configuration for the GuardDen dashboard."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, SecretStr, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class DashboardSettings(BaseSettings):
|
||||
"""Dashboard settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
env_prefix="GUARDDEN_DASHBOARD_",
|
||||
)
|
||||
|
||||
database_url: SecretStr = Field(
|
||||
validation_alias="GUARDDEN_DATABASE_URL",
|
||||
description="Database connection URL",
|
||||
)
|
||||
|
||||
base_url: str = Field(
|
||||
default="http://localhost:8080",
|
||||
description="Base URL for OAuth callbacks",
|
||||
)
|
||||
secret_key: SecretStr = Field(
|
||||
default=SecretStr("change-me"),
|
||||
description="Session secret key",
|
||||
)
|
||||
|
||||
entra_tenant_id: str = Field(description="Entra ID tenant ID")
|
||||
entra_client_id: str = Field(description="Entra ID application client ID")
|
||||
entra_client_secret: SecretStr = Field(description="Entra ID application client secret")
|
||||
|
||||
discord_client_id: str = Field(description="Discord OAuth client ID")
|
||||
discord_client_secret: SecretStr = Field(description="Discord OAuth client secret")
|
||||
|
||||
owner_discord_id: int = Field(description="Discord user ID allowed to access dashboard")
|
||||
owner_entra_object_id: str = Field(description="Entra ID object ID allowed to access")
|
||||
|
||||
cors_origins: list[str] = Field(default_factory=list, description="Allowed CORS origins")
|
||||
static_dir: Path = Field(
|
||||
default=Path("dashboard/frontend/dist"),
|
||||
description="Directory containing built frontend assets",
|
||||
)
|
||||
|
||||
@field_validator("cors_origins", mode="before")
|
||||
@classmethod
|
||||
def _parse_origins(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, list):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return []
|
||||
return [item.strip() for item in text.split(",") if item.strip()]
|
||||
|
||||
def callback_url(self, provider: str) -> str:
|
||||
return f"{self.base_url}/auth/{provider}/callback"
|
||||
|
||||
|
||||
def get_dashboard_settings() -> DashboardSettings:
|
||||
"""Load dashboard settings from environment."""
|
||||
return DashboardSettings()
|
||||
298
src/guardden/dashboard/config_management.py
Normal file
298
src/guardden/dashboard/config_management.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Configuration management API routes for the GuardDen dashboard."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from guardden.dashboard.auth import require_owner
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
from guardden.dashboard.db import DashboardDatabase
|
||||
from guardden.dashboard.schemas import AutomodRuleConfig, ConfigExport, GuildSettings
|
||||
from guardden.models import Guild
|
||||
from guardden.models import GuildSettings as GuildSettingsModel
|
||||
|
||||
|
||||
def create_config_router(
|
||||
settings: DashboardSettings,
|
||||
database: DashboardDatabase,
|
||||
) -> APIRouter:
|
||||
"""Create the configuration management API router."""
|
||||
router = APIRouter(prefix="/api/guilds")
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
async for session in database.session():
|
||||
yield session
|
||||
|
||||
def require_owner_dep(request: Request) -> None:
|
||||
require_owner(settings, request)
|
||||
|
||||
@router.get(
|
||||
"/{guild_id}/settings",
|
||||
response_model=GuildSettings,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def get_guild_settings(
|
||||
guild_id: int = Path(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> GuildSettings:
|
||||
"""Get guild settings."""
|
||||
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
|
||||
result = await session.execute(query)
|
||||
guild_settings = result.scalar_one_or_none()
|
||||
|
||||
if not guild_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Guild settings not found",
|
||||
)
|
||||
|
||||
return GuildSettings(
|
||||
guild_id=guild_settings.guild_id,
|
||||
prefix=guild_settings.prefix,
|
||||
log_channel_id=guild_settings.log_channel_id,
|
||||
automod_enabled=guild_settings.automod_enabled,
|
||||
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
|
||||
ai_sensitivity=guild_settings.ai_sensitivity,
|
||||
verification_enabled=guild_settings.verification_enabled,
|
||||
verification_role_id=guild_settings.verified_role_id,
|
||||
max_warns_before_action=3, # Default value, could be derived from strike_actions
|
||||
)
|
||||
|
||||
@router.put(
|
||||
"/{guild_id}/settings",
|
||||
response_model=GuildSettings,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def update_guild_settings(
|
||||
guild_id: int = Path(...),
|
||||
settings_data: GuildSettings = ...,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> GuildSettings:
|
||||
"""Update guild settings."""
|
||||
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
|
||||
result = await session.execute(query)
|
||||
guild_settings = result.scalar_one_or_none()
|
||||
|
||||
if not guild_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Guild settings not found",
|
||||
)
|
||||
|
||||
# Update settings
|
||||
if settings_data.prefix is not None:
|
||||
guild_settings.prefix = settings_data.prefix
|
||||
if settings_data.log_channel_id is not None:
|
||||
guild_settings.log_channel_id = settings_data.log_channel_id
|
||||
guild_settings.automod_enabled = settings_data.automod_enabled
|
||||
guild_settings.ai_moderation_enabled = settings_data.ai_moderation_enabled
|
||||
guild_settings.ai_sensitivity = settings_data.ai_sensitivity
|
||||
guild_settings.verification_enabled = settings_data.verification_enabled
|
||||
if settings_data.verification_role_id is not None:
|
||||
guild_settings.verified_role_id = settings_data.verification_role_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(guild_settings)
|
||||
|
||||
return GuildSettings(
|
||||
guild_id=guild_settings.guild_id,
|
||||
prefix=guild_settings.prefix,
|
||||
log_channel_id=guild_settings.log_channel_id,
|
||||
automod_enabled=guild_settings.automod_enabled,
|
||||
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
|
||||
ai_sensitivity=guild_settings.ai_sensitivity,
|
||||
verification_enabled=guild_settings.verification_enabled,
|
||||
verification_role_id=guild_settings.verified_role_id,
|
||||
max_warns_before_action=3,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{guild_id}/automod",
|
||||
response_model=AutomodRuleConfig,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def get_automod_config(
|
||||
guild_id: int = Path(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AutomodRuleConfig:
|
||||
"""Get automod rule configuration."""
|
||||
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
|
||||
result = await session.execute(query)
|
||||
guild_settings = result.scalar_one_or_none()
|
||||
|
||||
if not guild_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Guild settings not found",
|
||||
)
|
||||
|
||||
return AutomodRuleConfig(
|
||||
guild_id=guild_settings.guild_id,
|
||||
banned_words_enabled=True, # Derived from automod_enabled
|
||||
scam_detection_enabled=guild_settings.automod_enabled,
|
||||
spam_detection_enabled=guild_settings.anti_spam_enabled,
|
||||
invite_filter_enabled=guild_settings.link_filter_enabled,
|
||||
max_mentions=guild_settings.mention_limit,
|
||||
max_emojis=10, # Default value
|
||||
spam_threshold=guild_settings.message_rate_limit,
|
||||
)
|
||||
|
||||
@router.put(
|
||||
"/{guild_id}/automod",
|
||||
response_model=AutomodRuleConfig,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def update_automod_config(
|
||||
guild_id: int = Path(...),
|
||||
automod_data: AutomodRuleConfig = ...,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> AutomodRuleConfig:
|
||||
"""Update automod rule configuration."""
|
||||
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
|
||||
result = await session.execute(query)
|
||||
guild_settings = result.scalar_one_or_none()
|
||||
|
||||
if not guild_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Guild settings not found",
|
||||
)
|
||||
|
||||
# Update automod settings
|
||||
guild_settings.automod_enabled = automod_data.scam_detection_enabled
|
||||
guild_settings.anti_spam_enabled = automod_data.spam_detection_enabled
|
||||
guild_settings.link_filter_enabled = automod_data.invite_filter_enabled
|
||||
guild_settings.mention_limit = automod_data.max_mentions
|
||||
guild_settings.message_rate_limit = automod_data.spam_threshold
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(guild_settings)
|
||||
|
||||
return AutomodRuleConfig(
|
||||
guild_id=guild_settings.guild_id,
|
||||
banned_words_enabled=automod_data.banned_words_enabled,
|
||||
scam_detection_enabled=guild_settings.automod_enabled,
|
||||
spam_detection_enabled=guild_settings.anti_spam_enabled,
|
||||
invite_filter_enabled=guild_settings.link_filter_enabled,
|
||||
max_mentions=guild_settings.mention_limit,
|
||||
max_emojis=10,
|
||||
spam_threshold=guild_settings.message_rate_limit,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{guild_id}/export",
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def export_config(
|
||||
guild_id: int = Path(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
"""Export guild configuration as JSON."""
|
||||
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
|
||||
result = await session.execute(query)
|
||||
guild_settings = result.scalar_one_or_none()
|
||||
|
||||
if not guild_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Guild settings not found",
|
||||
)
|
||||
|
||||
# Build export data
|
||||
export_data = ConfigExport(
|
||||
version="1.0",
|
||||
guild_settings=GuildSettings(
|
||||
guild_id=guild_settings.guild_id,
|
||||
prefix=guild_settings.prefix,
|
||||
log_channel_id=guild_settings.log_channel_id,
|
||||
automod_enabled=guild_settings.automod_enabled,
|
||||
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
|
||||
ai_sensitivity=guild_settings.ai_sensitivity,
|
||||
verification_enabled=guild_settings.verification_enabled,
|
||||
verification_role_id=guild_settings.verified_role_id,
|
||||
max_warns_before_action=3,
|
||||
),
|
||||
automod_rules=AutomodRuleConfig(
|
||||
guild_id=guild_settings.guild_id,
|
||||
banned_words_enabled=True,
|
||||
scam_detection_enabled=guild_settings.automod_enabled,
|
||||
spam_detection_enabled=guild_settings.anti_spam_enabled,
|
||||
invite_filter_enabled=guild_settings.link_filter_enabled,
|
||||
max_mentions=guild_settings.mention_limit,
|
||||
max_emojis=10,
|
||||
spam_threshold=guild_settings.message_rate_limit,
|
||||
),
|
||||
exported_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Convert to JSON
|
||||
json_data = export_data.model_dump_json(indent=2)
|
||||
|
||||
return StreamingResponse(
|
||||
iter([json_data]),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f"attachment; filename=guild_{guild_id}_config.json"},
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/{guild_id}/import",
|
||||
response_model=GuildSettings,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def import_config(
|
||||
guild_id: int = Path(...),
|
||||
config_data: ConfigExport = ...,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> GuildSettings:
|
||||
"""Import guild configuration from JSON."""
|
||||
query = select(GuildSettingsModel).where(GuildSettingsModel.guild_id == guild_id)
|
||||
result = await session.execute(query)
|
||||
guild_settings = result.scalar_one_or_none()
|
||||
|
||||
if not guild_settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Guild settings not found",
|
||||
)
|
||||
|
||||
# Import settings
|
||||
settings = config_data.guild_settings
|
||||
if settings.prefix is not None:
|
||||
guild_settings.prefix = settings.prefix
|
||||
if settings.log_channel_id is not None:
|
||||
guild_settings.log_channel_id = settings.log_channel_id
|
||||
guild_settings.automod_enabled = settings.automod_enabled
|
||||
guild_settings.ai_moderation_enabled = settings.ai_moderation_enabled
|
||||
guild_settings.ai_sensitivity = settings.ai_sensitivity
|
||||
guild_settings.verification_enabled = settings.verification_enabled
|
||||
if settings.verification_role_id is not None:
|
||||
guild_settings.verified_role_id = settings.verification_role_id
|
||||
|
||||
# Import automod rules
|
||||
automod = config_data.automod_rules
|
||||
guild_settings.anti_spam_enabled = automod.spam_detection_enabled
|
||||
guild_settings.link_filter_enabled = automod.invite_filter_enabled
|
||||
guild_settings.mention_limit = automod.max_mentions
|
||||
guild_settings.message_rate_limit = automod.spam_threshold
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(guild_settings)
|
||||
|
||||
return GuildSettings(
|
||||
guild_id=guild_settings.guild_id,
|
||||
prefix=guild_settings.prefix,
|
||||
log_channel_id=guild_settings.log_channel_id,
|
||||
automod_enabled=guild_settings.automod_enabled,
|
||||
ai_moderation_enabled=guild_settings.ai_moderation_enabled,
|
||||
ai_sensitivity=guild_settings.ai_sensitivity,
|
||||
verification_enabled=guild_settings.verification_enabled,
|
||||
verification_role_id=guild_settings.verified_role_id,
|
||||
max_warns_before_action=3,
|
||||
)
|
||||
|
||||
return router
|
||||
24
src/guardden/dashboard/db.py
Normal file
24
src/guardden/dashboard/db.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Database helpers for the dashboard."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
|
||||
|
||||
class DashboardDatabase:
|
||||
"""Async database session factory for the dashboard."""
|
||||
|
||||
def __init__(self, settings: DashboardSettings) -> None:
|
||||
db_url = settings.database_url.get_secret_value()
|
||||
if db_url.startswith("postgresql://"):
|
||||
db_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
|
||||
self._engine = create_async_engine(db_url, pool_pre_ping=True)
|
||||
self._sessionmaker = async_sessionmaker(self._engine, expire_on_commit=False)
|
||||
|
||||
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||
"""Yield a database session."""
|
||||
async with self._sessionmaker() as session:
|
||||
yield session
|
||||
121
src/guardden/dashboard/main.py
Normal file
121
src/guardden/dashboard/main.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""FastAPI app for the GuardDen dashboard."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
from guardden.dashboard.analytics import create_analytics_router
|
||||
from guardden.dashboard.auth import (
|
||||
build_oauth,
|
||||
discord_authorize_url,
|
||||
exchange_discord_code,
|
||||
require_owner,
|
||||
)
|
||||
from guardden.dashboard.config import DashboardSettings, get_dashboard_settings
|
||||
from guardden.dashboard.config_management import create_config_router
|
||||
from guardden.dashboard.db import DashboardDatabase
|
||||
from guardden.dashboard.routes import create_api_router
|
||||
from guardden.dashboard.users import create_users_router
|
||||
from guardden.dashboard.websocket import create_websocket_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_dashboard_settings()
|
||||
database = DashboardDatabase(settings)
|
||||
oauth = build_oauth(settings)
|
||||
|
||||
app = FastAPI(title="GuardDen Dashboard")
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.secret_key.get_secret_value())
|
||||
|
||||
if settings.cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
def require_owner_dep(request: Request) -> None:
|
||||
require_owner(settings, request)
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/me")
|
||||
async def me(request: Request) -> dict[str, bool | str | None]:
|
||||
entra_oid = request.session.get("entra_oid")
|
||||
discord_id = request.session.get("discord_id")
|
||||
owner = str(entra_oid) == settings.owner_entra_object_id and str(discord_id) == str(
|
||||
settings.owner_discord_id
|
||||
)
|
||||
return {
|
||||
"entra": bool(entra_oid),
|
||||
"discord": bool(discord_id),
|
||||
"owner": owner,
|
||||
"entra_oid": entra_oid,
|
||||
"discord_id": discord_id,
|
||||
}
|
||||
|
||||
@app.get("/auth/entra/login")
|
||||
async def entra_login(request: Request) -> RedirectResponse:
|
||||
redirect_uri = settings.callback_url("entra")
|
||||
return await oauth.entra.authorize_redirect(request, redirect_uri)
|
||||
|
||||
@app.get("/auth/entra/callback")
|
||||
async def entra_callback(request: Request) -> RedirectResponse:
|
||||
token = await oauth.entra.authorize_access_token(request)
|
||||
user = await oauth.entra.parse_id_token(request, token)
|
||||
request.session["entra_oid"] = user.get("oid")
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
@app.get("/auth/discord/login")
|
||||
async def discord_login(request: Request) -> RedirectResponse:
|
||||
state = secrets.token_urlsafe(16)
|
||||
request.session["discord_state"] = state
|
||||
return RedirectResponse(url=discord_authorize_url(settings, state))
|
||||
|
||||
@app.get("/auth/discord/callback")
|
||||
async def discord_callback(request: Request) -> RedirectResponse:
|
||||
params = dict(request.query_params)
|
||||
code = params.get("code")
|
||||
state = params.get("state")
|
||||
if not code or not state:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing code")
|
||||
if state != request.session.get("discord_state"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state")
|
||||
profile = await exchange_discord_code(settings, code)
|
||||
request.session["discord_id"] = profile.get("id")
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
@app.get("/auth/logout")
|
||||
async def logout(request: Request) -> RedirectResponse:
|
||||
request.session.clear()
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
# Include all API routers
|
||||
app.include_router(create_api_router(settings, database))
|
||||
app.include_router(create_analytics_router(settings, database))
|
||||
app.include_router(create_users_router(settings, database))
|
||||
app.include_router(create_config_router(settings, database))
|
||||
app.include_router(create_websocket_router(settings))
|
||||
|
||||
static_dir = Path(settings.static_dir)
|
||||
if static_dir.exists():
|
||||
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
|
||||
else:
|
||||
logger.warning("Static directory not found: %s", static_dir)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
87
src/guardden/dashboard/routes.py
Normal file
87
src/guardden/dashboard/routes.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""API routes for the GuardDen dashboard."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from guardden.dashboard.auth import require_owner
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
from guardden.dashboard.db import DashboardDatabase
|
||||
from guardden.dashboard.schemas import GuildSummary, ModerationLogEntry, PaginatedLogs
|
||||
from guardden.models import Guild, ModerationLog
|
||||
|
||||
|
||||
def create_api_router(
|
||||
settings: DashboardSettings,
|
||||
database: DashboardDatabase,
|
||||
) -> APIRouter:
|
||||
"""Create the dashboard API router."""
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
async for session in database.session():
|
||||
yield session
|
||||
|
||||
def require_owner_dep(request: Request) -> None:
|
||||
require_owner(settings, request)
|
||||
|
||||
@router.get("/guilds", response_model=list[GuildSummary], dependencies=[Depends(require_owner_dep)])
|
||||
async def list_guilds(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[GuildSummary]:
|
||||
result = await session.execute(select(Guild).order_by(Guild.name.asc()))
|
||||
guilds = result.scalars().all()
|
||||
return [
|
||||
GuildSummary(id=g.id, name=g.name, owner_id=g.owner_id, premium=g.premium)
|
||||
for g in guilds
|
||||
]
|
||||
|
||||
@router.get(
|
||||
"/moderation/logs",
|
||||
response_model=PaginatedLogs,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def list_moderation_logs(
|
||||
guild_id: int | None = Query(default=None),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> PaginatedLogs:
|
||||
query = select(ModerationLog)
|
||||
count_query = select(func.count(ModerationLog.id))
|
||||
if guild_id:
|
||||
query = query.where(ModerationLog.guild_id == guild_id)
|
||||
count_query = count_query.where(ModerationLog.guild_id == guild_id)
|
||||
|
||||
query = query.order_by(ModerationLog.created_at.desc()).offset(offset).limit(limit)
|
||||
total_result = await session.execute(count_query)
|
||||
total = int(total_result.scalar() or 0)
|
||||
|
||||
result = await session.execute(query)
|
||||
logs = result.scalars().all()
|
||||
items = [
|
||||
ModerationLogEntry(
|
||||
id=log.id,
|
||||
guild_id=log.guild_id,
|
||||
target_id=log.target_id,
|
||||
target_name=log.target_name,
|
||||
moderator_id=log.moderator_id,
|
||||
moderator_name=log.moderator_name,
|
||||
action=log.action,
|
||||
reason=log.reason,
|
||||
duration=log.duration,
|
||||
expires_at=log.expires_at,
|
||||
channel_id=log.channel_id,
|
||||
message_id=log.message_id,
|
||||
message_content=log.message_content,
|
||||
is_automatic=log.is_automatic,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return PaginatedLogs(total=total, items=items)
|
||||
|
||||
return router
|
||||
163
src/guardden/dashboard/schemas.py
Normal file
163
src/guardden/dashboard/schemas.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Pydantic schemas for dashboard APIs."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GuildSummary(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
owner_id: int
|
||||
premium: bool
|
||||
|
||||
|
||||
class ModerationLogEntry(BaseModel):
|
||||
id: int
|
||||
guild_id: int
|
||||
target_id: int
|
||||
target_name: str
|
||||
moderator_id: int
|
||||
moderator_name: str
|
||||
action: str
|
||||
reason: str | None
|
||||
duration: int | None
|
||||
expires_at: datetime | None
|
||||
channel_id: int | None
|
||||
message_id: int | None
|
||||
message_content: str | None
|
||||
is_automatic: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class PaginatedLogs(BaseModel):
|
||||
total: int
|
||||
items: list[ModerationLogEntry]
|
||||
|
||||
|
||||
# Analytics Schemas
|
||||
class TimeSeriesDataPoint(BaseModel):
|
||||
timestamp: datetime
|
||||
value: int
|
||||
|
||||
|
||||
class ModerationStats(BaseModel):
|
||||
total_actions: int
|
||||
actions_by_type: dict[str, int]
|
||||
actions_over_time: list[TimeSeriesDataPoint]
|
||||
automatic_vs_manual: dict[str, int]
|
||||
|
||||
|
||||
class UserActivityStats(BaseModel):
|
||||
active_users: int
|
||||
total_messages: int
|
||||
new_joins_today: int
|
||||
new_joins_week: int
|
||||
|
||||
|
||||
class AIPerformanceStats(BaseModel):
|
||||
total_checks: int
|
||||
flagged_content: int
|
||||
avg_confidence: float
|
||||
false_positives: int = 0
|
||||
avg_response_time_ms: float = 0.0
|
||||
|
||||
|
||||
class AnalyticsSummary(BaseModel):
|
||||
moderation_stats: ModerationStats
|
||||
user_activity: UserActivityStats
|
||||
ai_performance: AIPerformanceStats
|
||||
|
||||
|
||||
# User Management Schemas
|
||||
class UserProfile(BaseModel):
|
||||
user_id: int
|
||||
username: str
|
||||
strike_count: int
|
||||
total_warnings: int
|
||||
total_kicks: int
|
||||
total_bans: int
|
||||
total_timeouts: int
|
||||
first_seen: datetime
|
||||
last_action: datetime | None
|
||||
|
||||
|
||||
class UserNote(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
guild_id: int
|
||||
moderator_id: int
|
||||
moderator_name: str
|
||||
content: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CreateUserNote(BaseModel):
|
||||
content: str = Field(min_length=1, max_length=2000)
|
||||
|
||||
|
||||
class BulkModerationAction(BaseModel):
|
||||
action: str = Field(pattern="^(ban|kick|timeout|warn)$")
|
||||
user_ids: list[int] = Field(min_length=1, max_length=100)
|
||||
reason: str | None = None
|
||||
duration: int | None = None
|
||||
|
||||
|
||||
class BulkActionResult(BaseModel):
|
||||
success_count: int
|
||||
failed_count: int
|
||||
errors: dict[int, str]
|
||||
|
||||
|
||||
# Configuration Schemas
|
||||
class GuildSettings(BaseModel):
|
||||
guild_id: int
|
||||
prefix: str | None = None
|
||||
log_channel_id: int | None = None
|
||||
automod_enabled: bool = True
|
||||
ai_moderation_enabled: bool = False
|
||||
ai_sensitivity: int = Field(ge=0, le=100, default=50)
|
||||
verification_enabled: bool = False
|
||||
verification_role_id: int | None = None
|
||||
max_warns_before_action: int = Field(ge=1, le=10, default=3)
|
||||
|
||||
|
||||
class AutomodRuleConfig(BaseModel):
|
||||
guild_id: int
|
||||
banned_words_enabled: bool = True
|
||||
scam_detection_enabled: bool = True
|
||||
spam_detection_enabled: bool = True
|
||||
invite_filter_enabled: bool = False
|
||||
max_mentions: int = Field(ge=1, le=20, default=5)
|
||||
max_emojis: int = Field(ge=1, le=50, default=10)
|
||||
spam_threshold: int = Field(ge=1, le=20, default=5)
|
||||
|
||||
|
||||
class ConfigExport(BaseModel):
|
||||
version: str = "1.0"
|
||||
guild_settings: GuildSettings
|
||||
automod_rules: AutomodRuleConfig
|
||||
exported_at: datetime
|
||||
|
||||
|
||||
# WebSocket Event Schemas
|
||||
class WebSocketEvent(BaseModel):
|
||||
type: str
|
||||
guild_id: int
|
||||
timestamp: datetime
|
||||
data: dict[str, object]
|
||||
|
||||
|
||||
class ModerationEvent(WebSocketEvent):
|
||||
type: str = "moderation_action"
|
||||
data: dict[str, object]
|
||||
|
||||
|
||||
class UserJoinEvent(WebSocketEvent):
|
||||
type: str = "user_join"
|
||||
data: dict[str, object]
|
||||
|
||||
|
||||
class AIAlertEvent(WebSocketEvent):
|
||||
type: str = "ai_alert"
|
||||
data: dict[str, object]
|
||||
246
src/guardden/dashboard/users.py
Normal file
246
src/guardden/dashboard/users.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""User management API routes for the GuardDen dashboard."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from guardden.dashboard.auth import require_owner
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
from guardden.dashboard.db import DashboardDatabase
|
||||
from guardden.dashboard.schemas import CreateUserNote, UserNote, UserProfile
|
||||
from guardden.models import ModerationLog, UserActivity
|
||||
from guardden.models import UserNote as UserNoteModel
|
||||
|
||||
|
||||
def create_users_router(
|
||||
settings: DashboardSettings,
|
||||
database: DashboardDatabase,
|
||||
) -> APIRouter:
|
||||
"""Create the user management API router."""
|
||||
router = APIRouter(prefix="/api/users")
|
||||
|
||||
async def get_session() -> AsyncIterator[AsyncSession]:
|
||||
async for session in database.session():
|
||||
yield session
|
||||
|
||||
def require_owner_dep(request: Request) -> None:
|
||||
require_owner(settings, request)
|
||||
|
||||
@router.get(
|
||||
"/search",
|
||||
response_model=list[UserProfile],
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def search_users(
|
||||
guild_id: int = Query(...),
|
||||
username: str | None = Query(default=None),
|
||||
min_strikes: int | None = Query(default=None, ge=0),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[UserProfile]:
|
||||
"""Search for users in a guild with optional filters."""
|
||||
query = select(UserActivity).where(UserActivity.guild_id == guild_id)
|
||||
|
||||
if username:
|
||||
query = query.where(UserActivity.username.ilike(f"%{username}%"))
|
||||
|
||||
if min_strikes is not None:
|
||||
query = query.where(UserActivity.strike_count >= min_strikes)
|
||||
|
||||
query = query.order_by(UserActivity.last_seen.desc()).limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
users = result.scalars().all()
|
||||
|
||||
# Get last moderation action for each user
|
||||
profiles = []
|
||||
for user in users:
|
||||
last_action_query = (
|
||||
select(ModerationLog.created_at)
|
||||
.where(ModerationLog.guild_id == guild_id)
|
||||
.where(ModerationLog.target_id == user.user_id)
|
||||
.order_by(ModerationLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
last_action_result = await session.execute(last_action_query)
|
||||
last_action = last_action_result.scalar()
|
||||
|
||||
profiles.append(
|
||||
UserProfile(
|
||||
user_id=user.user_id,
|
||||
username=user.username,
|
||||
strike_count=user.strike_count,
|
||||
total_warnings=user.warning_count,
|
||||
total_kicks=user.kick_count,
|
||||
total_bans=user.ban_count,
|
||||
total_timeouts=user.timeout_count,
|
||||
first_seen=user.first_seen,
|
||||
last_action=last_action,
|
||||
)
|
||||
)
|
||||
|
||||
return profiles
|
||||
|
||||
@router.get(
|
||||
"/{user_id}/profile",
|
||||
response_model=UserProfile,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def get_user_profile(
|
||||
user_id: int = Path(...),
|
||||
guild_id: int = Query(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Get detailed profile for a specific user."""
|
||||
query = (
|
||||
select(UserActivity)
|
||||
.where(UserActivity.guild_id == guild_id)
|
||||
.where(UserActivity.user_id == user_id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found in this guild",
|
||||
)
|
||||
|
||||
# Get last moderation action
|
||||
last_action_query = (
|
||||
select(ModerationLog.created_at)
|
||||
.where(ModerationLog.guild_id == guild_id)
|
||||
.where(ModerationLog.target_id == user_id)
|
||||
.order_by(ModerationLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
last_action_result = await session.execute(last_action_query)
|
||||
last_action = last_action_result.scalar()
|
||||
|
||||
return UserProfile(
|
||||
user_id=user.user_id,
|
||||
username=user.username,
|
||||
strike_count=user.strike_count,
|
||||
total_warnings=user.warning_count,
|
||||
total_kicks=user.kick_count,
|
||||
total_bans=user.ban_count,
|
||||
total_timeouts=user.timeout_count,
|
||||
first_seen=user.first_seen,
|
||||
last_action=last_action,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/{user_id}/notes",
|
||||
response_model=list[UserNote],
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def get_user_notes(
|
||||
user_id: int = Path(...),
|
||||
guild_id: int = Query(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[UserNote]:
|
||||
"""Get all notes for a specific user."""
|
||||
query = (
|
||||
select(UserNoteModel)
|
||||
.where(UserNoteModel.guild_id == guild_id)
|
||||
.where(UserNoteModel.user_id == user_id)
|
||||
.order_by(UserNoteModel.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
notes = result.scalars().all()
|
||||
|
||||
return [
|
||||
UserNote(
|
||||
id=note.id,
|
||||
user_id=note.user_id,
|
||||
guild_id=note.guild_id,
|
||||
moderator_id=note.moderator_id,
|
||||
moderator_name=note.moderator_name,
|
||||
content=note.content,
|
||||
created_at=note.created_at,
|
||||
)
|
||||
for note in notes
|
||||
]
|
||||
|
||||
@router.post(
|
||||
"/{user_id}/notes",
|
||||
response_model=UserNote,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def create_user_note(
|
||||
user_id: int = Path(...),
|
||||
guild_id: int = Query(...),
|
||||
note_data: CreateUserNote = ...,
|
||||
request: Request = ...,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> UserNote:
|
||||
"""Create a new note for a user."""
|
||||
# Get moderator info from session
|
||||
moderator_id = request.session.get("discord_id")
|
||||
if not moderator_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Discord authentication required",
|
||||
)
|
||||
|
||||
# Create the note
|
||||
new_note = UserNoteModel(
|
||||
user_id=user_id,
|
||||
guild_id=guild_id,
|
||||
moderator_id=int(moderator_id),
|
||||
moderator_name="Dashboard User", # TODO: Fetch actual username
|
||||
content=note_data.content,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
session.add(new_note)
|
||||
await session.commit()
|
||||
await session.refresh(new_note)
|
||||
|
||||
return UserNote(
|
||||
id=new_note.id,
|
||||
user_id=new_note.user_id,
|
||||
guild_id=new_note.guild_id,
|
||||
moderator_id=new_note.moderator_id,
|
||||
moderator_name=new_note.moderator_name,
|
||||
content=new_note.content,
|
||||
created_at=new_note.created_at,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/{user_id}/notes/{note_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Depends(require_owner_dep)],
|
||||
)
|
||||
async def delete_user_note(
|
||||
user_id: int = Path(...),
|
||||
note_id: int = Path(...),
|
||||
guild_id: int = Query(...),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a user note."""
|
||||
query = (
|
||||
select(UserNoteModel)
|
||||
.where(UserNoteModel.id == note_id)
|
||||
.where(UserNoteModel.guild_id == guild_id)
|
||||
.where(UserNoteModel.user_id == user_id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
note = result.scalar_one_or_none()
|
||||
|
||||
if not note:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Note not found",
|
||||
)
|
||||
|
||||
await session.delete(note)
|
||||
await session.commit()
|
||||
|
||||
return router
|
||||
221
src/guardden/dashboard/websocket.py
Normal file
221
src/guardden/dashboard/websocket.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""WebSocket support for real-time dashboard updates."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from guardden.dashboard.config import DashboardSettings
|
||||
from guardden.dashboard.schemas import WebSocketEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections for real-time updates."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: dict[int, list[WebSocket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, websocket: WebSocket, guild_id: int) -> None:
|
||||
"""Accept a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
if guild_id not in self.active_connections:
|
||||
self.active_connections[guild_id] = []
|
||||
self.active_connections[guild_id].append(websocket)
|
||||
logger.info("New WebSocket connection for guild %s", guild_id)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket, guild_id: int) -> None:
|
||||
"""Remove a WebSocket connection."""
|
||||
async with self._lock:
|
||||
if guild_id in self.active_connections:
|
||||
if websocket in self.active_connections[guild_id]:
|
||||
self.active_connections[guild_id].remove(websocket)
|
||||
if not self.active_connections[guild_id]:
|
||||
del self.active_connections[guild_id]
|
||||
logger.info("WebSocket disconnected for guild %s", guild_id)
|
||||
|
||||
async def broadcast_to_guild(self, guild_id: int, event: WebSocketEvent) -> None:
|
||||
"""Broadcast an event to all connections for a specific guild."""
|
||||
async with self._lock:
|
||||
connections = self.active_connections.get(guild_id, []).copy()
|
||||
|
||||
if not connections:
|
||||
return
|
||||
|
||||
# Convert event to JSON
|
||||
message = event.model_dump_json()
|
||||
|
||||
# Send to all connections
|
||||
dead_connections = []
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send message to WebSocket: %s", e)
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Clean up dead connections
|
||||
if dead_connections:
|
||||
async with self._lock:
|
||||
if guild_id in self.active_connections:
|
||||
for conn in dead_connections:
|
||||
if conn in self.active_connections[guild_id]:
|
||||
self.active_connections[guild_id].remove(conn)
|
||||
if not self.active_connections[guild_id]:
|
||||
del self.active_connections[guild_id]
|
||||
|
||||
async def broadcast_to_all(self, event: WebSocketEvent) -> None:
|
||||
"""Broadcast an event to all connections."""
|
||||
async with self._lock:
|
||||
all_guilds = list(self.active_connections.keys())
|
||||
|
||||
for guild_id in all_guilds:
|
||||
await self.broadcast_to_guild(guild_id, event)
|
||||
|
||||
def get_connection_count(self, guild_id: int | None = None) -> int:
|
||||
"""Get the number of active connections."""
|
||||
if guild_id is not None:
|
||||
return len(self.active_connections.get(guild_id, []))
|
||||
return sum(len(conns) for conns in self.active_connections.values())
|
||||
|
||||
|
||||
# Global connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
def create_websocket_router(settings: DashboardSettings) -> APIRouter:
|
||||
"""Create the WebSocket API router."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.websocket("/ws/events")
|
||||
async def websocket_events(websocket: WebSocket, guild_id: int) -> None:
|
||||
"""WebSocket endpoint for real-time events."""
|
||||
await connection_manager.connect(websocket, guild_id)
|
||||
try:
|
||||
# Send initial connection confirmation
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"guild_id": guild_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {"message": "Connected to real-time events"},
|
||||
}
|
||||
)
|
||||
|
||||
# Keep connection alive and handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Wait for messages from client (ping/pong, etc.)
|
||||
data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
|
||||
|
||||
# Echo back as heartbeat
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send periodic ping to keep connection alive
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "ping",
|
||||
"guild_id": guild_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": {},
|
||||
}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client disconnected from WebSocket for guild %s", guild_id)
|
||||
except Exception as e:
|
||||
logger.error("WebSocket error for guild %s: %s", guild_id, e)
|
||||
finally:
|
||||
await connection_manager.disconnect(websocket, guild_id)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
# Helper functions to broadcast events from other parts of the application
|
||||
async def broadcast_moderation_action(
|
||||
guild_id: int,
|
||||
action: str,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
moderator_name: str,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
"""Broadcast a moderation action event."""
|
||||
event = WebSocketEvent(
|
||||
type="moderation_action",
|
||||
guild_id=guild_id,
|
||||
timestamp=datetime.now(),
|
||||
data={
|
||||
"action": action,
|
||||
"target_id": target_id,
|
||||
"target_name": target_name,
|
||||
"moderator_name": moderator_name,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
await connection_manager.broadcast_to_guild(guild_id, event)
|
||||
|
||||
|
||||
async def broadcast_user_join(
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
username: str,
|
||||
) -> None:
|
||||
"""Broadcast a user join event."""
|
||||
event = WebSocketEvent(
|
||||
type="user_join",
|
||||
guild_id=guild_id,
|
||||
timestamp=datetime.now(),
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
},
|
||||
)
|
||||
await connection_manager.broadcast_to_guild(guild_id, event)
|
||||
|
||||
|
||||
async def broadcast_ai_alert(
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
severity: str,
|
||||
category: str,
|
||||
confidence: float,
|
||||
) -> None:
|
||||
"""Broadcast an AI moderation alert."""
|
||||
event = WebSocketEvent(
|
||||
type="ai_alert",
|
||||
guild_id=guild_id,
|
||||
timestamp=datetime.now(),
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"severity": severity,
|
||||
"category": category,
|
||||
"confidence": confidence,
|
||||
},
|
||||
)
|
||||
await connection_manager.broadcast_to_guild(guild_id, event)
|
||||
|
||||
|
||||
async def broadcast_system_event(
|
||||
event_type: str,
|
||||
data: dict[str, Any],
|
||||
guild_id: int | None = None,
|
||||
) -> None:
|
||||
"""Broadcast a generic system event."""
|
||||
event = WebSocketEvent(
|
||||
type=event_type,
|
||||
guild_id=guild_id or 0,
|
||||
timestamp=datetime.now(),
|
||||
data=data,
|
||||
)
|
||||
if guild_id:
|
||||
await connection_manager.broadcast_to_guild(guild_id, event)
|
||||
else:
|
||||
await connection_manager.broadcast_to_all(event)
|
||||
234
src/guardden/health.py
Normal file
234
src/guardden/health.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Health check utilities for GuardDen."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
from guardden.config import get_settings
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ai import create_ai_provider
|
||||
from guardden.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""Comprehensive health check system for GuardDen."""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self.database = Database(self.settings)
|
||||
self.ai_provider = create_ai_provider(self.settings)
|
||||
|
||||
async def check_database(self) -> Dict[str, Any]:
|
||||
"""Check database connectivity and performance."""
|
||||
try:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
async with self.database.session() as session:
|
||||
# Simple query to check connectivity
|
||||
result = await session.execute("SELECT 1 as test")
|
||||
test_value = result.scalar()
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
response_time_ms = (end_time - start_time) * 1000
|
||||
|
||||
return {
|
||||
"status": "healthy" if test_value == 1 else "unhealthy",
|
||||
"response_time_ms": round(response_time_ms, 2),
|
||||
"connection_pool": {
|
||||
"pool_size": self.database._engine.pool.size() if self.database._engine else 0,
|
||||
"checked_in": self.database._engine.pool.checkedin() if self.database._engine else 0,
|
||||
"checked_out": self.database._engine.pool.checkedout() if self.database._engine else 0,
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", exc_info=e)
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
|
||||
async def check_ai_provider(self) -> Dict[str, Any]:
|
||||
"""Check AI provider connectivity."""
|
||||
if self.settings.ai_provider == "none":
|
||||
return {
|
||||
"status": "disabled",
|
||||
"provider": "none"
|
||||
}
|
||||
|
||||
try:
|
||||
# Simple test to check if AI provider is responsive
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# This is a minimal test - actual implementation would depend on provider
|
||||
provider_type = type(self.ai_provider).__name__
|
||||
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
response_time_ms = (end_time - start_time) * 1000
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": self.settings.ai_provider,
|
||||
"provider_type": provider_type,
|
||||
"response_time_ms": round(response_time_ms, 2)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("AI provider health check failed", exc_info=e)
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"provider": self.settings.ai_provider,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
|
||||
async def check_discord_connectivity(self) -> Dict[str, Any]:
|
||||
"""Check Discord API connectivity (basic test)."""
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get("https://discord.com/api/v10/gateway") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
end_time = asyncio.get_event_loop().time()
|
||||
response_time_ms = (end_time - start_time) * 1000
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_ms": round(response_time_ms, 2),
|
||||
"gateway_url": data.get("url")
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"http_status": response.status,
|
||||
"error": f"HTTP {response.status}"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Discord connectivity check failed", exc_info=e)
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
|
||||
async def get_system_info(self) -> Dict[str, Any]:
|
||||
"""Get system information for health reporting."""
|
||||
import psutil
|
||||
import platform
|
||||
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
return {
|
||||
"platform": platform.platform(),
|
||||
"python_version": platform.python_version(),
|
||||
"cpu": {
|
||||
"count": psutil.cpu_count(),
|
||||
"usage_percent": psutil.cpu_percent(interval=1)
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": round(memory.total / 1024 / 1024),
|
||||
"available_mb": round(memory.available / 1024 / 1024),
|
||||
"usage_percent": memory.percent
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": round(disk.total / 1024 / 1024 / 1024),
|
||||
"free_gb": round(disk.free / 1024 / 1024 / 1024),
|
||||
"usage_percent": round((disk.used / disk.total) * 100, 1)
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("Failed to get system info", exc_info=e)
|
||||
return {
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
|
||||
async def perform_full_health_check(self) -> Dict[str, Any]:
|
||||
"""Perform comprehensive health check."""
|
||||
logger.info("Starting comprehensive health check")
|
||||
|
||||
checks = {
|
||||
"database": await self.check_database(),
|
||||
"ai_provider": await self.check_ai_provider(),
|
||||
"discord_api": await self.check_discord_connectivity(),
|
||||
"system": await self.get_system_info()
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
overall_status = "healthy"
|
||||
for check_name, check_result in checks.items():
|
||||
if check_name == "system":
|
||||
continue # System info doesn't affect health status
|
||||
|
||||
status = check_result.get("status", "unknown")
|
||||
if status in ["unhealthy", "error"]:
|
||||
overall_status = "unhealthy"
|
||||
break
|
||||
elif status == "degraded" and overall_status == "healthy":
|
||||
overall_status = "degraded"
|
||||
|
||||
result = {
|
||||
"status": overall_status,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
"checks": checks,
|
||||
"configuration": {
|
||||
"ai_provider": self.settings.ai_provider,
|
||||
"log_level": self.settings.log_level,
|
||||
"database_pool": {
|
||||
"min": self.settings.database_pool_min,
|
||||
"max": self.settings.database_pool_max
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Health check completed", extra={"overall_status": overall_status})
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
"""CLI health check command."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="GuardDen Health Check")
|
||||
parser.add_argument("--check", action="store_true", help="Perform health check and exit")
|
||||
parser.add_argument("--json", action="store_true", help="Output in JSON format")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.check:
|
||||
# Set up minimal logging for health check
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
health_checker = HealthChecker()
|
||||
result = await health_checker.perform_full_health_check()
|
||||
|
||||
if args.json:
|
||||
import json
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print(f"Overall Status: {result['status'].upper()}")
|
||||
for check_name, check_result in result["checks"].items():
|
||||
status = check_result.get("status", "unknown")
|
||||
print(f" {check_name}: {status}")
|
||||
if "response_time_ms" in check_result:
|
||||
print(f" Response time: {check_result['response_time_ms']}ms")
|
||||
if "error" in check_result:
|
||||
print(f" Error: {check_result['error']}")
|
||||
|
||||
# Exit with non-zero code if unhealthy
|
||||
if result["status"] != "healthy":
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("Use --check to perform health check")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,15 +1,19 @@
|
||||
"""Database models for GuardDen."""
|
||||
|
||||
from guardden.models.analytics import AICheck, MessageActivity, UserActivity
|
||||
from guardden.models.base import Base
|
||||
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
|
||||
__all__ = [
|
||||
"AICheck",
|
||||
"Base",
|
||||
"BannedWord",
|
||||
"Guild",
|
||||
"GuildSettings",
|
||||
"BannedWord",
|
||||
"MessageActivity",
|
||||
"ModerationLog",
|
||||
"Strike",
|
||||
"UserActivity",
|
||||
"UserNote",
|
||||
]
|
||||
|
||||
86
src/guardden/models/analytics.py
Normal file
86
src/guardden/models/analytics.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Analytics models for tracking bot usage and performance."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, Float, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||
|
||||
|
||||
class AICheck(Base, TimestampMixin):
|
||||
"""Record of AI moderation checks."""
|
||||
|
||||
__tablename__ = "ai_checks"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
|
||||
channel_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
message_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
# Check result
|
||||
flagged: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
category: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
severity: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Performance metrics
|
||||
response_time_ms: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
|
||||
# False positive tracking (set by moderators)
|
||||
is_false_positive: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, index=True
|
||||
)
|
||||
reviewed_by: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
|
||||
class MessageActivity(Base):
|
||||
"""Daily message activity statistics per guild."""
|
||||
|
||||
__tablename__ = "message_activity"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
|
||||
date: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Activity counts
|
||||
total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
active_users: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
new_joins: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Moderation activity
|
||||
automod_triggers: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
ai_checks: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
manual_actions: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
|
||||
class UserActivity(Base, TimestampMixin):
|
||||
"""Track user activity and first/last seen timestamps."""
|
||||
|
||||
__tablename__ = "user_activity"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False, index=True)
|
||||
|
||||
# User information
|
||||
username: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
|
||||
# Activity timestamps
|
||||
first_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
last_seen: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
last_message: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Activity counts
|
||||
message_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
command_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Moderation stats
|
||||
strike_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
warning_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
kick_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
ban_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
timeout_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
@@ -3,7 +3,7 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -66,6 +66,15 @@ class GuildSettings(Base, TimestampMixin):
|
||||
anti_spam_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
link_filter_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Automod thresholds
|
||||
message_rate_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
message_rate_window: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
duplicate_threshold: Mapped[int] = mapped_column(Integer, default=3, nullable=False)
|
||||
mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||
mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False)
|
||||
mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
||||
scam_allowlist: Mapped[list[str]] = mapped_column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Strike thresholds (actions at each threshold)
|
||||
strike_actions: Mapped[dict] = mapped_column(
|
||||
JSONB,
|
||||
@@ -81,6 +90,8 @@ class GuildSettings(Base, TimestampMixin):
|
||||
# AI moderation settings
|
||||
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
|
||||
ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||||
ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Verification settings
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Services for GuardDen."""
|
||||
|
||||
from guardden.services.automod import AutomodService
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
|
||||
from guardden.services.verification import ChallengeType, VerificationService
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guardden.services.automod import AutomodService
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
|
||||
from guardden.services.verification import ChallengeType, VerificationService
|
||||
|
||||
__all__ = [
|
||||
"AutomodService",
|
||||
@@ -14,3 +17,23 @@ __all__ = [
|
||||
"get_rate_limiter",
|
||||
"ratelimit",
|
||||
]
|
||||
|
||||
_LAZY_ATTRS = {
|
||||
"AutomodService": ("guardden.services.automod", "AutomodService"),
|
||||
"Database": ("guardden.services.database", "Database"),
|
||||
"RateLimiter": ("guardden.services.ratelimit", "RateLimiter"),
|
||||
"get_rate_limiter": ("guardden.services.ratelimit", "get_rate_limiter"),
|
||||
"ratelimit": ("guardden.services.ratelimit", "ratelimit"),
|
||||
"ChallengeType": ("guardden.services.verification", "ChallengeType"),
|
||||
"VerificationService": ("guardden.services.verification", "VerificationService"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in _LAZY_ATTRS:
|
||||
module_path, attr = _LAZY_ATTRS[name]
|
||||
module = __import__(module_path, fromlist=[attr])
|
||||
value = getattr(module, attr)
|
||||
globals()[name] = value
|
||||
return value
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
parse_categories,
|
||||
run_with_retries,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -96,7 +97,7 @@ class AnthropicProvider(AIProvider):
|
||||
|
||||
async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str:
|
||||
"""Make an API call to Claude."""
|
||||
try:
|
||||
async def _request() -> str:
|
||||
message = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
@@ -104,6 +105,13 @@ class AnthropicProvider(AIProvider):
|
||||
messages=[{"role": "user", "content": user_content}],
|
||||
)
|
||||
return message.content[0].text
|
||||
|
||||
try:
|
||||
return await run_with_retries(
|
||||
_request,
|
||||
logger=logger,
|
||||
operation_name="Anthropic API call",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
@@ -145,11 +153,7 @@ class AnthropicProvider(AIProvider):
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
categories = [
|
||||
ContentCategory(cat)
|
||||
for cat in data.get("categories", [])
|
||||
if cat in ContentCategory.__members__.values()
|
||||
]
|
||||
categories = parse_categories(data.get("categories", []))
|
||||
|
||||
return ModerationResult(
|
||||
is_flagged=data.get("is_flagged", False),
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Base classes for AI providers."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Literal, TypeVar
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
@@ -20,6 +23,64 @@ class ContentCategory(str, Enum):
|
||||
MISINFORMATION = "misinformation"
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RetryConfig:
|
||||
"""Retry configuration for AI calls."""
|
||||
|
||||
retries: int = 3
|
||||
base_delay: float = 0.25
|
||||
max_delay: float = 2.0
|
||||
|
||||
|
||||
def parse_categories(values: list[str]) -> list[ContentCategory]:
|
||||
"""Parse category values into ContentCategory enums."""
|
||||
categories: list[ContentCategory] = []
|
||||
for value in values:
|
||||
try:
|
||||
categories.append(ContentCategory(value))
|
||||
except ValueError:
|
||||
continue
|
||||
return categories
|
||||
|
||||
|
||||
async def run_with_retries(
|
||||
operation: Callable[[], Awaitable[_T]],
|
||||
*,
|
||||
config: RetryConfig | None = None,
|
||||
logger: logging.Logger | None = None,
|
||||
operation_name: str = "AI call",
|
||||
) -> _T:
|
||||
"""Run an async operation with retries and backoff."""
|
||||
retry_config = config or RetryConfig()
|
||||
delay = retry_config.base_delay
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(1, retry_config.retries + 1):
|
||||
try:
|
||||
return await operation()
|
||||
except Exception as error: # noqa: BLE001 - we re-raise after retries
|
||||
last_error = error
|
||||
if attempt >= retry_config.retries:
|
||||
raise
|
||||
if logger:
|
||||
logger.warning(
|
||||
"%s failed (attempt %s/%s): %s",
|
||||
operation_name,
|
||||
attempt,
|
||||
retry_config.retries,
|
||||
error,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(retry_config.max_delay, delay * 2)
|
||||
|
||||
if last_error:
|
||||
raise last_error
|
||||
raise RuntimeError("Retry loop exited unexpectedly")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModerationResult:
|
||||
"""Result of AI content moderation."""
|
||||
|
||||
@@ -9,6 +9,7 @@ from guardden.services.ai.base import (
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
run_with_retries,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,7 +42,7 @@ class OpenAIProvider(AIProvider):
|
||||
max_tokens: int = 500,
|
||||
) -> str:
|
||||
"""Make an API call to OpenAI."""
|
||||
try:
|
||||
async def _request() -> str:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
@@ -52,6 +53,13 @@ class OpenAIProvider(AIProvider):
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
try:
|
||||
return await run_with_retries(
|
||||
_request,
|
||||
logger=logger,
|
||||
operation_name="OpenAI chat completion",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
@@ -71,7 +79,14 @@ class OpenAIProvider(AIProvider):
|
||||
"""Analyze text content for policy violations."""
|
||||
# First, use OpenAI's built-in moderation API for quick check
|
||||
try:
|
||||
mod_response = await self.client.moderations.create(input=content)
|
||||
async def _moderate() -> Any:
|
||||
return await self.client.moderations.create(input=content)
|
||||
|
||||
mod_response = await run_with_retries(
|
||||
_moderate,
|
||||
logger=logger,
|
||||
operation_name="OpenAI moderation",
|
||||
)
|
||||
results = mod_response.results[0]
|
||||
|
||||
# Map OpenAI categories to our categories
|
||||
@@ -142,20 +157,27 @@ class OpenAIProvider(AIProvider):
|
||||
sensitivity_note = " Be strict - flag suggestive content."
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="gpt-4o-mini", # Use vision-capable model
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{"role": "system", "content": system + sensitivity_note},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Analyze this image for moderation."},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
async def _request() -> Any:
|
||||
return await self.client.chat.completions.create(
|
||||
model="gpt-4o-mini", # Use vision-capable model
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{"role": "system", "content": system + sensitivity_note},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Analyze this image for moderation."},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
response = await run_with_retries(
|
||||
_request,
|
||||
logger=logger,
|
||||
operation_name="OpenAI image analysis",
|
||||
)
|
||||
|
||||
data = self._parse_json_response(response.choices[0].message.content or "{}")
|
||||
|
||||
@@ -2,17 +2,150 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
import signal
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Sequence, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import discord
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
else:
|
||||
try:
|
||||
import discord # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
class _DiscordStub:
|
||||
class Message: # minimal stub for type hints
|
||||
pass
|
||||
|
||||
from guardden.models import BannedWord
|
||||
discord = _DiscordStub() # type: ignore
|
||||
|
||||
from guardden.models.guild import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Circuit breaker for regex safety
|
||||
class RegexTimeoutError(Exception):
|
||||
"""Raised when regex execution takes too long."""
|
||||
pass
|
||||
|
||||
|
||||
class RegexCircuitBreaker:
|
||||
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
||||
|
||||
def __init__(self, timeout_seconds: float = 0.1):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.failed_patterns: dict[str, datetime] = {}
|
||||
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
||||
|
||||
def _timeout_handler(self, signum, frame):
|
||||
"""Signal handler for regex timeout."""
|
||||
raise RegexTimeoutError("Regex execution timed out")
|
||||
|
||||
def is_pattern_disabled(self, pattern: str) -> bool:
|
||||
"""Check if a pattern is temporarily disabled due to timeouts."""
|
||||
if pattern not in self.failed_patterns:
|
||||
return False
|
||||
|
||||
failure_time = self.failed_patterns[pattern]
|
||||
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
||||
# Re-enable the pattern after threshold time
|
||||
del self.failed_patterns[pattern]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
||||
"""Safely execute regex search with timeout protection."""
|
||||
if self.is_pattern_disabled(pattern):
|
||||
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
# Basic pattern validation to catch obviously problematic patterns
|
||||
if self._is_dangerous_pattern(pattern):
|
||||
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
old_handler = None
|
||||
try:
|
||||
# Set up timeout signal (Unix systems only)
|
||||
if hasattr(signal, 'SIGALRM'):
|
||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Compile and execute regex
|
||||
compiled_pattern = re.compile(pattern, flags)
|
||||
result = bool(compiled_pattern.search(text))
|
||||
|
||||
execution_time = time.perf_counter() - start_time
|
||||
|
||||
# Log slow patterns for monitoring
|
||||
if execution_time > self.timeout_seconds * 0.8:
|
||||
logger.warning(
|
||||
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except RegexTimeoutError:
|
||||
# Pattern took too long, disable it temporarily
|
||||
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
||||
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
||||
return False
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in regex execution: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
# Clean up timeout signal
|
||||
if hasattr(signal, 'SIGALRM') and old_handler is not None:
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
||||
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
||||
# Check for patterns that are commonly problematic
|
||||
dangerous_indicators = [
|
||||
r'(\w+)+', # Nested quantifiers
|
||||
r'(\d+)+', # Nested quantifiers on digits
|
||||
r'(.+)+', # Nested quantifiers on anything
|
||||
r'(.*)+', # Nested quantifiers on anything (greedy)
|
||||
r'(\w*)+', # Nested quantifiers with *
|
||||
r'(\S+)+', # Nested quantifiers on non-whitespace
|
||||
]
|
||||
|
||||
# Check for excessively long patterns
|
||||
if len(pattern) > 500:
|
||||
return True
|
||||
|
||||
# Check for nested quantifiers (simplified detection)
|
||||
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
|
||||
return True
|
||||
|
||||
# Check for excessive repetition operators
|
||||
if pattern.count('+') > 10 or pattern.count('*') > 10:
|
||||
return True
|
||||
|
||||
# Check for specific dangerous patterns
|
||||
for dangerous in dangerous_indicators:
|
||||
if dangerous in pattern:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global circuit breaker instance
|
||||
_regex_circuit_breaker = RegexCircuitBreaker()
|
||||
|
||||
|
||||
# Known scam/phishing patterns
|
||||
SCAM_PATTERNS = [
|
||||
@@ -47,10 +180,10 @@ SUSPICIOUS_TLDS = {
|
||||
".gq",
|
||||
}
|
||||
|
||||
# URL pattern for extraction
|
||||
# URL pattern for extraction - more restrictive for security
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
|
||||
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
|
||||
r"https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?:/[^\s]*)?|"
|
||||
r"(?:www\.)?[a-zA-Z0-9-]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|gov|edu)(?:/[^\s]*)?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
@@ -91,6 +224,66 @@ class AutomodResult:
|
||||
matched_filter: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpamConfig:
|
||||
"""Configuration for spam thresholds."""
|
||||
|
||||
message_rate_limit: int = 5
|
||||
message_rate_window: int = 5
|
||||
duplicate_threshold: int = 3
|
||||
mention_limit: int = 5
|
||||
mention_rate_limit: int = 10
|
||||
mention_rate_window: int = 60
|
||||
|
||||
|
||||
def normalize_domain(value: str) -> str:
|
||||
"""Normalize a domain or URL for allowlist checks with security validation."""
|
||||
if not value or not isinstance(value, str):
|
||||
return ""
|
||||
|
||||
text = value.strip().lower()
|
||||
if not text or len(text) > 2000: # Prevent excessively long URLs
|
||||
return ""
|
||||
|
||||
# Sanitize input to prevent injection attacks
|
||||
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
|
||||
return ""
|
||||
|
||||
try:
|
||||
if "://" not in text:
|
||||
text = f"http://{text}"
|
||||
|
||||
parsed = urlparse(text)
|
||||
hostname = parsed.hostname or ""
|
||||
|
||||
# Additional validation for hostname
|
||||
if not hostname or len(hostname) > 253: # RFC limit
|
||||
return ""
|
||||
|
||||
# Check for malicious patterns
|
||||
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
|
||||
return ""
|
||||
|
||||
# Remove www prefix
|
||||
if hostname.startswith("www."):
|
||||
hostname = hostname[4:]
|
||||
|
||||
return hostname
|
||||
except (ValueError, UnicodeError, Exception):
|
||||
# urlparse can raise various exceptions with malicious input
|
||||
return ""
|
||||
|
||||
|
||||
def is_allowed_domain(hostname: str, allowlist: set[str]) -> bool:
|
||||
"""Check if a hostname is allowlisted."""
|
||||
if not hostname:
|
||||
return False
|
||||
for domain in allowlist:
|
||||
if hostname == domain or hostname.endswith(f".{domain}"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AutomodService:
|
||||
"""Service for automatic content moderation."""
|
||||
|
||||
@@ -104,23 +297,25 @@ class AutomodService:
|
||||
lambda: defaultdict(UserSpamTracker)
|
||||
)
|
||||
|
||||
# Spam thresholds
|
||||
self.message_rate_limit = 5 # messages per window
|
||||
self.message_rate_window = 5 # seconds
|
||||
self.duplicate_threshold = 3 # same message count
|
||||
self.mention_limit = 5 # mentions per message
|
||||
self.mention_rate_limit = 10 # mentions per window
|
||||
self.mention_rate_window = 60 # seconds
|
||||
# Default spam thresholds
|
||||
self.default_spam_config = SpamConfig()
|
||||
|
||||
def _get_content_hash(self, content: str) -> str:
|
||||
"""Get a normalized hash of message content for duplicate detection."""
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
normalized = re.sub(r"[^\w\s]", "", content.lower())
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
# Use simple string operations for basic patterns to avoid regex overhead
|
||||
normalized = content.lower()
|
||||
|
||||
# Remove special characters (simplified approach)
|
||||
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
|
||||
|
||||
# Normalize whitespace
|
||||
normalized = ' '.join(normalized.split())
|
||||
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
self, content: str, banned_words: list[BannedWord]
|
||||
self, content: str, banned_words: Sequence[BannedWord]
|
||||
) -> AutomodResult | None:
|
||||
"""Check message against banned words list."""
|
||||
content_lower = content.lower()
|
||||
@@ -129,12 +324,9 @@ class AutomodService:
|
||||
matched = False
|
||||
|
||||
if banned.is_regex:
|
||||
try:
|
||||
if re.search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
except re.error:
|
||||
logger.warning(f"Invalid regex pattern: {banned.pattern}")
|
||||
continue
|
||||
# Use circuit breaker for safe regex execution
|
||||
if _regex_circuit_breaker.safe_regex_search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
else:
|
||||
if banned.pattern.lower() in content_lower:
|
||||
matched = True
|
||||
@@ -155,7 +347,9 @@ class AutomodService:
|
||||
|
||||
return None
|
||||
|
||||
def check_scam_links(self, content: str) -> AutomodResult | None:
|
||||
def check_scam_links(
|
||||
self, content: str, allowlist: list[str] | None = None
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for scam/phishing patterns."""
|
||||
# Check for known scam patterns
|
||||
for pattern in self._scam_patterns:
|
||||
@@ -167,10 +361,25 @@ class AutomodService:
|
||||
matched_filter="scam_pattern",
|
||||
)
|
||||
|
||||
allowlist_set = {normalize_domain(domain) for domain in allowlist or [] if domain}
|
||||
|
||||
# Check URLs for suspicious TLDs
|
||||
urls = URL_PATTERN.findall(content)
|
||||
for url in urls:
|
||||
# Limit URL length to prevent processing extremely long URLs
|
||||
if len(url) > 2000:
|
||||
continue
|
||||
|
||||
url_lower = url.lower()
|
||||
hostname = normalize_domain(url)
|
||||
|
||||
# Skip if hostname normalization failed (security check)
|
||||
if not hostname:
|
||||
continue
|
||||
|
||||
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
||||
continue
|
||||
|
||||
for tld in SUSPICIOUS_TLDS:
|
||||
if tld in url_lower:
|
||||
# Additional check: is it trying to impersonate a known domain?
|
||||
@@ -194,12 +403,21 @@ class AutomodService:
|
||||
return None
|
||||
|
||||
def check_spam(
|
||||
self, message: discord.Message, anti_spam_enabled: bool = True
|
||||
self,
|
||||
message: discord.Message,
|
||||
anti_spam_enabled: bool = True,
|
||||
spam_config: SpamConfig | None = None,
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for spam behavior."""
|
||||
if not anti_spam_enabled:
|
||||
return None
|
||||
|
||||
# Skip DM messages
|
||||
if message.guild is None:
|
||||
return None
|
||||
|
||||
config = spam_config or self.default_spam_config
|
||||
|
||||
guild_id = message.guild.id
|
||||
user_id = message.author.id
|
||||
tracker = self._spam_trackers[guild_id][user_id]
|
||||
@@ -213,21 +431,24 @@ class AutomodService:
|
||||
tracker.messages.append(SpamRecord(content_hash, now))
|
||||
|
||||
# Rate limit check
|
||||
recent_window = now - timedelta(seconds=self.message_rate_window)
|
||||
recent_window = now - timedelta(seconds=config.message_rate_window)
|
||||
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||
|
||||
if len(recent_messages) > self.message_rate_limit:
|
||||
if len(recent_messages) > config.message_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=60, # 1 minute timeout
|
||||
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
|
||||
reason=(
|
||||
f"Sending messages too fast ({len(recent_messages)} in "
|
||||
f"{config.message_rate_window}s)"
|
||||
),
|
||||
matched_filter="rate_limit",
|
||||
)
|
||||
|
||||
# Duplicate message check
|
||||
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||
if duplicate_count >= self.duplicate_threshold:
|
||||
if duplicate_count >= config.duplicate_threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
@@ -240,7 +461,7 @@ class AutomodService:
|
||||
if message.mention_everyone:
|
||||
mention_count += 100 # Treat @everyone as many mentions
|
||||
|
||||
if mention_count > self.mention_limit:
|
||||
if mention_count > config.mention_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
@@ -249,6 +470,26 @@ class AutomodService:
|
||||
matched_filter="mass_mention",
|
||||
)
|
||||
|
||||
if mention_count > 0:
|
||||
if tracker.last_mention_time:
|
||||
window = timedelta(seconds=config.mention_rate_window)
|
||||
if now - tracker.last_mention_time > window:
|
||||
tracker.mention_count = 0
|
||||
tracker.mention_count += mention_count
|
||||
tracker.last_mention_time = now
|
||||
|
||||
if tracker.mention_count > config.mention_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300,
|
||||
reason=(
|
||||
"Too many mentions in a short period "
|
||||
f"({tracker.mention_count} in {config.mention_rate_window}s)"
|
||||
),
|
||||
matched_filter="mention_rate",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||
|
||||
155
src/guardden/services/cache.py
Normal file
155
src/guardden/services/cache.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Redis caching service for improved performance."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""Service for caching data with Redis (optional) or in-memory fallback."""
|
||||
|
||||
def __init__(self, redis_url: str | None = None) -> None:
|
||||
self.redis_url = redis_url
|
||||
self._redis_client: Any = None
|
||||
self._memory_cache: dict[str, tuple[Any, float]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize Redis connection if URL is provided."""
|
||||
if not self.redis_url:
|
||||
logger.info("Redis URL not configured, using in-memory cache")
|
||||
return
|
||||
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis_client = await aioredis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection
|
||||
await self._redis_client.ping()
|
||||
logger.info("Redis cache initialized successfully")
|
||||
except ImportError:
|
||||
logger.warning("redis package not installed, using in-memory cache")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to Redis: %s, using in-memory cache", e)
|
||||
self._redis_client = None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis_client:
|
||||
await self._redis_client.close()
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
"""Get a value from cache."""
|
||||
if self._redis_client:
|
||||
try:
|
||||
value = await self._redis_client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Redis get error for key %s: %s", key, e)
|
||||
return None
|
||||
else:
|
||||
# In-memory fallback
|
||||
async with self._lock:
|
||||
if key in self._memory_cache:
|
||||
value, expiry = self._memory_cache[key]
|
||||
if expiry == 0 or asyncio.get_event_loop().time() < expiry:
|
||||
return value
|
||||
else:
|
||||
del self._memory_cache[key]
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: int = 300) -> bool:
|
||||
"""Set a value in cache with TTL in seconds."""
|
||||
if self._redis_client:
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
await self._redis_client.set(key, serialized, ex=ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Redis set error for key %s: %s", key, e)
|
||||
return False
|
||||
else:
|
||||
# In-memory fallback
|
||||
async with self._lock:
|
||||
expiry = asyncio.get_event_loop().time() + ttl if ttl > 0 else 0
|
||||
self._memory_cache[key] = (value, expiry)
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete a value from cache."""
|
||||
if self._redis_client:
|
||||
try:
|
||||
await self._redis_client.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Redis delete error for key %s: %s", key, e)
|
||||
return False
|
||||
else:
|
||||
async with self._lock:
|
||||
if key in self._memory_cache:
|
||||
del self._memory_cache[key]
|
||||
return True
|
||||
|
||||
async def clear_pattern(self, pattern: str) -> int:
|
||||
"""Clear all keys matching a pattern."""
|
||||
if self._redis_client:
|
||||
try:
|
||||
keys = []
|
||||
async for key in self._redis_client.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
if keys:
|
||||
await self._redis_client.delete(*keys)
|
||||
return len(keys)
|
||||
except Exception as e:
|
||||
logger.error("Redis clear pattern error for %s: %s", pattern, e)
|
||||
return 0
|
||||
else:
|
||||
# In-memory fallback
|
||||
async with self._lock:
|
||||
import fnmatch
|
||||
|
||||
keys_to_delete = [
|
||||
key for key in self._memory_cache.keys() if fnmatch.fnmatch(key, pattern)
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._memory_cache[key]
|
||||
return len(keys_to_delete)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._redis_client:
|
||||
return {"type": "redis", "url": self.redis_url}
|
||||
else:
|
||||
return {
|
||||
"type": "memory",
|
||||
"size": len(self._memory_cache),
|
||||
}
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_cache_service: CacheService | None = None
|
||||
|
||||
|
||||
def get_cache_service() -> CacheService:
|
||||
"""Get the global cache service instance."""
|
||||
global _cache_service
|
||||
if _cache_service is None:
|
||||
_cache_service = CacheService()
|
||||
return _cache_service
|
||||
|
||||
|
||||
def set_cache_service(service: CacheService) -> None:
|
||||
"""Set the global cache service instance."""
|
||||
global _cache_service
|
||||
_cache_service = service
|
||||
@@ -1,30 +1,43 @@
|
||||
"""Guild configuration service."""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
import discord
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from guardden.models import BannedWord, Guild, GuildSettings
|
||||
from guardden.services.cache import CacheService, get_cache_service
|
||||
from guardden.services.database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuildConfigService:
|
||||
"""Manages guild configurations with caching."""
|
||||
"""Manages guild configurations with multi-tier caching."""
|
||||
|
||||
def __init__(self, database: Database) -> None:
|
||||
def __init__(self, database: Database, cache: CacheService | None = None) -> None:
|
||||
self.database = database
|
||||
self._cache: dict[int, GuildSettings] = {}
|
||||
self.cache = cache or get_cache_service()
|
||||
self._memory_cache: dict[int, GuildSettings] = {}
|
||||
self._cache_ttl = 300 # 5 minutes
|
||||
|
||||
async def get_config(self, guild_id: int) -> GuildSettings | None:
|
||||
"""Get guild configuration, using cache if available."""
|
||||
if guild_id in self._cache:
|
||||
return self._cache[guild_id]
|
||||
"""Get guild configuration, using multi-tier cache."""
|
||||
# Check memory cache first
|
||||
if guild_id in self._memory_cache:
|
||||
return self._memory_cache[guild_id]
|
||||
|
||||
# Check Redis cache
|
||||
cache_key = f"guild_config:{guild_id}"
|
||||
cached_data = await self.cache.get(cache_key)
|
||||
if cached_data:
|
||||
# Store in memory cache for faster access
|
||||
settings = GuildSettings(**cached_data)
|
||||
self._memory_cache[guild_id] = settings
|
||||
return settings
|
||||
|
||||
# Fetch from database
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
@@ -32,7 +45,19 @@ class GuildConfigService:
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if settings:
|
||||
self._cache[guild_id] = settings
|
||||
# Store in both caches
|
||||
self._memory_cache[guild_id] = settings
|
||||
# Serialize settings for Redis
|
||||
settings_dict = {
|
||||
"guild_id": settings.guild_id,
|
||||
"prefix": settings.prefix,
|
||||
"log_channel_id": settings.log_channel_id,
|
||||
"automod_enabled": settings.automod_enabled,
|
||||
"ai_moderation_enabled": settings.ai_moderation_enabled,
|
||||
"ai_sensitivity": settings.ai_sensitivity,
|
||||
# Add other fields as needed
|
||||
}
|
||||
await self.cache.set(cache_key, settings_dict, ttl=self._cache_ttl)
|
||||
|
||||
return settings
|
||||
|
||||
@@ -94,9 +119,11 @@ class GuildConfigService:
|
||||
|
||||
return settings
|
||||
|
||||
def invalidate_cache(self, guild_id: int) -> None:
|
||||
"""Remove a guild from the cache."""
|
||||
self._cache.pop(guild_id, None)
|
||||
async def invalidate_cache(self, guild_id: int) -> None:
|
||||
"""Remove a guild from all caches."""
|
||||
self._memory_cache.pop(guild_id, None)
|
||||
cache_key = f"guild_config:{guild_id}"
|
||||
await self.cache.delete(cache_key)
|
||||
|
||||
async def get_banned_words(self, guild_id: int) -> list[BannedWord]:
|
||||
"""Get all banned words for a guild."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -211,6 +212,23 @@ class RateLimiter:
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def acquire_command(
|
||||
self,
|
||||
command_name: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""Acquire a per-command rate limit slot."""
|
||||
action = f"command:{command_name}"
|
||||
if action not in self._configs:
|
||||
base = self._configs.get("command", RateLimitConfig(5, 10, RateLimitScope.MEMBER))
|
||||
self.configure(
|
||||
action,
|
||||
RateLimitConfig(base.max_requests, base.window_seconds, base.scope),
|
||||
)
|
||||
return self.acquire(action, user_id=user_id, guild_id=guild_id, channel_id=channel_id)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
action: str,
|
||||
@@ -266,6 +284,7 @@ def ratelimit(
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(self, ctx, *args, **kwargs):
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
@@ -292,9 +311,6 @@ def ratelimit(
|
||||
|
||||
return await func(self, ctx, *args, **kwargs)
|
||||
|
||||
# Preserve function metadata
|
||||
wrapper.__name__ = func.__name__
|
||||
wrapper.__doc__ = func.__doc__
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -10,8 +10,6 @@ from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -217,6 +215,28 @@ class EmojiChallengeGenerator(ChallengeGenerator):
|
||||
return names.get(emoji, "correct")
|
||||
|
||||
|
||||
class QuestionsChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates custom question challenges."""
|
||||
|
||||
DEFAULT_QUESTIONS = [
|
||||
("What color is the sky on a clear day?", "blue"),
|
||||
("Type the word 'verified' to continue.", "verified"),
|
||||
("What is 2 + 2?", "4"),
|
||||
("What planet do we live on?", "earth"),
|
||||
]
|
||||
|
||||
def __init__(self, questions: list[tuple[str, str]] | None = None) -> None:
|
||||
self.questions = questions or self.DEFAULT_QUESTIONS
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
question, answer = random.choice(self.questions)
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.QUESTIONS,
|
||||
question=question,
|
||||
answer=answer,
|
||||
)
|
||||
|
||||
|
||||
class VerificationService:
|
||||
"""Service for managing member verification."""
|
||||
|
||||
@@ -230,6 +250,7 @@ class VerificationService:
|
||||
ChallengeType.CAPTCHA: CaptchaChallengeGenerator(),
|
||||
ChallengeType.MATH: MathChallengeGenerator(),
|
||||
ChallengeType.EMOJI: EmojiChallengeGenerator(),
|
||||
ChallengeType.QUESTIONS: QuestionsChallengeGenerator(),
|
||||
}
|
||||
|
||||
def create_challenge(
|
||||
|
||||
@@ -1,5 +1,30 @@
|
||||
"""Utility functions for GuardDen."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from guardden.utils.logging import setup_logging
|
||||
|
||||
__all__ = ["setup_logging"]
|
||||
|
||||
def parse_duration(duration_str: str) -> timedelta | None:
|
||||
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
|
||||
import re
|
||||
|
||||
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
units = {
|
||||
"s": timedelta(seconds=amount),
|
||||
"m": timedelta(minutes=amount),
|
||||
"h": timedelta(hours=amount),
|
||||
"d": timedelta(days=amount),
|
||||
"w": timedelta(weeks=amount),
|
||||
}
|
||||
|
||||
return units.get(unit)
|
||||
|
||||
|
||||
__all__ = ["parse_duration", "setup_logging"]
|
||||
|
||||
@@ -1,27 +1,294 @@
|
||||
"""Logging configuration for GuardDen."""
|
||||
"""Structured logging utilities for GuardDen."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Literal
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
try:
|
||||
import structlog
|
||||
from structlog.contextvars import bind_contextvars, clear_contextvars, unbind_contextvars
|
||||
from structlog.stdlib import BoundLogger
|
||||
STRUCTLOG_AVAILABLE = True
|
||||
except ImportError:
|
||||
STRUCTLOG_AVAILABLE = False
|
||||
# Fallback types when structlog is not available
|
||||
BoundLogger = logging.Logger
|
||||
|
||||
|
||||
def setup_logging(level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO") -> None:
|
||||
"""Configure logging for the application."""
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||
date_format = "%Y-%m-%d %H:%M:%S"
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Custom JSON formatter for structured logging."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Format log record as JSON."""
|
||||
log_data = {
|
||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
"function": record.funcName,
|
||||
"line": record.lineno,
|
||||
}
|
||||
|
||||
# Add exception information if present
|
||||
if record.exc_info:
|
||||
log_data["exception"] = {
|
||||
"type": record.exc_info[0].__name__ if record.exc_info[0] else None,
|
||||
"message": str(record.exc_info[1]) if record.exc_info[1] else None,
|
||||
"traceback": self.formatException(record.exc_info) if record.exc_info else None,
|
||||
}
|
||||
|
||||
# Add extra fields from the record
|
||||
extra_fields = {}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in {
|
||||
'name', 'msg', 'args', 'levelname', 'levelno', 'pathname', 'filename',
|
||||
'module', 'lineno', 'funcName', 'created', 'msecs', 'relativeCreated',
|
||||
'thread', 'threadName', 'processName', 'process', 'getMessage',
|
||||
'exc_info', 'exc_text', 'stack_info', 'message'
|
||||
}:
|
||||
extra_fields[key] = value
|
||||
|
||||
if extra_fields:
|
||||
log_data["extra"] = extra_fields
|
||||
|
||||
return json.dumps(log_data, default=str, ensure_ascii=False)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, level),
|
||||
format=log_format,
|
||||
datefmt=date_format,
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.http").setLevel(logging.WARNING)
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(
|
||||
logging.DEBUG if level == "DEBUG" else logging.WARNING
|
||||
)
|
||||
class GuardDenLogger:
|
||||
"""Custom logger configuration for GuardDen."""
|
||||
|
||||
def __init__(self, level: str = "INFO", json_format: bool = False):
|
||||
self.level = level.upper()
|
||||
self.json_format = json_format
|
||||
self.configure_logging()
|
||||
|
||||
def configure_logging(self) -> None:
|
||||
"""Configure structured logging for the application."""
|
||||
# Clear any existing configuration
|
||||
logging.root.handlers.clear()
|
||||
|
||||
if STRUCTLOG_AVAILABLE and self.json_format:
|
||||
self._configure_structlog()
|
||||
else:
|
||||
self._configure_stdlib_logging()
|
||||
|
||||
# Configure specific loggers
|
||||
self._configure_library_loggers()
|
||||
|
||||
def _configure_structlog(self) -> None:
|
||||
"""Configure structlog for structured logging."""
|
||||
structlog.configure(
|
||||
processors=[
|
||||
# Add context variables to log entries
|
||||
structlog.contextvars.merge_contextvars,
|
||||
# Add log level to event dict
|
||||
structlog.stdlib.filter_by_level,
|
||||
# Add logger name to event dict
|
||||
structlog.stdlib.add_logger_name,
|
||||
# Add log level to event dict
|
||||
structlog.stdlib.add_log_level,
|
||||
# Perform %-style formatting
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
# Add timestamp
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
# Add stack info when requested
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
# Format exceptions
|
||||
structlog.processors.format_exc_info,
|
||||
# Unicode-encode strings
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
# Pass to stdlib logging
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
# Configure stdlib logging with JSON formatter
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = JSONFormatter()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Set up root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(getattr(logging, self.level))
|
||||
|
||||
def _configure_stdlib_logging(self) -> None:
|
||||
"""Configure standard library logging."""
|
||||
if self.json_format:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = JSONFormatter()
|
||||
else:
|
||||
# Use traditional format for development
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||
date_format = "%Y-%m-%d %H:%M:%S"
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(log_format, datefmt=date_format)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, self.level),
|
||||
handlers=[handler],
|
||||
)
|
||||
|
||||
def _configure_library_loggers(self) -> None:
|
||||
"""Configure logging levels for third-party libraries."""
|
||||
# Discord.py can be quite verbose
|
||||
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.http").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.gateway").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.client").setLevel(logging.WARNING)
|
||||
|
||||
# SQLAlchemy logging
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(
|
||||
logging.DEBUG if self.level == "DEBUG" else logging.WARNING
|
||||
)
|
||||
logging.getLogger("sqlalchemy.dialects").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.orm").setLevel(logging.WARNING)
|
||||
|
||||
# HTTP libraries
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
# Other libraries
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_logger(name: str) -> BoundLogger:
|
||||
"""Get a structured logger instance."""
|
||||
if STRUCTLOG_AVAILABLE:
|
||||
return structlog.get_logger(name)
|
||||
else:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def bind_context(**kwargs: Any) -> None:
|
||||
"""Bind context variables for structured logging."""
|
||||
if STRUCTLOG_AVAILABLE:
|
||||
bind_contextvars(**kwargs)
|
||||
|
||||
|
||||
def unbind_context(*keys: str) -> None:
|
||||
"""Unbind specific context variables."""
|
||||
if STRUCTLOG_AVAILABLE:
|
||||
unbind_contextvars(*keys)
|
||||
|
||||
|
||||
def clear_context() -> None:
|
||||
"""Clear all context variables."""
|
||||
if STRUCTLOG_AVAILABLE:
|
||||
clear_contextvars()
|
||||
|
||||
|
||||
class LoggingMiddleware:
|
||||
"""Middleware for logging Discord bot events and commands."""
|
||||
|
||||
def __init__(self, logger: BoundLogger):
|
||||
self.logger = logger
|
||||
|
||||
def log_command_start(self, ctx, command_name: str) -> None:
|
||||
"""Log when a command starts."""
|
||||
bind_context(
|
||||
command=command_name,
|
||||
user_id=ctx.author.id,
|
||||
user_name=str(ctx.author),
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
guild_name=ctx.guild.name if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
channel_name=getattr(ctx.channel, 'name', 'DM'),
|
||||
)
|
||||
if hasattr(self.logger, 'info'):
|
||||
self.logger.info(
|
||||
"Command started",
|
||||
extra={
|
||||
"command": command_name,
|
||||
"args": ctx.args if hasattr(ctx, 'args') else None,
|
||||
}
|
||||
)
|
||||
|
||||
def log_command_success(self, ctx, command_name: str, duration: float) -> None:
|
||||
"""Log successful command completion."""
|
||||
if hasattr(self.logger, 'info'):
|
||||
self.logger.info(
|
||||
"Command completed successfully",
|
||||
extra={
|
||||
"command": command_name,
|
||||
"duration_ms": round(duration * 1000, 2),
|
||||
}
|
||||
)
|
||||
|
||||
def log_command_error(self, ctx, command_name: str, error: Exception, duration: float) -> None:
|
||||
"""Log command errors."""
|
||||
if hasattr(self.logger, 'error'):
|
||||
self.logger.error(
|
||||
"Command failed",
|
||||
exc_info=error,
|
||||
extra={
|
||||
"command": command_name,
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": str(error),
|
||||
"duration_ms": round(duration * 1000, 2),
|
||||
}
|
||||
)
|
||||
|
||||
def log_moderation_action(
|
||||
self,
|
||||
action: str,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
moderator_id: int,
|
||||
moderator_name: str,
|
||||
guild_id: int,
|
||||
reason: str = None,
|
||||
duration: int = None,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Log moderation actions."""
|
||||
if hasattr(self.logger, 'info'):
|
||||
self.logger.info(
|
||||
"Moderation action performed",
|
||||
extra={
|
||||
"action": action,
|
||||
"target_id": target_id,
|
||||
"target_name": target_name,
|
||||
"moderator_id": moderator_id,
|
||||
"moderator_name": moderator_name,
|
||||
"guild_id": guild_id,
|
||||
"reason": reason,
|
||||
"duration_seconds": duration,
|
||||
**extra,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Global logging middleware instance
|
||||
_logging_middleware: LoggingMiddleware = None
|
||||
|
||||
|
||||
def get_logging_middleware() -> LoggingMiddleware:
|
||||
"""Get the global logging middleware instance."""
|
||||
global _logging_middleware
|
||||
if _logging_middleware is None:
|
||||
logger = get_logger("guardden.middleware")
|
||||
_logging_middleware = LoggingMiddleware(logger)
|
||||
return _logging_middleware
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO",
|
||||
json_format: bool = False
|
||||
) -> None:
|
||||
"""Set up logging for the GuardDen application."""
|
||||
GuardDenLogger(level=level, json_format=json_format)
|
||||
|
||||
328
src/guardden/utils/metrics.py
Normal file
328
src/guardden/utils/metrics.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Prometheus metrics utilities for GuardDen."""
|
||||
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
try:
|
||||
from prometheus_client import Counter, Histogram, Gauge, Info, start_http_server, CollectorRegistry, REGISTRY
|
||||
PROMETHEUS_AVAILABLE = True
|
||||
except ImportError:
|
||||
PROMETHEUS_AVAILABLE = False
|
||||
# Mock objects when Prometheus client is not available
|
||||
class MockMetric:
|
||||
def inc(self, *args, **kwargs): pass
|
||||
def observe(self, *args, **kwargs): pass
|
||||
def set(self, *args, **kwargs): pass
|
||||
def info(self, *args, **kwargs): pass
|
||||
|
||||
Counter = Histogram = Gauge = Info = MockMetric
|
||||
CollectorRegistry = REGISTRY = None
|
||||
|
||||
|
||||
class GuardDenMetrics:
|
||||
"""Centralized metrics collection for GuardDen."""
|
||||
|
||||
def __init__(self, registry: Optional[CollectorRegistry] = None):
|
||||
self.registry = registry or REGISTRY
|
||||
self.enabled = PROMETHEUS_AVAILABLE
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Bot metrics
|
||||
self.bot_commands_total = Counter(
|
||||
'guardden_commands_total',
|
||||
'Total number of commands executed',
|
||||
['command', 'guild', 'status'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.bot_command_duration = Histogram(
|
||||
'guardden_command_duration_seconds',
|
||||
'Command execution duration in seconds',
|
||||
['command', 'guild'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.bot_guilds_total = Gauge(
|
||||
'guardden_guilds_total',
|
||||
'Total number of guilds the bot is in',
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.bot_users_total = Gauge(
|
||||
'guardden_users_total',
|
||||
'Total number of users across all guilds',
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
# Moderation metrics
|
||||
self.moderation_actions_total = Counter(
|
||||
'guardden_moderation_actions_total',
|
||||
'Total number of moderation actions',
|
||||
['action', 'guild', 'automated'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.automod_triggers_total = Counter(
|
||||
'guardden_automod_triggers_total',
|
||||
'Total number of automod triggers',
|
||||
['filter_type', 'guild', 'action'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
# AI metrics
|
||||
self.ai_requests_total = Counter(
|
||||
'guardden_ai_requests_total',
|
||||
'Total number of AI provider requests',
|
||||
['provider', 'operation', 'status'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.ai_request_duration = Histogram(
|
||||
'guardden_ai_request_duration_seconds',
|
||||
'AI request duration in seconds',
|
||||
['provider', 'operation'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.ai_confidence_score = Histogram(
|
||||
'guardden_ai_confidence_score',
|
||||
'AI confidence scores',
|
||||
['provider', 'operation'],
|
||||
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
# Database metrics
|
||||
self.database_connections_active = Gauge(
|
||||
'guardden_database_connections_active',
|
||||
'Number of active database connections',
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.database_query_duration = Histogram(
|
||||
'guardden_database_query_duration_seconds',
|
||||
'Database query duration in seconds',
|
||||
['operation'],
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
# System metrics
|
||||
self.bot_info = Info(
|
||||
'guardden_bot_info',
|
||||
'Bot information',
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
self.last_heartbeat = Gauge(
|
||||
'guardden_last_heartbeat_timestamp',
|
||||
'Timestamp of last successful heartbeat',
|
||||
registry=self.registry
|
||||
)
|
||||
|
||||
def record_command(self, command: str, guild_id: Optional[int], status: str, duration: float):
|
||||
"""Record command execution metrics."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
guild_str = str(guild_id) if guild_id else 'dm'
|
||||
self.bot_commands_total.labels(command=command, guild=guild_str, status=status).inc()
|
||||
self.bot_command_duration.labels(command=command, guild=guild_str).observe(duration)
|
||||
|
||||
def record_moderation_action(self, action: str, guild_id: int, automated: bool):
|
||||
"""Record moderation action metrics."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self.moderation_actions_total.labels(
|
||||
action=action,
|
||||
guild=str(guild_id),
|
||||
automated=str(automated).lower()
|
||||
).inc()
|
||||
|
||||
def record_automod_trigger(self, filter_type: str, guild_id: int, action: str):
|
||||
"""Record automod trigger metrics."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self.automod_triggers_total.labels(
|
||||
filter_type=filter_type,
|
||||
guild=str(guild_id),
|
||||
action=action
|
||||
).inc()
|
||||
|
||||
def record_ai_request(self, provider: str, operation: str, status: str, duration: float, confidence: Optional[float] = None):
|
||||
"""Record AI request metrics."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self.ai_requests_total.labels(
|
||||
provider=provider,
|
||||
operation=operation,
|
||||
status=status
|
||||
).inc()
|
||||
|
||||
self.ai_request_duration.labels(
|
||||
provider=provider,
|
||||
operation=operation
|
||||
).observe(duration)
|
||||
|
||||
if confidence is not None:
|
||||
self.ai_confidence_score.labels(
|
||||
provider=provider,
|
||||
operation=operation
|
||||
).observe(confidence)
|
||||
|
||||
def update_guild_count(self, count: int):
|
||||
"""Update total guild count."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.bot_guilds_total.set(count)
|
||||
|
||||
def update_user_count(self, count: int):
|
||||
"""Update total user count."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.bot_users_total.set(count)
|
||||
|
||||
def update_database_connections(self, active: int):
|
||||
"""Update active database connections."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.database_connections_active.set(active)
|
||||
|
||||
def record_database_query(self, operation: str, duration: float):
|
||||
"""Record database query metrics."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.database_query_duration.labels(operation=operation).observe(duration)
|
||||
|
||||
def update_bot_info(self, info: Dict[str, str]):
|
||||
"""Update bot information."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.bot_info.info(info)
|
||||
|
||||
def heartbeat(self):
|
||||
"""Record heartbeat timestamp."""
|
||||
if not self.enabled:
|
||||
return
|
||||
self.last_heartbeat.set(time.time())
|
||||
|
||||
|
||||
# Global metrics instance
|
||||
_metrics: Optional[GuardDenMetrics] = None
|
||||
|
||||
|
||||
def get_metrics() -> GuardDenMetrics:
|
||||
"""Get the global metrics instance."""
|
||||
global _metrics
|
||||
if _metrics is None:
|
||||
_metrics = GuardDenMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
def start_metrics_server(port: int = 8001) -> None:
|
||||
"""Start Prometheus metrics HTTP server."""
|
||||
if PROMETHEUS_AVAILABLE:
|
||||
start_http_server(port)
|
||||
|
||||
|
||||
def metrics_middleware(func):
|
||||
"""Decorator to automatically record command metrics."""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if not PROMETHEUS_AVAILABLE:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
start_time = time.time()
|
||||
status = "success"
|
||||
|
||||
try:
|
||||
# Try to extract context information
|
||||
ctx = None
|
||||
if args and hasattr(args[0], 'qualified_name'):
|
||||
# This is likely a command
|
||||
command_name = args[0].qualified_name
|
||||
if len(args) > 1 and hasattr(args[1], 'guild'):
|
||||
ctx = args[1]
|
||||
else:
|
||||
command_name = func.__name__
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
status = "error"
|
||||
raise
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
guild_id = ctx.guild.id if ctx and ctx.guild else None
|
||||
|
||||
metrics = get_metrics()
|
||||
metrics.record_command(
|
||||
command=command_name,
|
||||
guild_id=guild_id,
|
||||
status=status,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Periodic metrics collector for system stats."""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.metrics = get_metrics()
|
||||
|
||||
async def collect_bot_metrics(self):
|
||||
"""Collect basic bot metrics."""
|
||||
if not PROMETHEUS_AVAILABLE:
|
||||
return
|
||||
|
||||
# Guild count
|
||||
guild_count = len(self.bot.guilds)
|
||||
self.metrics.update_guild_count(guild_count)
|
||||
|
||||
# Total user count across all guilds
|
||||
total_users = sum(guild.member_count or 0 for guild in self.bot.guilds)
|
||||
self.metrics.update_user_count(total_users)
|
||||
|
||||
# Database connections if available
|
||||
if hasattr(self.bot, 'database') and self.bot.database._engine:
|
||||
try:
|
||||
pool = self.bot.database._engine.pool
|
||||
if hasattr(pool, 'checkedout'):
|
||||
active_connections = pool.checkedout()
|
||||
self.metrics.update_database_connections(active_connections)
|
||||
except Exception:
|
||||
pass # Ignore database connection metrics errors
|
||||
|
||||
# Bot info
|
||||
self.metrics.update_bot_info({
|
||||
'version': getattr(self.bot, 'version', 'unknown'),
|
||||
'python_version': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
|
||||
'discord_py_version': str(discord.__version__) if 'discord' in globals() else 'unknown',
|
||||
})
|
||||
|
||||
# Heartbeat
|
||||
self.metrics.heartbeat()
|
||||
|
||||
|
||||
def setup_metrics(bot, port: int = 8001) -> Optional[MetricsCollector]:
|
||||
"""Set up metrics collection for the bot."""
|
||||
if not PROMETHEUS_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
start_metrics_server(port)
|
||||
collector = MetricsCollector(bot)
|
||||
return collector
|
||||
except Exception as e:
|
||||
# Log error but don't fail startup
|
||||
logger = __import__('logging').getLogger(__name__)
|
||||
logger.error(f"Failed to start metrics server: {e}")
|
||||
return None
|
||||
10
src/guardden/utils/ratelimit.py
Normal file
10
src/guardden/utils/ratelimit.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Rate limit helpers for Discord commands."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitExceeded(Exception):
|
||||
"""Raised when a command is rate limited."""
|
||||
|
||||
retry_after: float
|
||||
Reference in New Issue
Block a user