"""Persistent conversation management using PostgreSQL.""" import logging from datetime import datetime, timedelta from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from daemon_boyfriend.config import settings from daemon_boyfriend.models import Conversation, Message, User from daemon_boyfriend.services.providers import Message as ProviderMessage logger = logging.getLogger(__name__) class PersistentConversationManager: """Manages conversation history in PostgreSQL.""" def __init__(self, session: AsyncSession, max_history: int | None = None) -> None: self._session = session self.max_history = max_history or settings.max_conversation_history self._timeout = timedelta(minutes=settings.conversation_timeout_minutes) async def get_or_create_conversation( self, user: User, guild_id: int | None = None, channel_id: int | None = None, ) -> Conversation: """Get active conversation or create new one. Args: user: User model instance guild_id: Discord guild ID (None for DMs) channel_id: Discord channel ID Returns: Conversation model instance """ # Look for recent active conversation in this channel cutoff = datetime.utcnow() - self._timeout stmt = select(Conversation).where( Conversation.user_id == user.id, Conversation.is_active == True, # noqa: E712 Conversation.last_message_at > cutoff, ) if channel_id: stmt = stmt.where(Conversation.channel_id == channel_id) stmt = stmt.order_by(Conversation.last_message_at.desc()) result = await self._session.execute(stmt) conversation = result.scalar_first() if conversation: logger.debug( f"Found existing conversation {conversation.id} for user {user.discord_id}" ) return conversation # Create new conversation conversation = Conversation( user_id=user.id, guild_id=guild_id, channel_id=channel_id, ) self._session.add(conversation) await self._session.flush() logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}") return conversation async def get_history(self, conversation: Conversation) -> list[ProviderMessage]: """Get conversation history as provider messages. Args: conversation: Conversation model instance Returns: List of ProviderMessage instances for the AI """ stmt = ( select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(self.max_history) ) result = await self._session.execute(stmt) messages = list(reversed(result.scalars().all())) return [ ProviderMessage( role=msg.role, content=msg.content, # Note: Images would need special handling if needed ) for msg in messages ] async def add_message( self, conversation: Conversation, user: User, role: str, content: str, discord_message_id: int | None = None, image_urls: list[str] | None = None, ) -> Message: """Add a message to the conversation. Args: conversation: Conversation model instance user: User model instance role: Message role (user, assistant, system) content: Message content discord_message_id: Discord's message ID (optional) image_urls: List of image URLs (optional) Returns: Created Message instance """ message = Message( conversation_id=conversation.id, user_id=user.id, role=role, content=content, discord_message_id=discord_message_id, has_images=bool(image_urls), image_urls=image_urls, ) self._session.add(message) # Update conversation stats conversation.last_message_at = datetime.utcnow() conversation.message_count += 1 await self._session.flush() return message async def add_exchange( self, conversation: Conversation, user: User, user_message: str, assistant_message: str, discord_message_id: int | None = None, image_urls: list[str] | None = None, ) -> tuple[Message, Message]: """Add a user/assistant exchange to the conversation. Args: conversation: Conversation model instance user: User model instance user_message: The user's message content assistant_message: The assistant's response discord_message_id: Discord's message ID (optional) image_urls: List of image URLs in user message (optional) Returns: Tuple of (user_message, assistant_message) Message instances """ user_msg = await self.add_message( conversation, user, "user", user_message, discord_message_id, image_urls ) assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message) return user_msg, assistant_msg async def clear_conversation(self, conversation: Conversation) -> None: """Mark a conversation as inactive. Args: conversation: Conversation model instance """ conversation.is_active = False logger.info(f"Marked conversation {conversation.id} as inactive") async def get_message_count(self, conversation: Conversation) -> int: """Get the number of messages in a conversation. Args: conversation: Conversation model instance Returns: Number of messages """ return conversation.message_count