Add PostgreSQL memory system for persistent user and conversation storage

- Add PostgreSQL database with SQLAlchemy async support
- Create models: User, UserFact, UserPreference, Conversation, Message, Guild, GuildMember
- Add custom name support so bot knows 'who is who'
- Add user facts system for remembering things about users
- Add persistent conversation history that survives restarts
- Add memory commands cog (!setname, !remember, !whatdoyouknow, !forgetme)
- Add admin commands (!setusername, !teachbot)
- Set up Alembic for database migrations
- Update docker-compose with PostgreSQL service
- Gracefully falls back to in-memory storage when DB not configured
This commit is contained in:
2026-01-12 14:00:06 +01:00
parent 853e7c9fcd
commit e00d4fd501
20 changed files with 1623 additions and 13 deletions

View File

@@ -7,6 +7,7 @@ import discord
from discord.ext import commands
from daemon_boyfriend.config import settings
from daemon_boyfriend.services import db
logger = logging.getLogger(__name__)
@@ -27,7 +28,14 @@ class DaemonBoyfriend(commands.Bot):
)
async def setup_hook(self) -> None:
"""Load cogs on startup."""
"""Initialize database and load cogs on startup."""
# Initialize database if configured
if db.is_configured:
await db.init()
logger.info("Database initialized")
else:
logger.info("Database not configured, using in-memory storage")
# Load all cogs
cogs_path = Path(__file__).parent / "cogs"
for cog_file in cogs_path.glob("*.py"):
@@ -55,3 +63,10 @@ class DaemonBoyfriend(commands.Bot):
name=settings.bot_status,
)
await self.change_presence(activity=activity)
async def close(self) -> None:
"""Clean up resources on shutdown."""
logger.info("Shutting down bot...")
if db.is_initialized:
await db.close()
await super().close()

View File

