- Add PostgreSQL database with SQLAlchemy async support - Create models: User, UserFact, UserPreference, Conversation, Message, Guild, GuildMember - Add custom name support so bot knows 'who is who' - Add user facts system for remembering things about users - Add persistent conversation history that survives restarts - Add memory commands cog (!setname, !remember, !whatdoyouknow, !forgetme) - Add admin commands (!setusername, !teachbot) - Set up Alembic for database migrations - Update docker-compose with PostgreSQL service - Gracefully falls back to in-memory storage when DB not configured
189 lines
6.0 KiB
Python
189 lines
6.0 KiB
Python
"""Persistent conversation management using PostgreSQL."""
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from daemon_boyfriend.config import settings
|
|
from daemon_boyfriend.models import Conversation, Message, User
|
|
from daemon_boyfriend.services.providers import Message as ProviderMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PersistentConversationManager:
|
|
"""Manages conversation history in PostgreSQL."""
|
|
|
|
def __init__(self, session: AsyncSession, max_history: int | None = None) -> None:
|
|
self._session = session
|
|
self.max_history = max_history or settings.max_conversation_history
|
|
self._timeout = timedelta(minutes=settings.conversation_timeout_minutes)
|
|
|
|
async def get_or_create_conversation(
|
|
self,
|
|
user: User,
|
|
guild_id: int | None = None,
|
|
channel_id: int | None = None,
|
|
) -> Conversation:
|
|
"""Get active conversation or create new one.
|
|
|
|
Args:
|
|
user: User model instance
|
|
guild_id: Discord guild ID (None for DMs)
|
|
channel_id: Discord channel ID
|
|
|
|
Returns:
|
|
Conversation model instance
|
|
"""
|
|
# Look for recent active conversation in this channel
|
|
cutoff = datetime.utcnow() - self._timeout
|
|
|
|
stmt = select(Conversation).where(
|
|
Conversation.user_id == user.id,
|
|
Conversation.is_active == True, # noqa: E712
|
|
Conversation.last_message_at > cutoff,
|
|
)
|
|
|
|
if channel_id:
|
|
stmt = stmt.where(Conversation.channel_id == channel_id)
|
|
|
|
stmt = stmt.order_by(Conversation.last_message_at.desc())
|
|
|
|
result = await self._session.execute(stmt)
|
|
conversation = result.scalar_first()
|
|
|
|
if conversation:
|
|
logger.debug(
|
|
f"Found existing conversation {conversation.id} for user {user.discord_id}"
|
|
)
|
|
return conversation
|
|
|
|
# Create new conversation
|
|
conversation = Conversation(
|
|
user_id=user.id,
|
|
guild_id=guild_id,
|
|
channel_id=channel_id,
|
|
)
|
|
self._session.add(conversation)
|
|
await self._session.flush()
|
|
logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}")
|
|
return conversation
|
|
|
|
async def get_history(self, conversation: Conversation) -> list[ProviderMessage]:
|
|
"""Get conversation history as provider messages.
|
|
|
|
Args:
|
|
conversation: Conversation model instance
|
|
|
|
Returns:
|
|
List of ProviderMessage instances for the AI
|
|
"""
|
|
stmt = (
|
|
select(Message)
|
|
.where(Message.conversation_id == conversation.id)
|
|
.order_by(Message.created_at.desc())
|
|
.limit(self.max_history)
|
|
)
|
|
|
|
result = await self._session.execute(stmt)
|
|
messages = list(reversed(result.scalars().all()))
|
|
|
|
return [
|
|
ProviderMessage(
|
|
role=msg.role,
|
|
content=msg.content,
|
|
# Note: Images would need special handling if needed
|
|
)
|
|
for msg in messages
|
|
]
|
|
|
|
async def add_message(
|
|
self,
|
|
conversation: Conversation,
|
|
user: User,
|
|
role: str,
|
|
content: str,
|
|
discord_message_id: int | None = None,
|
|
image_urls: list[str] | None = None,
|
|
) -> Message:
|
|
"""Add a message to the conversation.
|
|
|
|
Args:
|
|
conversation: Conversation model instance
|
|
user: User model instance
|
|
role: Message role (user, assistant, system)
|
|
content: Message content
|
|
discord_message_id: Discord's message ID (optional)
|
|
image_urls: List of image URLs (optional)
|
|
|
|
Returns:
|
|
Created Message instance
|
|
"""
|
|
message = Message(
|
|
conversation_id=conversation.id,
|
|
user_id=user.id,
|
|
role=role,
|
|
content=content,
|
|
discord_message_id=discord_message_id,
|
|
has_images=bool(image_urls),
|
|
image_urls=image_urls,
|
|
)
|
|
self._session.add(message)
|
|
|
|
# Update conversation stats
|
|
conversation.last_message_at = datetime.utcnow()
|
|
conversation.message_count += 1
|
|
|
|
await self._session.flush()
|
|
return message
|
|
|
|
async def add_exchange(
|
|
self,
|
|
conversation: Conversation,
|
|
user: User,
|
|
user_message: str,
|
|
assistant_message: str,
|
|
discord_message_id: int | None = None,
|
|
image_urls: list[str] | None = None,
|
|
) -> tuple[Message, Message]:
|
|
"""Add a user/assistant exchange to the conversation.
|
|
|
|
Args:
|
|
conversation: Conversation model instance
|
|
user: User model instance
|
|
user_message: The user's message content
|
|
assistant_message: The assistant's response
|
|
discord_message_id: Discord's message ID (optional)
|
|
image_urls: List of image URLs in user message (optional)
|
|
|
|
Returns:
|
|
Tuple of (user_message, assistant_message) Message instances
|
|
"""
|
|
user_msg = await self.add_message(
|
|
conversation, user, "user", user_message, discord_message_id, image_urls
|
|
)
|
|
assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message)
|
|
return user_msg, assistant_msg
|
|
|
|
async def clear_conversation(self, conversation: Conversation) -> None:
|
|
"""Mark a conversation as inactive.
|
|
|
|
Args:
|
|
conversation: Conversation model instance
|
|
"""
|
|
conversation.is_active = False
|
|
logger.info(f"Marked conversation {conversation.id} as inactive")
|
|
|
|
async def get_message_count(self, conversation: Conversation) -> int:
|
|
"""Get the number of messages in a conversation.
|
|
|
|
Args:
|
|
conversation: Conversation model instance
|
|
|
|
Returns:
|
|
Number of messages
|
|
"""
|
|
return conversation.message_count
|