"""Database connection and session management.""" import logging from contextlib import asynccontextmanager from typing import AsyncGenerator import asyncpg from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from guardden.config import Settings from guardden.models import Base logger = logging.getLogger(__name__) class Database: """Manages database connections and sessions.""" def __init__(self, settings: Settings) -> None: self.settings = settings self._engine = None self._session_factory = None self._pool: asyncpg.Pool | None = None async def connect(self) -> None: """Initialize database connection pool.""" db_url = self.settings.database_url.get_secret_value() # Create SQLAlchemy async engine # Convert postgresql:// to postgresql+asyncpg:// if db_url.startswith("postgresql://"): sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) else: sqlalchemy_url = db_url self._engine = create_async_engine( sqlalchemy_url, pool_size=self.settings.database_pool_min, max_overflow=self.settings.database_pool_max - self.settings.database_pool_min, echo=self.settings.log_level == "DEBUG", ) self._session_factory = async_sessionmaker( self._engine, class_=AsyncSession, expire_on_commit=False, ) # Also create a raw asyncpg pool for performance-critical operations self._pool = await asyncpg.create_pool( db_url, min_size=self.settings.database_pool_min, max_size=self.settings.database_pool_max, ) logger.info("Database connection established") async def disconnect(self) -> None: """Close all database connections.""" if self._pool: await self._pool.close() self._pool = None if self._engine: await self._engine.dispose() self._engine = None logger.info("Database connections closed") async def create_tables(self) -> None: """Create all database tables.""" if not self._engine: raise RuntimeError("Database not connected") async with self._engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("Database tables created") @asynccontextmanager async def session(self) -> AsyncGenerator[AsyncSession, None]: """Get a database session context manager.""" if not self._session_factory: raise RuntimeError("Database not connected") async with self._session_factory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise @property def pool(self) -> asyncpg.Pool: """Get the raw asyncpg connection pool.""" if not self._pool: raise RuntimeError("Database not connected") return self._pool