@@ -12,7 +12,10 @@ from daemon_boyfriend.services import (
ConversationManager,
ImageAttachment,
Message,
PersistentConversationManager,
SearXNGService,
UserService,
db,
)
from daemon_boyfriend.utils import get_monitor
@@ -77,11 +80,17 @@ class AIChatCog(commands.Cog):
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
self.ai_service = AIService()
# Fallback in-memory conversation manager (used when DB not configured)
self.conversations = ConversationManager()
self.search_service: SearXNGService | None = None
if settings.searxng_enabled and settings.searxng_url:
self.search_service = SearXNGService(settings.searxng_url)
@property
def use_database(self) -> bool:
"""Check if database is available for use."""
return db.is_initialized
@commands.Cog.listener()
async def on_message(self, message: discord.Message) -> None:
"""Respond when the bot is mentioned."""
@@ -395,6 +404,95 @@ class AIChatCog(commands.Cog):
Returns:
The AI's response text
"""
if self.use_database:
return await self._generate_response_with_db(message, user_message)
else:
return await self._generate_response_in_memory(message, user_message)
async def _generate_response_with_db(self, message: discord.Message, user_message: str) -> str:
"""Generate response using database-backed storage."""
async with db.session() as session:
user_service = UserService(session)
conv_manager = PersistentConversationManager(session)
# Get or create user
user = await user_service.get_or_create_user(
discord_id=message.author.id,
username=message.author.name,
display_name=message.author.display_name,
)
# Get or create conversation
conversation = await conv_manager.get_or_create_conversation(
user=user,
guild_id=message.guild.id if message.guild else None,
channel_id=message.channel.id,
)
# Get history
history = await conv_manager.get_history(conversation)
# Extract any image attachments from the message
images = self._extract_image_attachments(message)
image_urls = [img.url for img in images] if images else None
# Add current message to history for the API call
current_message = Message(role="user", content=user_message, images=images)
messages = history + [current_message]
# Check if we should search the web
search_context = await self._maybe_search(user_message)
# Get context about mentioned users
mentioned_users_context = self._get_mentioned_users_context(message)
# Build system prompt with additional context
system_prompt = self.ai_service.get_system_prompt()
# Add user context from database (custom name, known facts)
user_context = await user_service.get_user_context(user)
system_prompt += f"\n\n--- User Context ---\n{user_context}"
# Add mentioned users context
if mentioned_users_context:
system_prompt += f"\n\n--- {mentioned_users_context} ---"
# Add search results if available
if search_context:
system_prompt += (
"\n\n--- Web Search Results ---\n"
"Use the following current information from the web to help answer the user's question. "
"Cite sources when relevant.\n\n"
f"{search_context}"
)
# Generate response
response = await self.ai_service.chat(
messages=messages,
system_prompt=system_prompt,
)
# Save the exchange to database
await conv_manager.add_exchange(
conversation=conversation,
user=user,
user_message=user_message,
assistant_message=response.content,
discord_message_id=message.id,
image_urls=image_urls,
)
logger.debug(
f"Generated response for user {user.discord_id}: "
f"{len(response.content)} chars, {response.usage}"
)
return response.content
async def _generate_response_in_memory(
self, message: discord.Message, user_message: str
) -> str:
"""Generate response using in-memory storage (fallback)."""
user_id = message.author.id
# Get conversation history

View File

@@ -0,0 +1,260 @@
"""Memory management cog - commands for managing bot memory about users."""
import logging
import discord
from discord.ext import commands
from daemon_boyfriend.services import UserService, db
logger = logging.getLogger(__name__)
class MemoryCog(commands.Cog):
"""Commands for managing bot memory about users."""
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
def _check_database(self) -> bool:
"""Check if database is available."""
return db.is_initialized
@commands.command(name="setname")
async def set_name(self, ctx: commands.Context, *, name: str) -> None:
"""Set your preferred name for the bot to use.
Usage: !setname John
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(name) > 100:
await ctx.reply("Name is too long! Please use 100 characters or less.")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_or_create_user(
discord_id=ctx.author.id,
username=ctx.author.name,
display_name=ctx.author.display_name,
)
await user_service.set_custom_name(ctx.author.id, name)
await ctx.reply(f"Got it! I'll call you **{name}** from now on.")
@commands.command(name="clearname")
async def clear_name(self, ctx: commands.Context) -> None:
"""Clear your custom name and use your Discord name instead.
Usage: !clearname
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
async with db.session() as session:
user_service = UserService(session)
await user_service.set_custom_name(ctx.author.id, None)
await ctx.reply("Done! I'll use your Discord display name now.")
@commands.command(name="remember")
async def remember_fact(self, ctx: commands.Context, *, fact: str) -> None:
"""Tell the bot something to remember about you.
Usage: !remember I love pizza
Usage: !remember My favorite color is blue
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(fact) > 500:
await ctx.reply("That's too long to remember! Please keep it under 500 characters.")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_or_create_user(
discord_id=ctx.author.id,
username=ctx.author.name,
display_name=ctx.author.display_name,
)
await user_service.add_fact(
user=user,
fact_type="general",
fact_content=fact,
source="explicit",
confidence=1.0,
)
await ctx.reply(f"I'll remember that!")
@commands.command(name="whatdoyouknow", aliases=["aboutme", "myinfo"])
async def what_do_you_know(self, ctx: commands.Context) -> None:
"""Show what the bot remembers about you.
Usage: !whatdoyouknow
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_user_by_discord_id(ctx.author.id)
if not user:
await ctx.reply("I don't have any information about you yet!")
return
facts = await user_service.get_user_facts(user, active_only=True)
embed = discord.Embed(
title=f"What I know about {user.display_name}",
color=discord.Color.blue(),
)
embed.add_field(
name="Discord Username",
value=user.discord_username,
inline=True,
)
if user.custom_name:
embed.add_field(
name="Preferred Name",
value=user.custom_name,
inline=True,
)
embed.add_field(
name="First Seen",
value=user.first_seen_at.strftime("%Y-%m-%d"),
inline=True,
)
if facts:
facts_text = "\n".join(f"- {fact.fact_content}" for fact in facts[:15])
if len(facts) > 15:
facts_text += f"\n... and {len(facts) - 15} more"
embed.add_field(
name=f"Things I Remember ({len(facts)})",
value=facts_text or "Nothing yet!",
inline=False,
)
else:
embed.add_field(
name="Things I Remember",
value="Nothing yet! Use `!remember` to tell me something.",
inline=False,
)
await ctx.reply(embed=embed)
@commands.command(name="forgetme")
async def forget_me(self, ctx: commands.Context) -> None:
"""Clear all facts the bot knows about you.
Usage: !forgetme
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_user_by_discord_id(ctx.author.id)
if not user:
await ctx.reply("I don't have any information about you to forget!")
return
count = await user_service.delete_user_facts(user)
if count > 0:
await ctx.reply(f"Done! I've forgotten {count} thing(s) about you.")
else:
await ctx.reply("I didn't have anything to forget about you!")
@commands.command(name="setusername")
@commands.has_permissions(administrator=True)
async def set_user_name(
self, ctx: commands.Context, user: discord.Member, *, name: str
) -> None:
"""[Admin] Set a custom name for another user.
Usage: !setusername @user John
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(name) > 100:
await ctx.reply("Name is too long! Please use 100 characters or less.")
return
async with db.session() as session:
user_service = UserService(session)
db_user = await user_service.get_or_create_user(
discord_id=user.id,
username=user.name,
display_name=user.display_name,
)
await user_service.set_custom_name(user.id, name)
await ctx.reply(f"Got it! I'll call {user.mention} **{name}** from now on.")
@commands.command(name="teachbot")
@commands.has_permissions(administrator=True)
async def teach_bot(self, ctx: commands.Context, user: discord.Member, *, fact: str) -> None:
"""[Admin] Teach the bot a fact about a user.
Usage: !teachbot @user They are a software developer
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(fact) > 500:
await ctx.reply("That's too long! Please keep it under 500 characters.")
return
async with db.session() as session:
user_service = UserService(session)
db_user = await user_service.get_or_create_user(
discord_id=user.id,
username=user.name,
display_name=user.display_name,
)
await user_service.add_fact(
user=db_user,
fact_type="general",
fact_content=fact,
source="admin",
confidence=1.0,
)
await ctx.reply(f"I'll remember that about {user.mention}!")
@set_user_name.error
@teach_bot.error
async def admin_command_error(
self, ctx: commands.Context, error: commands.CommandError
) -> None:
"""Handle errors for admin commands."""
if isinstance(error, commands.MissingPermissions):
await ctx.reply("You need administrator permissions to use this command.")
elif isinstance(error, commands.MemberNotFound):
await ctx.reply("I couldn't find that user.")
else:
logger.error(f"Error in admin command: {error}", exc_info=True)
await ctx.reply(f"An error occurred: {error}")
async def setup(bot: commands.Bot) -> None:
"""Load the Memory cog."""
await bot.add_cog(MemoryCog(bot))

