Features: - Core moderation: warn, kick, ban, timeout, strike system - Automod: banned words filter, scam detection, anti-spam, link filtering - AI moderation: Claude/OpenAI integration, NSFW detection, phishing analysis - Verification system: button, captcha, math, emoji challenges - Rate limiting system with configurable scopes - Event logging: joins, leaves, message edits/deletes, voice activity - Per-guild configuration with caching - Docker deployment support Bug fixes applied: - Fixed await on session.delete() in guild_config.py - Fixed memory leak in AI moderation message tracking (use deque) - Added error handling to bot shutdown - Added error handling to timeout command - Removed unused Literal import - Added prefix validation - Added image analysis limit (3 per message) - Fixed test mock for SQLAlchemy model
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
"""Database connection and session management."""
|
|
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator
|
|
|
|
import asyncpg
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from guardden.config import Settings
|
|
from guardden.models import Base
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Database:
|
|
"""Manages database connections and sessions."""
|
|
|
|
def __init__(self, settings: Settings) -> None:
|
|
self.settings = settings
|
|
self._engine = None
|
|
self._session_factory = None
|
|
self._pool: asyncpg.Pool | None = None
|
|
|
|
async def connect(self) -> None:
|
|
"""Initialize database connection pool."""
|
|
db_url = self.settings.database_url.get_secret_value()
|
|
|
|
# Create SQLAlchemy async engine
|
|
# Convert postgresql:// to postgresql+asyncpg://
|
|
if db_url.startswith("postgresql://"):
|
|
sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
|
else:
|
|
sqlalchemy_url = db_url
|
|
|
|
self._engine = create_async_engine(
|
|
sqlalchemy_url,
|
|
pool_size=self.settings.database_pool_min,
|
|
max_overflow=self.settings.database_pool_max - self.settings.database_pool_min,
|
|
echo=self.settings.log_level == "DEBUG",
|
|
)
|
|
|
|
self._session_factory = async_sessionmaker(
|
|
self._engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
# Also create a raw asyncpg pool for performance-critical operations
|
|
self._pool = await asyncpg.create_pool(
|
|
db_url,
|
|
min_size=self.settings.database_pool_min,
|
|
max_size=self.settings.database_pool_max,
|
|
)
|
|
|
|
logger.info("Database connection established")
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Close all database connections."""
|
|
if self._pool:
|
|
await self._pool.close()
|
|
self._pool = None
|
|
|
|
if self._engine:
|
|
await self._engine.dispose()
|
|
self._engine = None
|
|
|
|
logger.info("Database connections closed")
|
|
|
|
async def create_tables(self) -> None:
|
|
"""Create all database tables."""
|
|
if not self._engine:
|
|
raise RuntimeError("Database not connected")
|
|
|
|
async with self._engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
logger.info("Database tables created")
|
|
|
|
@asynccontextmanager
|
|
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get a database session context manager."""
|
|
if not self._session_factory:
|
|
raise RuntimeError("Database not connected")
|
|
|
|
async with self._session_factory() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
|
|
@property
|
|
def pool(self) -> asyncpg.Pool:
|
|
"""Get the raw asyncpg connection pool."""
|
|
if not self._pool:
|
|
raise RuntimeError("Database not connected")
|
|
return self._pool
|