first commit
This commit is contained in:
13
src/daemon_boyfriend/utils/__init__.py
Normal file
13
src/daemon_boyfriend/utils/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Utility modules."""
|
||||
|
||||
from .logging import get_logger, setup_logging
|
||||
from .rate_limiter import RateLimiter, chat_limiter, rate_limited, search_limiter
|
||||
|
||||
__all__ = [
|
||||
"setup_logging",
|
||||
"get_logger",
|
||||
"RateLimiter",
|
||||
"rate_limited",
|
||||
"chat_limiter",
|
||||
"search_limiter",
|
||||
]
|
||||
66
src/daemon_boyfriend/utils/error_handler.py
Normal file
66
src/daemon_boyfriend/utils/error_handler.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Global error handling for the bot."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorHandler(commands.Cog):
|
||||
"""Global error handling cog."""
|
||||
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
# Set up the app command error handler
|
||||
self.bot.tree.on_error = self.on_app_command_error
|
||||
|
||||
async def on_app_command_error(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
error: app_commands.AppCommandError,
|
||||
) -> None:
|
||||
"""Handle slash command errors.
|
||||
|
||||
Args:
|
||||
interaction: The interaction that caused the error
|
||||
error: The error that was raised
|
||||
"""
|
||||
# Determine the error message
|
||||
if isinstance(error, app_commands.CommandOnCooldown):
|
||||
message = f"This command is on cooldown. Try again in {error.retry_after:.1f} seconds."
|
||||
elif isinstance(error, app_commands.MissingPermissions):
|
||||
message = "You don't have permission to use this command."
|
||||
elif isinstance(error, app_commands.BotMissingPermissions):
|
||||
message = "I don't have the required permissions to do that."
|
||||
elif isinstance(error, app_commands.CommandNotFound):
|
||||
message = "That command doesn't exist."
|
||||
elif isinstance(error, app_commands.TransformerError):
|
||||
message = f"Invalid input: {error}"
|
||||
elif isinstance(error, app_commands.CheckFailure):
|
||||
message = "You can't use this command."
|
||||
else:
|
||||
# Log unexpected errors
|
||||
logger.error(
|
||||
f"Unhandled app command error: {error}",
|
||||
exc_info=error,
|
||||
)
|
||||
message = "An unexpected error occurred. Please try again later."
|
||||
|
||||
# Send the error message
|
||||
try:
|
||||
if interaction.response.is_done():
|
||||
await interaction.followup.send(message, ephemeral=True)
|
||||
else:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
except discord.HTTPException:
|
||||
# If we can't send a message, just log it
|
||||
logger.error(f"Failed to send error message: {message}")
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot) -> None:
|
||||
"""Load the error handler cog."""
|
||||
await bot.add_cog(ErrorHandler(bot))
|
||||
44
src/daemon_boyfriend/utils/logging.py
Normal file
44
src/daemon_boyfriend/utils/logging.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Logging configuration for the bot."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure application-wide logging."""
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(settings.log_level)
|
||||
|
||||
# Remove existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# Reduce noise from discord.py
|
||||
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.http").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.gateway").setLevel(logging.WARNING)
|
||||
|
||||
# Reduce noise from aiohttp
|
||||
logging.getLogger("aiohttp").setLevel(logging.WARNING)
|
||||
|
||||
# Reduce noise from httpx (used by openai/anthropic)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the given name."""
|
||||
return logging.getLogger(name)
|
||||
138
src/daemon_boyfriend/utils/rate_limiter.py
Normal file
138
src/daemon_boyfriend/utils/rate_limiter.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Rate limiting utilities."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Per-user rate limiting.
|
||||
|
||||
Tracks the number of calls per user within a time window
|
||||
and rejects calls that exceed the limit.
|
||||
"""
|
||||
|
||||
def __init__(self, max_calls: int, period_seconds: int) -> None:
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
max_calls: Maximum calls allowed per period
|
||||
period_seconds: Length of the period in seconds
|
||||
"""
|
||||
self.max_calls = max_calls
|
||||
self.period = timedelta(seconds=period_seconds)
|
||||
self.calls: dict[int, list[datetime]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self, user_id: int) -> bool:
|
||||
"""Check if a user is within rate limit and record the call.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
True if allowed, False if rate limited
|
||||
"""
|
||||
async with self._lock:
|
||||
now = datetime.now()
|
||||
cutoff = now - self.period
|
||||
|
||||
# Clean old entries
|
||||
self.calls[user_id] = [t for t in self.calls[user_id] if t > cutoff]
|
||||
|
||||
if len(self.calls[user_id]) >= self.max_calls:
|
||||
return False
|
||||
|
||||
self.calls[user_id].append(now)
|
||||
return True
|
||||
|
||||
def remaining(self, user_id: int) -> int:
|
||||
"""Get remaining calls for a user in current period.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
Number of remaining calls
|
||||
"""
|
||||
now = datetime.now()
|
||||
cutoff = now - self.period
|
||||
recent = [t for t in self.calls[user_id] if t > cutoff]
|
||||
return max(0, self.max_calls - len(recent))
|
||||
|
||||
def reset_time(self, user_id: int) -> float:
|
||||
"""Get seconds until rate limit resets for a user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
Seconds until reset, or 0 if not rate limited
|
||||
"""
|
||||
if not self.calls[user_id]:
|
||||
return 0
|
||||
|
||||
oldest = min(self.calls[user_id])
|
||||
reset_at = oldest + self.period
|
||||
remaining = (reset_at - datetime.now()).total_seconds()
|
||||
return max(0, remaining)
|
||||
|
||||
|
||||
def rate_limited(
|
||||
limiter: RateLimiter,
|
||||
) -> Callable[
|
||||
[Callable[..., Coroutine[Any, Any, T]]],
|
||||
Callable[..., Coroutine[Any, Any, T | None]],
|
||||
]:
|
||||
"""Decorator for rate-limited slash commands.
|
||||
|
||||
Args:
|
||||
limiter: The RateLimiter instance to use
|
||||
|
||||
Returns:
|
||||
Decorated function that checks rate limit before execution
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Coroutine[Any, Any, T]],
|
||||
) -> Callable[..., Coroutine[Any, Any, T | None]]:
|
||||
@wraps(func)
|
||||
async def wrapper(
|
||||
self: Any, interaction: discord.Interaction, *args: Any, **kwargs: Any
|
||||
) -> T | None:
|
||||
if not await limiter.is_allowed(interaction.user.id):
|
||||
reset_time = limiter.reset_time(interaction.user.id)
|
||||
await interaction.response.send_message(
|
||||
f"You're sending messages too quickly. Please wait {reset_time:.0f} seconds.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return None
|
||||
return await func(self, interaction, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Pre-configured rate limiters
|
||||
chat_limiter = RateLimiter(
|
||||
max_calls=settings.rate_limit_messages,
|
||||
period_seconds=60,
|
||||
)
|
||||
|
||||
search_limiter = RateLimiter(
|
||||
max_calls=settings.rate_limit_searches,
|
||||
period_seconds=60,
|
||||
)
|
||||
Reference in New Issue
Block a user