View File

@@ -67,6 +67,20 @@ class Settings(BaseSettings):
max_conversation_history: int = Field(
20, description="Max messages to keep in conversation memory per user"
)
conversation_timeout_minutes: int = Field(
60, ge=5, le=1440, description="Minutes of inactivity before starting new conversation"
)
# Database Configuration
database_url: str | None = Field(
None,
description="PostgreSQL connection URL (asyncpg format). If not set, uses in-memory storage.",
)
database_echo: bool = Field(False, description="Echo SQL statements (for debugging)")
database_pool_size: int = Field(5, ge=1, le=20, description="Database connection pool size")
database_max_overflow: int = Field(
10, ge=0, le=30, description="Max connections beyond pool size"
)
# SearXNG Configuration
searxng_url: str | None = Field(None, description="SearXNG instance URL for web search")

View File

@@ -0,0 +1,17 @@
"""Database models."""
from .base import Base
from .conversation import Conversation, Message
from .guild import Guild, GuildMember
from .user import User, UserFact, UserPreference
__all__ = [
"Base",
"Conversation",
"Guild",
"GuildMember",
"Message",
"User",
"UserFact",
"UserPreference",
]

View File

@@ -0,0 +1,28 @@
"""SQLAlchemy base model and metadata configuration."""
from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
# Naming convention for constraints (helps with migrations)
convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
metadata = MetaData(naming_convention=convention)
class Base(AsyncAttrs, DeclarativeBase):
"""Base class for all database models."""
metadata = metadata
# Common timestamp columns
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, onupdate=datetime.utcnow)

