Refactor: remove unused cogs and services, simplify architecture
- Remove admin.py and search.py cogs - Remove searxng.py service and rate_limiter.py utility - Update bot.py, ai_chat.py, config.py, and ai_service.py - Update documentation and docker-compose.yml
This commit is contained in:
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DaemonBoyfriend(commands.Bot):
|
||||
"""The main bot class for Daemon Boyfriend."""
|
||||
"""The main bot class."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
intents = discord.Intents.default()
|
||||
@@ -21,13 +21,13 @@ class DaemonBoyfriend(commands.Bot):
|
||||
intents.members = True
|
||||
|
||||
super().__init__(
|
||||
command_prefix=settings.command_prefix,
|
||||
command_prefix="!", # Required but not used
|
||||
intents=intents,
|
||||
help_command=None, # We use slash commands instead
|
||||
help_command=None,
|
||||
)
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Load cogs and sync commands on startup."""
|
||||
"""Load cogs on startup."""
|
||||
# Load all cogs
|
||||
cogs_path = Path(__file__).parent / "cogs"
|
||||
for cog_file in cogs_path.glob("*.py"):
|
||||
@@ -40,18 +40,6 @@ class DaemonBoyfriend(commands.Bot):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cog {cog_name}: {e}")
|
||||
|
||||
# Sync slash commands
|
||||
if settings.discord_guild_id:
|
||||
# Sync to specific guild for faster testing (instant)
|
||||
guild = discord.Object(id=settings.discord_guild_id)
|
||||
self.tree.copy_global_to(guild=guild)
|
||||
await self.tree.sync(guild=guild)
|
||||
logger.info(f"Synced commands to guild {settings.discord_guild_id}")
|
||||
else:
|
||||
# Global sync (can take up to 1 hour to propagate)
|
||||
await self.tree.sync()
|
||||
logger.info("Synced commands globally")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Called when the bot is ready."""
|
||||
if self.user is None:
|
||||
@@ -59,31 +47,11 @@ class DaemonBoyfriend(commands.Bot):
|
||||
|
||||
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
logger.info(f"Connected to {len(self.guilds)} guild(s)")
|
||||
logger.info(f"Bot name: {settings.bot_name}")
|
||||
|
||||
# Set activity status
|
||||
# Set activity status from config
|
||||
activity = discord.Activity(
|
||||
type=discord.ActivityType.watching,
|
||||
name="over the MSC group",
|
||||
name=settings.bot_status,
|
||||
)
|
||||
await self.change_presence(activity=activity)
|
||||
|
||||
async def on_command_error(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
error: commands.CommandError, # type: ignore[type-arg]
|
||||
) -> None:
|
||||
"""Global error handler for prefix commands."""
|
||||
if isinstance(error, commands.CommandNotFound):
|
||||
return # Ignore unknown commands
|
||||
|
||||
if isinstance(error, commands.MissingPermissions):
|
||||
await ctx.send("You don't have permission to use this command.")
|
||||
return
|
||||
|
||||
if isinstance(error, commands.CommandOnCooldown):
|
||||
await ctx.send(f"Command on cooldown. Try again in {error.retry_after:.1f}s")
|
||||
return
|
||||
|
||||
# Log unexpected errors
|
||||
logger.error(f"Command error: {error}", exc_info=error)
|
||||
await ctx.send("An unexpected error occurred.")
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
"""Admin cog - administrative commands for the bot."""
|
||||
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
from daemon_boyfriend.services import SearXNGService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminCog(commands.Cog):
|
||||
"""Administrative commands for bot management."""
|
||||
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
self.start_time = datetime.now()
|
||||
|
||||
@app_commands.command(name="ping", description="Check bot latency")
|
||||
async def ping(self, interaction: discord.Interaction) -> None:
|
||||
"""Check the bot's latency."""
|
||||
latency = round(self.bot.latency * 1000)
|
||||
await interaction.response.send_message(f"Pong! Latency: {latency}ms")
|
||||
|
||||
@app_commands.command(name="status", description="Show bot status and info")
|
||||
async def status(self, interaction: discord.Interaction) -> None:
|
||||
"""Show bot status and information."""
|
||||
# Calculate uptime
|
||||
uptime = datetime.now() - self.start_time
|
||||
hours, remainder = divmod(int(uptime.total_seconds()), 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
uptime_str = f"{hours}h {minutes}m {seconds}s"
|
||||
|
||||
# Check SearXNG status
|
||||
searxng = SearXNGService()
|
||||
searxng_status = "Online" if await searxng.health_check() else "Offline"
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"{settings.bot_name} Status",
|
||||
color=discord.Color.green(),
|
||||
)
|
||||
|
||||
embed.add_field(name="Uptime", value=uptime_str, inline=True)
|
||||
embed.add_field(name="Latency", value=f"{round(self.bot.latency * 1000)}ms", inline=True)
|
||||
embed.add_field(name="Guilds", value=str(len(self.bot.guilds)), inline=True)
|
||||
|
||||
embed.add_field(name="AI Provider", value=settings.ai_provider, inline=True)
|
||||
embed.add_field(name="AI Model", value=settings.ai_model, inline=True)
|
||||
embed.add_field(name="SearXNG", value=searxng_status, inline=True)
|
||||
|
||||
embed.add_field(
|
||||
name="Python", value=f"{sys.version_info.major}.{sys.version_info.minor}", inline=True
|
||||
)
|
||||
embed.add_field(name="Discord.py", value=discord.__version__, inline=True)
|
||||
embed.add_field(name="Platform", value=platform.system(), inline=True)
|
||||
|
||||
await interaction.response.send_message(embed=embed)
|
||||
|
||||
@app_commands.command(name="provider", description="Show or change AI provider")
|
||||
@app_commands.describe(provider="The AI provider to switch to (admin only)")
|
||||
@app_commands.choices(
|
||||
provider=[
|
||||
app_commands.Choice(name="OpenAI", value="openai"),
|
||||
app_commands.Choice(name="OpenRouter", value="openrouter"),
|
||||
app_commands.Choice(name="Anthropic (Claude)", value="anthropic"),
|
||||
]
|
||||
)
|
||||
async def provider(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
provider: str | None = None,
|
||||
) -> None:
|
||||
"""Show current AI provider or change it (info only, actual change requires restart)."""
|
||||
if provider is None:
|
||||
# Just show current provider
|
||||
await interaction.response.send_message(
|
||||
f"Current AI provider: **{settings.ai_provider}**\nModel: **{settings.ai_model}**",
|
||||
ephemeral=True,
|
||||
)
|
||||
else:
|
||||
# Inform that changing requires restart
|
||||
await interaction.response.send_message(
|
||||
f"To change the AI provider to **{provider}**, update the `AI_PROVIDER` "
|
||||
f"and corresponding API key in your `.env` file, then restart the bot.",
|
||||
ephemeral=True,
|
||||
)
|
||||
|
||||
@app_commands.command(name="help", description="Show available commands")
|
||||
async def help_command(self, interaction: discord.Interaction) -> None:
|
||||
"""Show help information."""
|
||||
embed = discord.Embed(
|
||||
title=f"{settings.bot_name} - Help",
|
||||
description="Here are the available commands:",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Chat",
|
||||
value=(
|
||||
f"**@{settings.bot_name} <message>** - Chat with me by mentioning me\n"
|
||||
"**/chat <message>** - Chat using slash command\n"
|
||||
"**/clear** - Clear your conversation history"
|
||||
),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Search",
|
||||
value=("**/search <query>** - Search the web\n**/image <query>** - Search for images"),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Info",
|
||||
value=(
|
||||
"**/ping** - Check bot latency\n"
|
||||
"**/status** - Show bot status\n"
|
||||
"**/provider** - Show current AI provider\n"
|
||||
"**/help** - Show this help message"
|
||||
),
|
||||
inline=False,
|
||||
)
|
||||
|
||||
embed.set_footer(text=f"Made with love for the MSC group")
|
||||
|
||||
await interaction.response.send_message(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot) -> None:
|
||||
"""Load the Admin cog."""
|
||||
await bot.add_cog(AdminCog(bot))
|
||||
@@ -1,10 +1,9 @@
|
||||
"""AI Chat cog - handles chat commands and mention responses."""
|
||||
"""AI Chat cog - handles mention responses."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
@@ -66,42 +65,13 @@ def split_message(content: str, max_length: int = MAX_MESSAGE_LENGTH) -> list[st
|
||||
|
||||
|
||||
class AIChatCog(commands.Cog):
|
||||
"""AI conversation commands and mention handling."""
|
||||
"""AI conversation via mentions."""
|
||||
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
self.ai_service = AIService()
|
||||
self.conversations = ConversationManager()
|
||||
|
||||
@app_commands.command(name="chat", description="Chat with Daemon Boyfriend")
|
||||
@app_commands.describe(message="Your message to the bot")
|
||||
async def chat(self, interaction: discord.Interaction, message: str) -> None:
|
||||
"""Slash command to chat with the AI."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
try:
|
||||
response_text = await self._generate_response(interaction.user.id, message)
|
||||
|
||||
# Split long responses
|
||||
chunks = split_message(response_text)
|
||||
await interaction.followup.send(chunks[0])
|
||||
|
||||
# Send additional chunks as follow-up messages
|
||||
for chunk in chunks[1:]:
|
||||
await interaction.followup.send(chunk)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat error: {e}", exc_info=True)
|
||||
await interaction.followup.send("Sorry, I encountered an error. Please try again.")
|
||||
|
||||
@app_commands.command(name="clear", description="Clear your conversation history")
|
||||
async def clear_history(self, interaction: discord.Interaction) -> None:
|
||||
"""Clear the user's conversation history."""
|
||||
self.conversations.clear_history(interaction.user.id)
|
||||
await interaction.response.send_message(
|
||||
"Your conversation history has been cleared!", ephemeral=True
|
||||
)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Respond when the bot is mentioned."""
|
||||
@@ -117,8 +87,8 @@ class AIChatCog(commands.Cog):
|
||||
content = self._extract_message_content(message)
|
||||
|
||||
if not content:
|
||||
# Just a mention with no message
|
||||
await message.reply(f"Hey {message.author.display_name}! How can I help you?")
|
||||
# Just a mention with no message - use configured description
|
||||
await message.reply(f"Hey {message.author.display_name}! {settings.bot_description}")
|
||||
return
|
||||
|
||||
# Show typing indicator while generating response
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Search cog - web search using SearXNG."""
|
||||
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from daemon_boyfriend.services import SearchResponse, SearXNGService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_search_embed(response: SearchResponse) -> discord.Embed:
|
||||
"""Create a Discord embed for search results.
|
||||
|
||||
Args:
|
||||
response: The search response
|
||||
|
||||
Returns:
|
||||
A formatted Discord embed
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
title=f"Search: {response.query}",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
if not response.results:
|
||||
embed.description = "No results found."
|
||||
return embed
|
||||
|
||||
# Add results as fields
|
||||
for i, result in enumerate(response.results, 1):
|
||||
# Truncate content if too long
|
||||
content = result.content
|
||||
if len(content) > 200:
|
||||
content = content[:197] + "..."
|
||||
|
||||
embed.add_field(
|
||||
name=f"{i}. {result.title[:100]}",
|
||||
value=f"{content}\n[Link]({result.url})",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
# Add footer with result count
|
||||
embed.set_footer(text=f"Found {response.number_of_results} results")
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
class SearchCog(commands.Cog):
|
||||
"""Web search commands using SearXNG."""
|
||||
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
self.search_service = SearXNGService()
|
||||
|
||||
@app_commands.command(name="search", description="Search the web")
|
||||
@app_commands.describe(
|
||||
query="What to search for",
|
||||
category="Search category",
|
||||
)
|
||||
@app_commands.choices(
|
||||
category=[
|
||||
app_commands.Choice(name="General", value="general"),
|
||||
app_commands.Choice(name="Images", value="images"),
|
||||
app_commands.Choice(name="News", value="news"),
|
||||
app_commands.Choice(name="Science", value="science"),
|
||||
app_commands.Choice(name="IT", value="it"),
|
||||
app_commands.Choice(name="Videos", value="videos"),
|
||||
]
|
||||
)
|
||||
async def search(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
query: str,
|
||||
category: str = "general",
|
||||
) -> None:
|
||||
"""Search the web using SearXNG."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
try:
|
||||
response = await self.search_service.search(
|
||||
query=query,
|
||||
categories=[category],
|
||||
num_results=5,
|
||||
)
|
||||
|
||||
embed = create_search_embed(response)
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search error: {e}", exc_info=True)
|
||||
await interaction.followup.send(
|
||||
f"Search failed. Make sure SearXNG is running and accessible."
|
||||
)
|
||||
|
||||
@app_commands.command(name="image", description="Search for images")
|
||||
@app_commands.describe(query="What to search for")
|
||||
async def image_search(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
query: str,
|
||||
) -> None:
|
||||
"""Search for images using SearXNG."""
|
||||
await interaction.response.defer(thinking=True)
|
||||
|
||||
try:
|
||||
response = await self.search_service.search(
|
||||
query=query,
|
||||
categories=["images"],
|
||||
num_results=5,
|
||||
)
|
||||
|
||||
if not response.results:
|
||||
await interaction.followup.send(f"No images found for: {query}")
|
||||
return
|
||||
|
||||
# Create embed with first image
|
||||
embed = discord.Embed(
|
||||
title=f"Image search: {query}",
|
||||
color=discord.Color.purple(),
|
||||
)
|
||||
|
||||
# Add image results
|
||||
for i, result in enumerate(response.results[:5], 1):
|
||||
embed.add_field(
|
||||
name=f"{i}. {result.title[:50]}",
|
||||
value=f"[View]({result.url})",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image search error: {e}", exc_info=True)
|
||||
await interaction.followup.send(
|
||||
"Image search failed. Make sure SearXNG is running and accessible."
|
||||
)
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot) -> None:
|
||||
"""Load the Search cog."""
|
||||
await bot.add_cog(SearchCog(bot))
|
||||
@@ -17,10 +17,6 @@ class Settings(BaseSettings):
|
||||
|
||||
# Discord Configuration
|
||||
discord_token: str = Field(..., description="Discord bot token")
|
||||
discord_guild_id: int | None = Field(
|
||||
None, description="Test guild ID for faster command sync during development"
|
||||
)
|
||||
command_prefix: str = Field("!", description="Legacy command prefix")
|
||||
|
||||
# AI Provider Configuration
|
||||
ai_provider: Literal["openai", "openrouter", "anthropic"] = Field(
|
||||
@@ -35,25 +31,31 @@ class Settings(BaseSettings):
|
||||
openrouter_api_key: str | None = Field(None, description="OpenRouter API key")
|
||||
anthropic_api_key: str | None = Field(None, description="Anthropic API key")
|
||||
|
||||
# SearXNG Configuration
|
||||
searxng_base_url: str = Field(
|
||||
"http://localhost:8080", description="SearXNG instance URL"
|
||||
)
|
||||
searxng_timeout: int = Field(10, description="Search timeout in seconds")
|
||||
|
||||
# Rate Limiting
|
||||
rate_limit_messages: int = Field(10, description="Messages per user per minute")
|
||||
rate_limit_searches: int = Field(5, description="Searches per user per minute")
|
||||
|
||||
# Logging
|
||||
log_level: str = Field("INFO", description="Logging level")
|
||||
|
||||
# Bot Behavior
|
||||
bot_name: str = Field("Daemon Boyfriend", description="Bot display name")
|
||||
# Bot Identity
|
||||
bot_name: str = Field("AI Bot", description="Bot display name")
|
||||
bot_personality: str = Field(
|
||||
"helpful, witty, and slightly mischievous",
|
||||
"helpful and friendly",
|
||||
description="Bot personality description for system prompt",
|
||||
)
|
||||
bot_description: str = Field(
|
||||
"I'm an AI assistant here to help you.",
|
||||
description="Bot description shown when mentioned without a message",
|
||||
)
|
||||
bot_status: str = Field(
|
||||
"for mentions",
|
||||
description="Bot status message (shown as 'Watching ...')",
|
||||
)
|
||||
|
||||
# System Prompt (optional override)
|
||||
system_prompt: str | None = Field(
|
||||
None,
|
||||
description="Custom system prompt. If not set, a default is generated from bot_name and bot_personality",
|
||||
)
|
||||
|
||||
# Conversation Settings
|
||||
max_conversation_history: int = Field(
|
||||
20, description="Max messages to keep in conversation memory per user"
|
||||
)
|
||||
|
||||
@@ -3,14 +3,10 @@
|
||||
from .ai_service import AIService
|
||||
from .conversation import ConversationManager
|
||||
from .providers import AIResponse, Message
|
||||
from .searxng import SearchResponse, SearchResult, SearXNGService
|
||||
|
||||
__all__ = [
|
||||
"AIService",
|
||||
"AIResponse",
|
||||
"Message",
|
||||
"ConversationManager",
|
||||
"SearXNGService",
|
||||
"SearchResponse",
|
||||
"SearchResult",
|
||||
]
|
||||
|
||||
@@ -93,9 +93,14 @@ class AIService:
|
||||
)
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
"""Get the default system prompt for the bot."""
|
||||
"""Get the system prompt for the bot."""
|
||||
# Use custom system prompt if provided
|
||||
if self._config.system_prompt:
|
||||
return self._config.system_prompt
|
||||
|
||||
# Generate default system prompt from bot identity settings
|
||||
return (
|
||||
f"You are {self._config.bot_name}, a {self._config.bot_personality} "
|
||||
f"Discord bot for the MSC group. Keep your responses concise and engaging. "
|
||||
f"Discord bot. Keep your responses concise and engaging. "
|
||||
f"You can use Discord markdown formatting in your responses."
|
||||
)
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
"""SearXNG search service."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import aiohttp
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""A single search result."""
|
||||
|
||||
title: str
|
||||
url: str
|
||||
content: str # snippet/description
|
||||
engine: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResponse:
|
||||
"""Response from a SearXNG search."""
|
||||
|
||||
query: str
|
||||
results: list[SearchResult]
|
||||
suggestions: list[str]
|
||||
number_of_results: int
|
||||
|
||||
|
||||
class SearXNGService:
|
||||
"""SearXNG search service client."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
self.base_url = (base_url or settings.searxng_base_url).rstrip("/")
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout or settings.searxng_timeout)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
categories: list[str] | None = None,
|
||||
language: str = "en",
|
||||
safesearch: Literal[0, 1, 2] = 1, # 0=off, 1=moderate, 2=strict
|
||||
num_results: int = 5,
|
||||
) -> SearchResponse:
|
||||
"""Search using SearXNG.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
categories: Search categories (general, images, videos, news, it, science)
|
||||
language: Language code (e.g., "en", "nl")
|
||||
safesearch: Safe search level (0=off, 1=moderate, 2=strict)
|
||||
num_results: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
SearchResponse with results
|
||||
|
||||
Raises:
|
||||
aiohttp.ClientError: On network errors
|
||||
"""
|
||||
params: dict[str, str | int] = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"language": language,
|
||||
"safesearch": safesearch,
|
||||
}
|
||||
|
||||
if categories:
|
||||
params["categories"] = ",".join(categories)
|
||||
|
||||
logger.debug(f"Searching SearXNG: {query}")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
async with session.get(f"{self.base_url}/search", params=params) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
# Parse results
|
||||
results = [
|
||||
SearchResult(
|
||||
title=r.get("title", "No title"),
|
||||
url=r.get("url", ""),
|
||||
content=r.get("content", "No description"),
|
||||
engine=r.get("engine", "unknown"),
|
||||
)
|
||||
for r in data.get("results", [])[:num_results]
|
||||
]
|
||||
|
||||
return SearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
suggestions=data.get("suggestions", []),
|
||||
number_of_results=data.get("number_of_results", len(results)),
|
||||
)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if the SearXNG instance is reachable.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
# Try the search endpoint with a simple query
|
||||
async with session.get(
|
||||
f"{self.base_url}/search",
|
||||
params={"q": "test", "format": "json"},
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.error(f"SearXNG health check failed: {e}")
|
||||
return False
|
||||
@@ -1,13 +1,8 @@
|
||||
"""Utility modules."""
|
||||
|
||||
from .logging import get_logger, setup_logging
|
||||
from .rate_limiter import RateLimiter, chat_limiter, rate_limited, search_limiter
|
||||
|
||||
__all__ = [
|
||||
"setup_logging",
|
||||
"get_logger",
|
||||
"RateLimiter",
|
||||
"rate_limited",
|
||||
"chat_limiter",
|
||||
"search_limiter",
|
||||
]
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Rate limiting utilities."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Per-user rate limiting.
|
||||
|
||||
Tracks the number of calls per user within a time window
|
||||
and rejects calls that exceed the limit.
|
||||
"""
|
||||
|
||||
def __init__(self, max_calls: int, period_seconds: int) -> None:
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
max_calls: Maximum calls allowed per period
|
||||
period_seconds: Length of the period in seconds
|
||||
"""
|
||||
self.max_calls = max_calls
|
||||
self.period = timedelta(seconds=period_seconds)
|
||||
self.calls: dict[int, list[datetime]] = defaultdict(list)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self, user_id: int) -> bool:
|
||||
"""Check if a user is within rate limit and record the call.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
True if allowed, False if rate limited
|
||||
"""
|
||||
async with self._lock:
|
||||
now = datetime.now()
|
||||
cutoff = now - self.period
|
||||
|
||||
# Clean old entries
|
||||
self.calls[user_id] = [t for t in self.calls[user_id] if t > cutoff]
|
||||
|
||||
if len(self.calls[user_id]) >= self.max_calls:
|
||||
return False
|
||||
|
||||
self.calls[user_id].append(now)
|
||||
return True
|
||||
|
||||
def remaining(self, user_id: int) -> int:
|
||||
"""Get remaining calls for a user in current period.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
Number of remaining calls
|
||||
"""
|
||||
now = datetime.now()
|
||||
cutoff = now - self.period
|
||||
recent = [t for t in self.calls[user_id] if t > cutoff]
|
||||
return max(0, self.max_calls - len(recent))
|
||||
|
||||
def reset_time(self, user_id: int) -> float:
|
||||
"""Get seconds until rate limit resets for a user.
|
||||
|
||||
Args:
|
||||
user_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
Seconds until reset, or 0 if not rate limited
|
||||
"""
|
||||
if not self.calls[user_id]:
|
||||
return 0
|
||||
|
||||
oldest = min(self.calls[user_id])
|
||||
reset_at = oldest + self.period
|
||||
remaining = (reset_at - datetime.now()).total_seconds()
|
||||
return max(0, remaining)
|
||||
|
||||
|
||||
def rate_limited(
|
||||
limiter: RateLimiter,
|
||||
) -> Callable[
|
||||
[Callable[..., Coroutine[Any, Any, T]]],
|
||||
Callable[..., Coroutine[Any, Any, T | None]],
|
||||
]:
|
||||
"""Decorator for rate-limited slash commands.
|
||||
|
||||
Args:
|
||||
limiter: The RateLimiter instance to use
|
||||
|
||||
Returns:
|
||||
Decorated function that checks rate limit before execution
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[..., Coroutine[Any, Any, T]],
|
||||
) -> Callable[..., Coroutine[Any, Any, T | None]]:
|
||||
@wraps(func)
|
||||
async def wrapper(
|
||||
self: Any, interaction: discord.Interaction, *args: Any, **kwargs: Any
|
||||
) -> T | None:
|
||||
if not await limiter.is_allowed(interaction.user.id):
|
||||
reset_time = limiter.reset_time(interaction.user.id)
|
||||
await interaction.response.send_message(
|
||||
f"You're sending messages too quickly. Please wait {reset_time:.0f} seconds.",
|
||||
ephemeral=True,
|
||||
)
|
||||
return None
|
||||
return await func(self, interaction, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Pre-configured rate limiters
|
||||
chat_limiter = RateLimiter(
|
||||
max_calls=settings.rate_limit_messages,
|
||||
period_seconds=60,
|
||||
)
|
||||
|
||||
search_limiter = RateLimiter(
|
||||
max_calls=settings.rate_limit_searches,
|
||||
period_seconds=60,
|
||||
)
|
||||
Reference in New Issue
Block a user