first commit

This commit is contained in:
2026-01-10 21:46:27 +01:00
parent d00593415d
commit 561f1a8fb1
30 changed files with 1932 additions and 1 deletions

View File

@@ -0,0 +1,13 @@
"""Utility modules."""
from .logging import get_logger, setup_logging
from .rate_limiter import RateLimiter, chat_limiter, rate_limited, search_limiter
__all__ = [
"setup_logging",
"get_logger",
"RateLimiter",
"rate_limited",
"chat_limiter",
"search_limiter",
]

View File

@@ -0,0 +1,66 @@
"""Global error handling for the bot."""
import logging
import traceback
import discord
from discord import app_commands
from discord.ext import commands
logger = logging.getLogger(__name__)
class ErrorHandler(commands.Cog):
"""Global error handling cog."""
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
# Set up the app command error handler
self.bot.tree.on_error = self.on_app_command_error
async def on_app_command_error(
self,
interaction: discord.Interaction,
error: app_commands.AppCommandError,
) -> None:
"""Handle slash command errors.
Args:
interaction: The interaction that caused the error
error: The error that was raised
"""
# Determine the error message
if isinstance(error, app_commands.CommandOnCooldown):
message = f"This command is on cooldown. Try again in {error.retry_after:.1f} seconds."
elif isinstance(error, app_commands.MissingPermissions):
message = "You don't have permission to use this command."
elif isinstance(error, app_commands.BotMissingPermissions):
message = "I don't have the required permissions to do that."
elif isinstance(error, app_commands.CommandNotFound):
message = "That command doesn't exist."
elif isinstance(error, app_commands.TransformerError):
message = f"Invalid input: {error}"
elif isinstance(error, app_commands.CheckFailure):
message = "You can't use this command."
else:
# Log unexpected errors
logger.error(
f"Unhandled app command error: {error}",
exc_info=error,
)
message = "An unexpected error occurred. Please try again later."
# Send the error message
try:
if interaction.response.is_done():
await interaction.followup.send(message, ephemeral=True)
else:
await interaction.response.send_message(message, ephemeral=True)
except discord.HTTPException:
# If we can't send a message, just log it
logger.error(f"Failed to send error message: {message}")
async def setup(bot: commands.Bot) -> None:
"""Load the error handler cog."""
await bot.add_cog(ErrorHandler(bot))

View File

@@ -0,0 +1,44 @@
"""Logging configuration for the bot."""
import logging
import sys
from daemon_boyfriend.config import settings
def setup_logging() -> None:
"""Configure application-wide logging."""
# Create formatter
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(settings.log_level)
# Remove existing handlers
root_logger.handlers.clear()
# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)
# Reduce noise from discord.py
logging.getLogger("discord").setLevel(logging.WARNING)
logging.getLogger("discord.http").setLevel(logging.WARNING)
logging.getLogger("discord.gateway").setLevel(logging.WARNING)
# Reduce noise from aiohttp
logging.getLogger("aiohttp").setLevel(logging.WARNING)
# Reduce noise from httpx (used by openai/anthropic)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the given name."""
return logging.getLogger(name)

View File

@@ -0,0 +1,138 @@
"""Rate limiting utilities."""
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Coroutine, TypeVar
import discord
from discord import app_commands
from daemon_boyfriend.config import settings
logger = logging.getLogger(__name__)
T = TypeVar("T")
class RateLimiter:
"""Per-user rate limiting.
Tracks the number of calls per user within a time window
and rejects calls that exceed the limit.
"""
def __init__(self, max_calls: int, period_seconds: int) -> None:
"""Initialize the rate limiter.
Args:
max_calls: Maximum calls allowed per period
period_seconds: Length of the period in seconds
"""
self.max_calls = max_calls
self.period = timedelta(seconds=period_seconds)
self.calls: dict[int, list[datetime]] = defaultdict(list)
self._lock = asyncio.Lock()
async def is_allowed(self, user_id: int) -> bool:
"""Check if a user is within rate limit and record the call.
Args:
user_id: Discord user ID
Returns:
True if allowed, False if rate limited
"""
async with self._lock:
now = datetime.now()
cutoff = now - self.period
# Clean old entries
self.calls[user_id] = [t for t in self.calls[user_id] if t > cutoff]
if len(self.calls[user_id]) >= self.max_calls:
return False
self.calls[user_id].append(now)
return True
def remaining(self, user_id: int) -> int:
"""Get remaining calls for a user in current period.
Args:
user_id: Discord user ID
Returns:
Number of remaining calls
"""
now = datetime.now()
cutoff = now - self.period
recent = [t for t in self.calls[user_id] if t > cutoff]
return max(0, self.max_calls - len(recent))
def reset_time(self, user_id: int) -> float:
"""Get seconds until rate limit resets for a user.
Args:
user_id: Discord user ID
Returns:
Seconds until reset, or 0 if not rate limited
"""
if not self.calls[user_id]:
return 0
oldest = min(self.calls[user_id])
reset_at = oldest + self.period
remaining = (reset_at - datetime.now()).total_seconds()
return max(0, remaining)
def rate_limited(
limiter: RateLimiter,
) -> Callable[
[Callable[..., Coroutine[Any, Any, T]]],
Callable[..., Coroutine[Any, Any, T | None]],
]:
"""Decorator for rate-limited slash commands.
Args:
limiter: The RateLimiter instance to use
Returns:
Decorated function that checks rate limit before execution
"""
def decorator(
func: Callable[..., Coroutine[Any, Any, T]],
) -> Callable[..., Coroutine[Any, Any, T | None]]:
@wraps(func)
async def wrapper(
self: Any, interaction: discord.Interaction, *args: Any, **kwargs: Any
) -> T | None:
if not await limiter.is_allowed(interaction.user.id):
reset_time = limiter.reset_time(interaction.user.id)
await interaction.response.send_message(
f"You're sending messages too quickly. Please wait {reset_time:.0f} seconds.",
ephemeral=True,
)
return None
return await func(self, interaction, *args, **kwargs)
return wrapper
return decorator
# Pre-configured rate limiters
chat_limiter = RateLimiter(
max_calls=settings.rate_limit_messages,
period_seconds=60,
)
search_limiter = RateLimiter(
max_calls=settings.rate_limit_searches,
period_seconds=60,
)