View File

@@ -0,0 +1,57 @@
"""Conversation and message database models."""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
if TYPE_CHECKING:
from .user import User
class Conversation(Base):
"""A conversation session with a user."""
__tablename__ = "conversations"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
guild_id: Mapped[int | None] = mapped_column(BigInteger)
channel_id: Mapped[int | None] = mapped_column(BigInteger, index=True)
started_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
last_message_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, index=True)
message_count: Mapped[int] = mapped_column(Integer, default=0)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
# Relationships
user: Mapped["User"] = relationship(back_populates="conversations")
messages: Mapped[list["Message"]] = relationship(
back_populates="conversation",
cascade="all, delete-orphan",
order_by="Message.created_at",
)
class Message(Base):
"""Individual chat message."""
__tablename__ = "messages"
id: Mapped[int] = mapped_column(primary_key=True)
conversation_id: Mapped[int] = mapped_column(
ForeignKey("conversations.id", ondelete="CASCADE"), index=True
)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
discord_message_id: Mapped[int | None] = mapped_column(BigInteger)
role: Mapped[str] = mapped_column(String(20)) # user, assistant, system
content: Mapped[str] = mapped_column(Text)
has_images: Mapped[bool] = mapped_column(Boolean, default=False)
image_urls: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
token_count: Mapped[int | None] = mapped_column(Integer)
# Relationships
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
user: Mapped["User"] = relationship(back_populates="messages")

View File

@@ -0,0 +1,50 @@
"""Guild (Discord server) database models."""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
if TYPE_CHECKING:
from .user import User
class Guild(Base):
"""Discord server tracking."""
__tablename__ = "guilds"
id: Mapped[int] = mapped_column(primary_key=True)
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
name: Mapped[str] = mapped_column(String(255))
joined_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
settings: Mapped[dict] = mapped_column(JSONB, default=dict)
# Relationships
members: Mapped[list["GuildMember"]] = relationship(
back_populates="guild", cascade="all, delete-orphan"
)
class GuildMember(Base):
"""Track users per guild with guild-specific settings."""
__tablename__ = "guild_members"
id: Mapped[int] = mapped_column(primary_key=True)
guild_id: Mapped[int] = mapped_column(ForeignKey("guilds.id", ondelete="CASCADE"), index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
guild_nickname: Mapped[str | None] = mapped_column(String(255))
roles: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
joined_guild_at: Mapped[datetime | None] = mapped_column(default=None)
# Relationships
guild: Mapped["Guild"] = relationship(back_populates="members")
user: Mapped["User"] = relationship(back_populates="guild_memberships")
__table_args__ = (UniqueConstraint("guild_id", "user_id"),)

View File

@@ -0,0 +1,83 @@
"""User-related database models."""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import BigInteger, Boolean, Float, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
if TYPE_CHECKING:
from .conversation import Conversation, Message
from .guild import GuildMember
class User(Base):
"""Discord user tracking."""
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
discord_username: Mapped[str] = mapped_column(String(255))
discord_display_name: Mapped[str | None] = mapped_column(String(255))
custom_name: Mapped[str | None] = mapped_column(String(255))
first_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
last_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
# Relationships
preferences: Mapped[list["UserPreference"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
facts: Mapped[list["UserFact"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
conversations: Mapped[list["Conversation"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
messages: Mapped[list["Message"]] = relationship(back_populates="user")
guild_memberships: Mapped[list["GuildMember"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
@property
def display_name(self) -> str:
"""Get the name to use when addressing this user."""
return self.custom_name or self.discord_display_name or self.discord_username
class UserPreference(Base):
"""Per-user preferences/settings."""
__tablename__ = "user_preferences"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
preference_key: Mapped[str] = mapped_column(String(100))
preference_value: Mapped[str | None] = mapped_column(Text)
# Relationships
user: Mapped["User"] = relationship(back_populates="preferences")
__table_args__ = (UniqueConstraint("user_id", "preference_key"),)
class UserFact(Base):
"""Facts the bot has learned about users."""
__tablename__ = "user_facts"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
fact_type: Mapped[str] = mapped_column(String(50), index=True) # general, preference, hobby
fact_content: Mapped[str] = mapped_column(Text)
confidence: Mapped[float] = mapped_column(Float, default=1.0)
source: Mapped[str] = mapped_column(String(50), default="conversation")
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
learned_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
last_referenced_at: Mapped[datetime | None] = mapped_column(default=None)
# Relationships
user: Mapped["User"] = relationship(back_populates="facts")

View File

@@ -2,14 +2,22 @@
from .ai_service import AIService
from .conversation import ConversationManager
from .database import DatabaseService, db, get_db
from .persistent_conversation import PersistentConversationManager
from .providers import AIResponse, ImageAttachment, Message
from .searxng import SearXNGService
from .user_service import UserService
__all__ = [
"AIService",
"AIResponse",
"ConversationManager",
"DatabaseService",
"ImageAttachment",
"Message",
"ConversationManager",
"PersistentConversationManager",
"SearXNGService",
"UserService",
"db",
"get_db",
]

View File

@@ -0,0 +1,94 @@
"""Database connection and session management."""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from daemon_boyfriend.config import settings
from daemon_boyfriend.models.base import Base
logger = logging.getLogger(__name__)
class DatabaseService:
"""Manages database connections and sessions."""
def __init__(self) -> None:
self._engine = None
self._session_factory: async_sessionmaker[AsyncSession] | None = None
self._initialized = False
@property
def is_configured(self) -> bool:
"""Check if database URL is configured."""
return settings.database_url is not None
@property
def is_initialized(self) -> bool:
"""Check if database has been initialized."""
return self._initialized
async def init(self) -> None:
"""Initialize database connection."""
if not self.is_configured:
logger.info("Database URL not configured, skipping database initialization")
return
logger.info("Initializing database connection...")
self._engine = create_async_engine(
settings.database_url,
echo=settings.database_echo,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
)
self._session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
)
self._initialized = True
logger.info("Database connection initialized")
async def close(self) -> None:
"""Close database connection."""
if self._engine:
logger.info("Closing database connection...")
await self._engine.dispose()
self._engine = None
self._session_factory = None
self._initialized = False
logger.info("Database connection closed")
async def create_tables(self) -> None:
"""Create all tables (for development/testing)."""
if not self._engine:
raise RuntimeError("Database not initialized")
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database tables created")
@asynccontextmanager
async def session(self) -> AsyncGenerator[AsyncSession, None]:
"""Get a database session with automatic commit/rollback."""
if not self._session_factory:
raise RuntimeError("Database not initialized. Call init() first.")
async with self._session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# Global instance
db = DatabaseService()
def get_db() -> DatabaseService:
"""Get the global database service instance."""
return db

View File

@@ -0,0 +1,188 @@
"""Persistent conversation management using PostgreSQL."""
import logging
from datetime import datetime, timedelta
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from daemon_boyfriend.config import settings
from daemon_boyfriend.models import Conversation, Message, User
from daemon_boyfriend.services.providers import Message as ProviderMessage
logger = logging.getLogger(__name__)
class PersistentConversationManager:
"""Manages conversation history in PostgreSQL."""
def __init__(self, session: AsyncSession, max_history: int | None = None) -> None:
self._session = session
self.max_history = max_history or settings.max_conversation_history
self._timeout = timedelta(minutes=settings.conversation_timeout_minutes)
async def get_or_create_conversation(
self,
user: User,
guild_id: int | None = None,
channel_id: int | None = None,
) -> Conversation:
"""Get active conversation or create new one.
Args:
user: User model instance
guild_id: Discord guild ID (None for DMs)
channel_id: Discord channel ID
Returns:
Conversation model instance
"""
# Look for recent active conversation in this channel
cutoff = datetime.utcnow() - self._timeout
stmt = select(Conversation).where(
Conversation.user_id == user.id,
Conversation.is_active == True, # noqa: E712
Conversation.last_message_at > cutoff,
)
if channel_id:
stmt = stmt.where(Conversation.channel_id == channel_id)
stmt = stmt.order_by(Conversation.last_message_at.desc())
result = await self._session.execute(stmt)
conversation = result.scalar_first()
if conversation:
logger.debug(
f"Found existing conversation {conversation.id} for user {user.discord_id}"
)
return conversation
# Create new conversation
conversation = Conversation(
user_id=user.id,
guild_id=guild_id,
channel_id=channel_id,
)
self._session.add(conversation)
await self._session.flush()
logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}")
return conversation
async def get_history(self, conversation: Conversation) -> list[ProviderMessage]:
"""Get conversation history as provider messages.
Args:
conversation: Conversation model instance
Returns:
List of ProviderMessage instances for the AI
"""
stmt = (
select(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(self.max_history)
)
result = await self._session.execute(stmt)
messages = list(reversed(result.scalars().all()))
return [
ProviderMessage(
role=msg.role,
content=msg.content,
# Note: Images would need special handling if needed
)
for msg in messages
]
async def add_message(
self,
conversation: Conversation,
user: User,
role: str,
content: str,
discord_message_id: int | None = None,
image_urls: list[str] | None = None,
) -> Message:
"""Add a message to the conversation.
Args:
conversation: Conversation model instance
user: User model instance
role: Message role (user, assistant, system)
content: Message content
discord_message_id: Discord's message ID (optional)
image_urls: List of image URLs (optional)
Returns:
Created Message instance
"""
message = Message(
conversation_id=conversation.id,
user_id=user.id,
role=role,
content=content,
discord_message_id=discord_message_id,
has_images=bool(image_urls),
image_urls=image_urls,
)
self._session.add(message)
# Update conversation stats
conversation.last_message_at = datetime.utcnow()
conversation.message_count += 1
await self._session.flush()
return message
async def add_exchange(
self,
conversation: Conversation,
user: User,
user_message: str,
assistant_message: str,
discord_message_id: int | None = None,
image_urls: list[str] | None = None,
) -> tuple[Message, Message]:
"""Add a user/assistant exchange to the conversation.
Args:
conversation: Conversation model instance
user: User model instance
user_message: The user's message content
assistant_message: The assistant's response
discord_message_id: Discord's message ID (optional)
image_urls: List of image URLs in user message (optional)
Returns:
Tuple of (user_message, assistant_message) Message instances
"""
user_msg = await self.add_message(
conversation, user, "user", user_message, discord_message_id, image_urls
)
assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message)
return user_msg, assistant_msg
async def clear_conversation(self, conversation: Conversation) -> None:
"""Mark a conversation as inactive.
Args:
conversation: Conversation model instance
"""
conversation.is_active = False
logger.info(f"Marked conversation {conversation.id} as inactive")
async def get_message_count(self, conversation: Conversation) -> int:
"""Get the number of messages in a conversation.
Args:
conversation: Conversation model instance
Returns:
Number of messages
"""
return conversation.message_count

View File

@@ -0,0 +1,250 @@
"""User management service."""
import logging
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from daemon_boyfriend.models import User, UserFact, UserPreference
logger = logging.getLogger(__name__)
class UserService:
"""Service for user-related operations."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_or_create_user(
self,
discord_id: int,
username: str,
display_name: str | None = None,
) -> User:
"""Get existing user or create new one.
Args:
discord_id: Discord user ID
username: Discord username
display_name: Discord display name
Returns:
User model instance
"""
stmt = select(User).where(User.discord_id == discord_id)
result = await self._session.execute(stmt)
user = result.scalar_one_or_none()
if user:
# Update last seen and current name
user.last_seen_at = datetime.utcnow()
user.discord_username = username
if display_name:
user.discord_display_name = display_name
return user
# Create new user
user = User(
discord_id=discord_id,
discord_username=username,
discord_display_name=display_name,
)
self._session.add(user)
await self._session.flush()
logger.info(f"Created new user: {username} (discord_id={discord_id})")
return user
async def get_user_by_discord_id(self, discord_id: int) -> User | None:
"""Get a user by their Discord ID.
Args:
discord_id: Discord user ID
Returns:
User if found, None otherwise
"""
stmt = select(User).where(User.discord_id == discord_id)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def set_custom_name(self, discord_id: int, custom_name: str | None) -> User | None:
"""Set a custom name for a user.
Args:
discord_id: Discord user ID
custom_name: Custom name to use, or None to clear
Returns:
Updated user if found, None otherwise
"""
user = await self.get_user_by_discord_id(discord_id)
if user:
user.custom_name = custom_name
logger.info(f"Set custom name for user {discord_id}: {custom_name}")
return user
async def add_fact(
self,
user: User,
fact_type: str,
fact_content: str,
source: str = "conversation",
confidence: float = 1.0,
) -> UserFact:
"""Add a fact about a user.
Args:
user: User model instance
fact_type: Type of fact (general, preference, hobby, relationship, etc.)
fact_content: The fact content
source: How the fact was learned (conversation, explicit, inferred)
confidence: Confidence level (0.0 to 1.0)
Returns:
Created UserFact instance
"""
fact = UserFact(
user_id=user.id,
fact_type=fact_type,
fact_content=fact_content,
source=source,
confidence=confidence,
)
self._session.add(fact)
await self._session.flush()
logger.debug(f"Added fact for user {user.discord_id}: [{fact_type}] {fact_content}")
return fact
async def get_user_facts(
self,
user: User,
fact_type: str | None = None,
active_only: bool = True,
) -> list[UserFact]:
"""Get facts about a user.
Args:
user: User model instance
fact_type: Optional filter by fact type
active_only: Only return active facts
Returns:
List of UserFact instances
"""
stmt = select(UserFact).where(UserFact.user_id == user.id)
if active_only:
stmt = stmt.where(UserFact.is_active == True) # noqa: E712
if fact_type:
stmt = stmt.where(UserFact.fact_type == fact_type)
stmt = stmt.order_by(UserFact.learned_at.desc())
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def delete_user_facts(self, user: User) -> int:
"""Soft-delete all facts for a user (set is_active=False).
Args:
user: User model instance
Returns:
Number of facts deactivated
"""
facts = await self.get_user_facts(user, active_only=True)
for fact in facts:
fact.is_active = False
logger.info(f"Deactivated {len(facts)} facts for user {user.discord_id}")
return len(facts)
async def set_preference(self, user: User, key: str, value: str | None) -> UserPreference:
"""Set a user preference.
Args:
user: User model instance
key: Preference key
value: Preference value, or None to delete
Returns:
UserPreference instance
"""
stmt = select(UserPreference).where(
UserPreference.user_id == user.id,
UserPreference.preference_key == key,
)
result = await self._session.execute(stmt)
pref = result.scalar_one_or_none()
if pref:
pref.preference_value = value
else:
pref = UserPreference(
user_id=user.id,
preference_key=key,
preference_value=value,
)
self._session.add(pref)
await self._session.flush()
return pref
async def get_preference(self, user: User, key: str) -> str | None:
"""Get a user preference value.
Args:
user: User model instance
key: Preference key
Returns:
Preference value or None if not set
"""
stmt = select(UserPreference).where(
UserPreference.user_id == user.id,
UserPreference.preference_key == key,
)
result = await self._session.execute(stmt)
pref = result.scalar_one_or_none()
return pref.preference_value if pref else None
async def get_user_context(self, user: User) -> str:
"""Build a context string about a user for the AI.
Args:
user: User model instance
Returns:
Formatted context string
"""
lines = [f"User's preferred name: {user.display_name}"]
if user.custom_name:
lines.append(f"(You should address them as: {user.custom_name})")
facts = await self.get_user_facts(user, active_only=True)
if facts:
lines.append("\nKnown facts about this user:")
for fact in facts[:20]: # Limit to 20 most recent facts
lines.append(f"- [{fact.fact_type}] {fact.fact_content}")
# Mark as referenced
fact.last_referenced_at = datetime.utcnow()
return "\n".join(lines)
async def get_user_with_facts(self, discord_id: int) -> User | None:
"""Get a user with their facts eagerly loaded.
Args:
discord_id: Discord user ID
Returns:
User with facts loaded, or None if not found
"""
stmt = select(User).where(User.discord_id == discord_id).options(selectinload(User.facts))
result = await self._session.execute(stmt)
return result.scalar_one_or_none()