Add PostgreSQL memory system for persistent user and conversation storage

- Add PostgreSQL database with SQLAlchemy async support
- Create models: User, UserFact, UserPreference, Conversation, Message, Guild, GuildMember
- Add custom name support so bot knows 'who is who'
- Add user facts system for remembering things about users
- Add persistent conversation history that survives restarts
- Add memory commands cog (!setname, !remember, !whatdoyouknow, !forgetme)
- Add admin commands (!setusername, !teachbot)
- Set up Alembic for database migrations
- Update docker-compose with PostgreSQL service
- Gracefully falls back to in-memory storage when DB not configured
This commit is contained in:
2026-01-12 14:00:06 +01:00
parent 853e7c9fcd
commit e00d4fd501
20 changed files with 1623 additions and 13 deletions

View File

@@ -11,9 +11,12 @@ pip install -r requirements.txt
# Run the bot (requires .env with DISCORD_TOKEN and AI provider key) # Run the bot (requires .env with DISCORD_TOKEN and AI provider key)
python -m daemon_boyfriend python -m daemon_boyfriend
# Run with Docker # Run with Docker (includes PostgreSQL)
docker-compose up -d docker-compose up -d
# Run database migrations
alembic upgrade head
# Syntax check all Python files # Syntax check all Python files
python -m py_compile src/daemon_boyfriend/**/*.py python -m py_compile src/daemon_boyfriend/**/*.py
``` ```
@@ -33,9 +36,25 @@ The AI system uses a provider abstraction pattern:
### Cog System ### Cog System
Discord functionality is in `cogs/`: Discord functionality is in `cogs/`:
- `ai_chat.py` - `@mention` handler (responds when bot is mentioned) - `ai_chat.py` - `@mention` handler (responds when bot is mentioned)
- `memory.py` - Memory management commands (`!setname`, `!remember`, etc.)
- `status.py` - Bot health and status commands
Cogs are auto-loaded by `bot.py` from the `cogs/` directory. Cogs are auto-loaded by `bot.py` from the `cogs/` directory.
### Database & Memory System
The bot uses PostgreSQL for persistent memory (optional, falls back to in-memory):
- `models/` - SQLAlchemy models (User, UserFact, Conversation, Message, Guild, GuildMember)
- `services/database.py` - Connection pool and async session management
- `services/user_service.py` - User CRUD, custom names, facts management
- `services/persistent_conversation.py` - Database-backed conversation history
- `alembic/` - Database migrations
Key features:
- Custom names: Set preferred names for users so the bot knows "who is who"
- User facts: Bot remembers things about users (hobbies, preferences, etc.)
- Persistent conversations: Chat history survives restarts
- Conversation timeout: New conversation starts after 60 minutes of inactivity
### Configuration ### Configuration
All config flows through `config.py` using pydantic-settings. The `settings` singleton is created at module load, so env vars must be set before importing. All config flows through `config.py` using pydantic-settings. The `settings` singleton is created at module load, so env vars must be set before importing.
@@ -47,13 +66,31 @@ The bot can search the web for current information via SearXNG:
- Configured via `SEARXNG_URL`, `SEARXNG_ENABLED`, and `SEARXNG_MAX_RESULTS` env vars - Configured via `SEARXNG_URL`, `SEARXNG_ENABLED`, and `SEARXNG_MAX_RESULTS` env vars
### Key Design Decisions ### Key Design Decisions
- `ConversationManager` stores per-user chat history in memory with configurable max length - `PersistentConversationManager` stores conversations in PostgreSQL when `DATABASE_URL` is set
- `ConversationManager` is the in-memory fallback when database is not configured
- Long AI responses are split via `split_message()` in `ai_chat.py` to respect Discord's 2000 char limit - Long AI responses are split via `split_message()` in `ai_chat.py` to respect Discord's 2000 char limit
- The bot responds only to @mentions via `on_message` listener - The bot responds only to @mentions via `on_message` listener
- Web search uses AI to decide when to search, avoiding unnecessary API calls for general knowledge questions - Web search uses AI to decide when to search, avoiding unnecessary API calls for general knowledge questions
- User context (custom name + known facts) is included in AI prompts for personalized responses
## Environment Variables ## Environment Variables
Required: `DISCORD_TOKEN`, plus one of `OPENAI_API_KEY`, `OPENROUTER_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` depending on `AI_PROVIDER` setting. Required: `DISCORD_TOKEN`, plus one of `OPENAI_API_KEY`, `OPENROUTER_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` depending on `AI_PROVIDER` setting.
Optional: `SEARXNG_URL` for web search capability. Optional:
- `DATABASE_URL` - PostgreSQL connection string (e.g., `postgresql+asyncpg://user:pass@host:5432/db`)
- `POSTGRES_PASSWORD` - Used by docker-compose for the PostgreSQL container
- `SEARXNG_URL` - SearXNG instance URL for web search capability
## Memory Commands
User commands:
- `!setname <name>` - Set your preferred name
- `!clearname` - Reset to Discord display name
- `!remember <fact>` - Tell the bot something about you
- `!whatdoyouknow` - See what the bot remembers about you
- `!forgetme` - Clear all facts about you
Admin commands:
- `!setusername @user <name>` - Set name for another user
- `!teachbot @user <fact>` - Add a fact about a user

71
alembic.ini Normal file
View File

@@ -0,0 +1,71 @@
# Alembic Configuration File
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration file names
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# sys.path path prepend for the migration environment
prepend_sys_path = src
# timezone to use when rendering the date within the migration file
timezone = UTC
# max length of characters to apply to the "slug" field
truncate_slug_length = 40
# set to 'true' to run the environment during the 'revision' command
revision_environment = false
# set to 'true' to allow .pyc and .pyo files without a source .py file
sourceless = false
# version number format
version_num_width = 4
# version path separator
version_path_separator = os
# output encoding used when revision files are written
output_encoding = utf-8
# Database URL - will be overridden by env.py from settings
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

86
alembic/env.py Normal file
View File

@@ -0,0 +1,86 @@
"""Alembic migration environment configuration."""
import asyncio
from logging.config import fileConfig
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
from daemon_boyfriend.config import settings
from daemon_boyfriend.models import Base
# this is the Alembic Config object
config = context.config
# Interpret the config file for Python logging
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Model metadata for autogenerate support
target_metadata = Base.metadata
def get_url() -> str:
"""Get database URL from settings."""
url = settings.database_url
if not url:
raise ValueError("DATABASE_URL must be set for migrations")
return url
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL and not an Engine,
though an Engine is acceptable here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the script output.
"""
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection) -> None:
"""Run migrations with the given connection."""
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in async mode."""
configuration = config.get_section(config.config_ini_section, {})
configuration["sqlalchemy.url"] = get_url()
connectable = async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

26
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,200 @@
"""Initial schema with users, conversations, messages, and guilds.
Revision ID: 0001
Revises:
Create Date: 2025-01-12
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0001"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Users table
op.create_table(
"users",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("discord_id", sa.BigInteger(), nullable=False),
sa.Column("discord_username", sa.String(255), nullable=False),
sa.Column("discord_display_name", sa.String(255), nullable=True),
sa.Column("custom_name", sa.String(255), nullable=True),
sa.Column("first_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("last_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.PrimaryKeyConstraint("id", name=op.f("pk_users")),
sa.UniqueConstraint("discord_id", name=op.f("uq_users_discord_id")),
)
op.create_index(op.f("ix_users_discord_id"), "users", ["discord_id"])
op.create_index(op.f("ix_users_last_seen_at"), "users", ["last_seen_at"])
# User preferences table
op.create_table(
"user_preferences",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("preference_key", sa.String(100), nullable=False),
sa.Column("preference_value", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name=op.f("fk_user_preferences_user_id_users"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_user_preferences")),
sa.UniqueConstraint("user_id", "preference_key", name="uq_user_preferences_user_key"),
)
op.create_index(op.f("ix_user_preferences_user_id"), "user_preferences", ["user_id"])
# User facts table
op.create_table(
"user_facts",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("fact_type", sa.String(50), nullable=False),
sa.Column("fact_content", sa.Text(), nullable=False),
sa.Column("confidence", sa.Float(), server_default="1.0"),
sa.Column("source", sa.String(50), server_default="conversation"),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("learned_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("last_referenced_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name=op.f("fk_user_facts_user_id_users"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_user_facts")),
)
op.create_index(op.f("ix_user_facts_user_id"), "user_facts", ["user_id"])
op.create_index(op.f("ix_user_facts_fact_type"), "user_facts", ["fact_type"])
op.create_index(op.f("ix_user_facts_is_active"), "user_facts", ["is_active"])
# Guilds table
op.create_table(
"guilds",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("discord_id", sa.BigInteger(), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("joined_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("settings", postgresql.JSONB(), server_default="{}"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.PrimaryKeyConstraint("id", name=op.f("pk_guilds")),
sa.UniqueConstraint("discord_id", name=op.f("uq_guilds_discord_id")),
)
op.create_index(op.f("ix_guilds_discord_id"), "guilds", ["discord_id"])
# Guild members table
op.create_table(
"guild_members",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("guild_id", sa.BigInteger(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("guild_nickname", sa.String(255), nullable=True),
sa.Column("roles", postgresql.ARRAY(sa.Text()), nullable=True),
sa.Column("joined_guild_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(
["guild_id"],
["guilds.id"],
name=op.f("fk_guild_members_guild_id_guilds"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name=op.f("fk_guild_members_user_id_users"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_guild_members")),
sa.UniqueConstraint("guild_id", "user_id", name="uq_guild_members_guild_user"),
)
op.create_index(op.f("ix_guild_members_guild_id"), "guild_members", ["guild_id"])
op.create_index(op.f("ix_guild_members_user_id"), "guild_members", ["user_id"])
# Conversations table
op.create_table(
"conversations",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("guild_id", sa.BigInteger(), nullable=True),
sa.Column("channel_id", sa.BigInteger(), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("last_message_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("message_count", sa.Integer(), server_default="0"),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name=op.f("fk_conversations_user_id_users"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_conversations")),
)
op.create_index(op.f("ix_conversations_user_id"), "conversations", ["user_id"])
op.create_index(op.f("ix_conversations_channel_id"), "conversations", ["channel_id"])
op.create_index(op.f("ix_conversations_last_message_at"), "conversations", ["last_message_at"])
op.create_index(op.f("ix_conversations_is_active"), "conversations", ["is_active"])
# Messages table
op.create_table(
"messages",
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
sa.Column("conversation_id", sa.BigInteger(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("discord_message_id", sa.BigInteger(), nullable=True),
sa.Column("role", sa.String(20), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("has_images", sa.Boolean(), nullable=False, server_default="false"),
sa.Column("image_urls", postgresql.ARRAY(sa.Text()), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.ForeignKeyConstraint(
["conversation_id"],
["conversations.id"],
name=op.f("fk_messages_conversation_id_conversations"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
name=op.f("fk_messages_user_id_users"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_messages")),
)
op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"])
op.create_index(op.f("ix_messages_user_id"), "messages", ["user_id"])
op.create_index(op.f("ix_messages_created_at"), "messages", ["created_at"])
def downgrade() -> None:
op.drop_table("messages")
op.drop_table("conversations")
op.drop_table("guild_members")
op.drop_table("guilds")
op.drop_table("user_facts")
op.drop_table("user_preferences")
op.drop_table("users")

View File

@@ -7,3 +7,26 @@ services:
- .env - .env
environment: environment:
- PYTHONUNBUFFERED=1 - PYTHONUNBUFFERED=1
- DATABASE_URL=postgresql+asyncpg://daemon:${POSTGRES_PASSWORD:-daemon}@postgres:5432/daemon_boyfriend
depends_on:
postgres:
condition: service_healthy
postgres:
image: postgres:16-alpine
container_name: daemon-boyfriend-postgres
restart: unless-stopped
environment:
POSTGRES_USER: daemon
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-daemon}
POSTGRES_DB: daemon_boyfriend
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U daemon -d daemon_boyfriend"]
interval: 10s
timeout: 5s
retries: 5
volumes:
postgres_data:

View File

@@ -13,3 +13,8 @@ aiohttp>=3.9.0
pydantic>=2.6.0 pydantic>=2.6.0
pydantic-settings>=2.2.0 pydantic-settings>=2.2.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
# Database
asyncpg>=0.29.0
sqlalchemy[asyncio]>=2.0.0
alembic>=1.13.0

View File

@@ -7,6 +7,7 @@ import discord
from discord.ext import commands from discord.ext import commands
from daemon_boyfriend.config import settings from daemon_boyfriend.config import settings
from daemon_boyfriend.services import db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,7 +28,14 @@ class DaemonBoyfriend(commands.Bot):
) )
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""Load cogs on startup.""" """Initialize database and load cogs on startup."""
# Initialize database if configured
if db.is_configured:
await db.init()
logger.info("Database initialized")
else:
logger.info("Database not configured, using in-memory storage")
# Load all cogs # Load all cogs
cogs_path = Path(__file__).parent / "cogs" cogs_path = Path(__file__).parent / "cogs"
for cog_file in cogs_path.glob("*.py"): for cog_file in cogs_path.glob("*.py"):
@@ -55,3 +63,10 @@ class DaemonBoyfriend(commands.Bot):
name=settings.bot_status, name=settings.bot_status,
) )
await self.change_presence(activity=activity) await self.change_presence(activity=activity)
async def close(self) -> None:
"""Clean up resources on shutdown."""
logger.info("Shutting down bot...")
if db.is_initialized:
await db.close()
await super().close()

View File

@@ -12,7 +12,10 @@ from daemon_boyfriend.services import (
ConversationManager, ConversationManager,
ImageAttachment, ImageAttachment,
Message, Message,
PersistentConversationManager,
SearXNGService, SearXNGService,
UserService,
db,
) )
from daemon_boyfriend.utils import get_monitor from daemon_boyfriend.utils import get_monitor
@@ -77,11 +80,17 @@ class AIChatCog(commands.Cog):
def __init__(self, bot: commands.Bot) -> None: def __init__(self, bot: commands.Bot) -> None:
self.bot = bot self.bot = bot
self.ai_service = AIService() self.ai_service = AIService()
# Fallback in-memory conversation manager (used when DB not configured)
self.conversations = ConversationManager() self.conversations = ConversationManager()
self.search_service: SearXNGService | None = None self.search_service: SearXNGService | None = None
if settings.searxng_enabled and settings.searxng_url: if settings.searxng_enabled and settings.searxng_url:
self.search_service = SearXNGService(settings.searxng_url) self.search_service = SearXNGService(settings.searxng_url)
@property
def use_database(self) -> bool:
"""Check if database is available for use."""
return db.is_initialized
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message) -> None: async def on_message(self, message: discord.Message) -> None:
"""Respond when the bot is mentioned.""" """Respond when the bot is mentioned."""
@@ -395,6 +404,95 @@ class AIChatCog(commands.Cog):
Returns: Returns:
The AI's response text The AI's response text
""" """
if self.use_database:
return await self._generate_response_with_db(message, user_message)
else:
return await self._generate_response_in_memory(message, user_message)
async def _generate_response_with_db(self, message: discord.Message, user_message: str) -> str:
"""Generate response using database-backed storage."""
async with db.session() as session:
user_service = UserService(session)
conv_manager = PersistentConversationManager(session)
# Get or create user
user = await user_service.get_or_create_user(
discord_id=message.author.id,
username=message.author.name,
display_name=message.author.display_name,
)
# Get or create conversation
conversation = await conv_manager.get_or_create_conversation(
user=user,
guild_id=message.guild.id if message.guild else None,
channel_id=message.channel.id,
)
# Get history
history = await conv_manager.get_history(conversation)
# Extract any image attachments from the message
images = self._extract_image_attachments(message)
image_urls = [img.url for img in images] if images else None
# Add current message to history for the API call
current_message = Message(role="user", content=user_message, images=images)
messages = history + [current_message]
# Check if we should search the web
search_context = await self._maybe_search(user_message)
# Get context about mentioned users
mentioned_users_context = self._get_mentioned_users_context(message)
# Build system prompt with additional context
system_prompt = self.ai_service.get_system_prompt()
# Add user context from database (custom name, known facts)
user_context = await user_service.get_user_context(user)
system_prompt += f"\n\n--- User Context ---\n{user_context}"
# Add mentioned users context
if mentioned_users_context:
system_prompt += f"\n\n--- {mentioned_users_context} ---"
# Add search results if available
if search_context:
system_prompt += (
"\n\n--- Web Search Results ---\n"
"Use the following current information from the web to help answer the user's question. "
"Cite sources when relevant.\n\n"
f"{search_context}"
)
# Generate response
response = await self.ai_service.chat(
messages=messages,
system_prompt=system_prompt,
)
# Save the exchange to database
await conv_manager.add_exchange(
conversation=conversation,
user=user,
user_message=user_message,
assistant_message=response.content,
discord_message_id=message.id,
image_urls=image_urls,
)
logger.debug(
f"Generated response for user {user.discord_id}: "
f"{len(response.content)} chars, {response.usage}"
)
return response.content
async def _generate_response_in_memory(
self, message: discord.Message, user_message: str
) -> str:
"""Generate response using in-memory storage (fallback)."""
user_id = message.author.id user_id = message.author.id
# Get conversation history # Get conversation history

View File

@@ -0,0 +1,260 @@
"""Memory management cog - commands for managing bot memory about users."""
import logging
import discord
from discord.ext import commands
from daemon_boyfriend.services import UserService, db
logger = logging.getLogger(__name__)
class MemoryCog(commands.Cog):
"""Commands for managing bot memory about users."""
def __init__(self, bot: commands.Bot) -> None:
self.bot = bot
def _check_database(self) -> bool:
"""Check if database is available."""
return db.is_initialized
@commands.command(name="setname")
async def set_name(self, ctx: commands.Context, *, name: str) -> None:
"""Set your preferred name for the bot to use.
Usage: !setname John
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(name) > 100:
await ctx.reply("Name is too long! Please use 100 characters or less.")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_or_create_user(
discord_id=ctx.author.id,
username=ctx.author.name,
display_name=ctx.author.display_name,
)
await user_service.set_custom_name(ctx.author.id, name)
await ctx.reply(f"Got it! I'll call you **{name}** from now on.")
@commands.command(name="clearname")
async def clear_name(self, ctx: commands.Context) -> None:
"""Clear your custom name and use your Discord name instead.
Usage: !clearname
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
async with db.session() as session:
user_service = UserService(session)
await user_service.set_custom_name(ctx.author.id, None)
await ctx.reply("Done! I'll use your Discord display name now.")
@commands.command(name="remember")
async def remember_fact(self, ctx: commands.Context, *, fact: str) -> None:
"""Tell the bot something to remember about you.
Usage: !remember I love pizza
Usage: !remember My favorite color is blue
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(fact) > 500:
await ctx.reply("That's too long to remember! Please keep it under 500 characters.")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_or_create_user(
discord_id=ctx.author.id,
username=ctx.author.name,
display_name=ctx.author.display_name,
)
await user_service.add_fact(
user=user,
fact_type="general",
fact_content=fact,
source="explicit",
confidence=1.0,
)
await ctx.reply(f"I'll remember that!")
@commands.command(name="whatdoyouknow", aliases=["aboutme", "myinfo"])
async def what_do_you_know(self, ctx: commands.Context) -> None:
"""Show what the bot remembers about you.
Usage: !whatdoyouknow
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_user_by_discord_id(ctx.author.id)
if not user:
await ctx.reply("I don't have any information about you yet!")
return
facts = await user_service.get_user_facts(user, active_only=True)
embed = discord.Embed(
title=f"What I know about {user.display_name}",
color=discord.Color.blue(),
)
embed.add_field(
name="Discord Username",
value=user.discord_username,
inline=True,
)
if user.custom_name:
embed.add_field(
name="Preferred Name",
value=user.custom_name,
inline=True,
)
embed.add_field(
name="First Seen",
value=user.first_seen_at.strftime("%Y-%m-%d"),
inline=True,
)
if facts:
facts_text = "\n".join(f"- {fact.fact_content}" for fact in facts[:15])
if len(facts) > 15:
facts_text += f"\n... and {len(facts) - 15} more"
embed.add_field(
name=f"Things I Remember ({len(facts)})",
value=facts_text or "Nothing yet!",
inline=False,
)
else:
embed.add_field(
name="Things I Remember",
value="Nothing yet! Use `!remember` to tell me something.",
inline=False,
)
await ctx.reply(embed=embed)
@commands.command(name="forgetme")
async def forget_me(self, ctx: commands.Context) -> None:
"""Clear all facts the bot knows about you.
Usage: !forgetme
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
async with db.session() as session:
user_service = UserService(session)
user = await user_service.get_user_by_discord_id(ctx.author.id)
if not user:
await ctx.reply("I don't have any information about you to forget!")
return
count = await user_service.delete_user_facts(user)
if count > 0:
await ctx.reply(f"Done! I've forgotten {count} thing(s) about you.")
else:
await ctx.reply("I didn't have anything to forget about you!")
@commands.command(name="setusername")
@commands.has_permissions(administrator=True)
async def set_user_name(
self, ctx: commands.Context, user: discord.Member, *, name: str
) -> None:
"""[Admin] Set a custom name for another user.
Usage: !setusername @user John
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(name) > 100:
await ctx.reply("Name is too long! Please use 100 characters or less.")
return
async with db.session() as session:
user_service = UserService(session)
db_user = await user_service.get_or_create_user(
discord_id=user.id,
username=user.name,
display_name=user.display_name,
)
await user_service.set_custom_name(user.id, name)
await ctx.reply(f"Got it! I'll call {user.mention} **{name}** from now on.")
@commands.command(name="teachbot")
@commands.has_permissions(administrator=True)
async def teach_bot(self, ctx: commands.Context, user: discord.Member, *, fact: str) -> None:
"""[Admin] Teach the bot a fact about a user.
Usage: !teachbot @user They are a software developer
"""
if not self._check_database():
await ctx.reply("Memory features are not available (database not configured).")
return
if len(fact) > 500:
await ctx.reply("That's too long! Please keep it under 500 characters.")
return
async with db.session() as session:
user_service = UserService(session)
db_user = await user_service.get_or_create_user(
discord_id=user.id,
username=user.name,
display_name=user.display_name,
)
await user_service.add_fact(
user=db_user,
fact_type="general",
fact_content=fact,
source="admin",
confidence=1.0,
)
await ctx.reply(f"I'll remember that about {user.mention}!")
@set_user_name.error
@teach_bot.error
async def admin_command_error(
self, ctx: commands.Context, error: commands.CommandError
) -> None:
"""Handle errors for admin commands."""
if isinstance(error, commands.MissingPermissions):
await ctx.reply("You need administrator permissions to use this command.")
elif isinstance(error, commands.MemberNotFound):
await ctx.reply("I couldn't find that user.")
else:
logger.error(f"Error in admin command: {error}", exc_info=True)
await ctx.reply(f"An error occurred: {error}")
async def setup(bot: commands.Bot) -> None:
"""Load the Memory cog."""
await bot.add_cog(MemoryCog(bot))

View File

@@ -67,6 +67,20 @@ class Settings(BaseSettings):
max_conversation_history: int = Field( max_conversation_history: int = Field(
20, description="Max messages to keep in conversation memory per user" 20, description="Max messages to keep in conversation memory per user"
) )
conversation_timeout_minutes: int = Field(
60, ge=5, le=1440, description="Minutes of inactivity before starting new conversation"
)
# Database Configuration
database_url: str | None = Field(
None,
description="PostgreSQL connection URL (asyncpg format). If not set, uses in-memory storage.",
)
database_echo: bool = Field(False, description="Echo SQL statements (for debugging)")
database_pool_size: int = Field(5, ge=1, le=20, description="Database connection pool size")
database_max_overflow: int = Field(
10, ge=0, le=30, description="Max connections beyond pool size"
)
# SearXNG Configuration # SearXNG Configuration
searxng_url: str | None = Field(None, description="SearXNG instance URL for web search") searxng_url: str | None = Field(None, description="SearXNG instance URL for web search")

View File

@@ -0,0 +1,17 @@
"""Database models."""
from .base import Base
from .conversation import Conversation, Message
from .guild import Guild, GuildMember
from .user import User, UserFact, UserPreference
__all__ = [
"Base",
"Conversation",
"Guild",
"GuildMember",
"Message",
"User",
"UserFact",
"UserPreference",
]

View File

@@ -0,0 +1,28 @@
"""SQLAlchemy base model and metadata configuration."""
from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
# Naming convention for constraints (helps with migrations)
convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
metadata = MetaData(naming_convention=convention)
class Base(AsyncAttrs, DeclarativeBase):
"""Base class for all database models."""
metadata = metadata
# Common timestamp columns
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
updated_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, onupdate=datetime.utcnow)

View File

@@ -0,0 +1,57 @@
"""Conversation and message database models."""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
if TYPE_CHECKING:
from .user import User
class Conversation(Base):
"""A conversation session with a user."""
__tablename__ = "conversations"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
guild_id: Mapped[int | None] = mapped_column(BigInteger)
channel_id: Mapped[int | None] = mapped_column(BigInteger, index=True)
started_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
last_message_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, index=True)
message_count: Mapped[int] = mapped_column(Integer, default=0)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
# Relationships
user: Mapped["User"] = relationship(back_populates="conversations")
messages: Mapped[list["Message"]] = relationship(
back_populates="conversation",
cascade="all, delete-orphan",
order_by="Message.created_at",
)
class Message(Base):
"""Individual chat message."""
__tablename__ = "messages"
id: Mapped[int] = mapped_column(primary_key=True)
conversation_id: Mapped[int] = mapped_column(
ForeignKey("conversations.id", ondelete="CASCADE"), index=True
)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
discord_message_id: Mapped[int | None] = mapped_column(BigInteger)
role: Mapped[str] = mapped_column(String(20)) # user, assistant, system
content: Mapped[str] = mapped_column(Text)
has_images: Mapped[bool] = mapped_column(Boolean, default=False)
image_urls: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
token_count: Mapped[int | None] = mapped_column(Integer)
# Relationships
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
user: Mapped["User"] = relationship(back_populates="messages")

View File

@@ -0,0 +1,50 @@
"""Guild (Discord server) database models."""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
if TYPE_CHECKING:
from .user import User
class Guild(Base):
"""Discord server tracking."""
__tablename__ = "guilds"
id: Mapped[int] = mapped_column(primary_key=True)
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
name: Mapped[str] = mapped_column(String(255))
joined_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
settings: Mapped[dict] = mapped_column(JSONB, default=dict)
# Relationships
members: Mapped[list["GuildMember"]] = relationship(
back_populates="guild", cascade="all, delete-orphan"
)
class GuildMember(Base):
"""Track users per guild with guild-specific settings."""
__tablename__ = "guild_members"
id: Mapped[int] = mapped_column(primary_key=True)
guild_id: Mapped[int] = mapped_column(ForeignKey("guilds.id", ondelete="CASCADE"), index=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
guild_nickname: Mapped[str | None] = mapped_column(String(255))
roles: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
joined_guild_at: Mapped[datetime | None] = mapped_column(default=None)
# Relationships
guild: Mapped["Guild"] = relationship(back_populates="members")
user: Mapped["User"] = relationship(back_populates="guild_memberships")
__table_args__ = (UniqueConstraint("guild_id", "user_id"),)

View File

@@ -0,0 +1,83 @@
"""User-related database models."""
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import BigInteger, Boolean, Float, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base
if TYPE_CHECKING:
from .conversation import Conversation, Message
from .guild import GuildMember
class User(Base):
"""Discord user tracking."""
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
discord_username: Mapped[str] = mapped_column(String(255))
discord_display_name: Mapped[str | None] = mapped_column(String(255))
custom_name: Mapped[str | None] = mapped_column(String(255))
first_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
last_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
# Relationships
preferences: Mapped[list["UserPreference"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
facts: Mapped[list["UserFact"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
conversations: Mapped[list["Conversation"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
messages: Mapped[list["Message"]] = relationship(back_populates="user")
guild_memberships: Mapped[list["GuildMember"]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
@property
def display_name(self) -> str:
"""Get the name to use when addressing this user."""
return self.custom_name or self.discord_display_name or self.discord_username
class UserPreference(Base):
"""Per-user preferences/settings."""
__tablename__ = "user_preferences"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
preference_key: Mapped[str] = mapped_column(String(100))
preference_value: Mapped[str | None] = mapped_column(Text)
# Relationships
user: Mapped["User"] = relationship(back_populates="preferences")
__table_args__ = (UniqueConstraint("user_id", "preference_key"),)
class UserFact(Base):
"""Facts the bot has learned about users."""
__tablename__ = "user_facts"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
fact_type: Mapped[str] = mapped_column(String(50), index=True) # general, preference, hobby
fact_content: Mapped[str] = mapped_column(Text)
confidence: Mapped[float] = mapped_column(Float, default=1.0)
source: Mapped[str] = mapped_column(String(50), default="conversation")
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
learned_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
last_referenced_at: Mapped[datetime | None] = mapped_column(default=None)
# Relationships
user: Mapped["User"] = relationship(back_populates="facts")

View File

@@ -2,14 +2,22 @@
from .ai_service import AIService from .ai_service import AIService
from .conversation import ConversationManager from .conversation import ConversationManager
from .database import DatabaseService, db, get_db
from .persistent_conversation import PersistentConversationManager
from .providers import AIResponse, ImageAttachment, Message from .providers import AIResponse, ImageAttachment, Message
from .searxng import SearXNGService from .searxng import SearXNGService
from .user_service import UserService
__all__ = [ __all__ = [
"AIService", "AIService",
"AIResponse", "AIResponse",
"ConversationManager",
"DatabaseService",
"ImageAttachment", "ImageAttachment",
"Message", "Message",
"ConversationManager", "PersistentConversationManager",
"SearXNGService", "SearXNGService",
"UserService",
"db",
"get_db",
] ]

View File

@@ -0,0 +1,94 @@
"""Database connection and session management."""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from daemon_boyfriend.config import settings
from daemon_boyfriend.models.base import Base
logger = logging.getLogger(__name__)
class DatabaseService:
"""Manages database connections and sessions."""
def __init__(self) -> None:
self._engine = None
self._session_factory: async_sessionmaker[AsyncSession] | None = None
self._initialized = False
@property
def is_configured(self) -> bool:
"""Check if database URL is configured."""
return settings.database_url is not None
@property
def is_initialized(self) -> bool:
"""Check if database has been initialized."""
return self._initialized
async def init(self) -> None:
"""Initialize database connection."""
if not self.is_configured:
logger.info("Database URL not configured, skipping database initialization")
return
logger.info("Initializing database connection...")
self._engine = create_async_engine(
settings.database_url,
echo=settings.database_echo,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
)
self._session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
)
self._initialized = True
logger.info("Database connection initialized")
async def close(self) -> None:
"""Close database connection."""
if self._engine:
logger.info("Closing database connection...")
await self._engine.dispose()
self._engine = None
self._session_factory = None
self._initialized = False
logger.info("Database connection closed")
async def create_tables(self) -> None:
"""Create all tables (for development/testing)."""
if not self._engine:
raise RuntimeError("Database not initialized")
async with self._engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database tables created")
@asynccontextmanager
async def session(self) -> AsyncGenerator[AsyncSession, None]:
"""Get a database session with automatic commit/rollback."""
if not self._session_factory:
raise RuntimeError("Database not initialized. Call init() first.")
async with self._session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# Global instance
db = DatabaseService()
def get_db() -> DatabaseService:
"""Get the global database service instance."""
return db

View File

@@ -0,0 +1,188 @@
"""Persistent conversation management using PostgreSQL."""
import logging
from datetime import datetime, timedelta
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from daemon_boyfriend.config import settings
from daemon_boyfriend.models import Conversation, Message, User
from daemon_boyfriend.services.providers import Message as ProviderMessage
logger = logging.getLogger(__name__)
class PersistentConversationManager:
"""Manages conversation history in PostgreSQL."""
def __init__(self, session: AsyncSession, max_history: int | None = None) -> None:
self._session = session
self.max_history = max_history or settings.max_conversation_history
self._timeout = timedelta(minutes=settings.conversation_timeout_minutes)
async def get_or_create_conversation(
self,
user: User,
guild_id: int | None = None,
channel_id: int | None = None,
) -> Conversation:
"""Get active conversation or create new one.
Args:
user: User model instance
guild_id: Discord guild ID (None for DMs)
channel_id: Discord channel ID
Returns:
Conversation model instance
"""
# Look for recent active conversation in this channel
cutoff = datetime.utcnow() - self._timeout
stmt = select(Conversation).where(
Conversation.user_id == user.id,
Conversation.is_active == True, # noqa: E712
Conversation.last_message_at > cutoff,
)
if channel_id:
stmt = stmt.where(Conversation.channel_id == channel_id)
stmt = stmt.order_by(Conversation.last_message_at.desc())
result = await self._session.execute(stmt)
conversation = result.scalar_first()
if conversation:
logger.debug(
f"Found existing conversation {conversation.id} for user {user.discord_id}"
)
return conversation
# Create new conversation
conversation = Conversation(
user_id=user.id,
guild_id=guild_id,
channel_id=channel_id,
)
self._session.add(conversation)
await self._session.flush()
logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}")
return conversation
async def get_history(self, conversation: Conversation) -> list[ProviderMessage]:
"""Get conversation history as provider messages.
Args:
conversation: Conversation model instance
Returns:
List of ProviderMessage instances for the AI
"""
stmt = (
select(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(self.max_history)
)
result = await self._session.execute(stmt)
messages = list(reversed(result.scalars().all()))
return [
ProviderMessage(
role=msg.role,
content=msg.content,
# Note: Images would need special handling if needed
)
for msg in messages
]
async def add_message(
self,
conversation: Conversation,
user: User,
role: str,
content: str,
discord_message_id: int | None = None,
image_urls: list[str] | None = None,
) -> Message:
"""Add a message to the conversation.
Args:
conversation: Conversation model instance
user: User model instance
role: Message role (user, assistant, system)
content: Message content
discord_message_id: Discord's message ID (optional)
image_urls: List of image URLs (optional)
Returns:
Created Message instance
"""
message = Message(
conversation_id=conversation.id,
user_id=user.id,
role=role,
content=content,
discord_message_id=discord_message_id,
has_images=bool(image_urls),
image_urls=image_urls,
)
self._session.add(message)
# Update conversation stats
conversation.last_message_at = datetime.utcnow()
conversation.message_count += 1
await self._session.flush()
return message
async def add_exchange(
self,
conversation: Conversation,
user: User,
user_message: str,
assistant_message: str,
discord_message_id: int | None = None,
image_urls: list[str] | None = None,
) -> tuple[Message, Message]:
"""Add a user/assistant exchange to the conversation.
Args:
conversation: Conversation model instance
user: User model instance
user_message: The user's message content
assistant_message: The assistant's response
discord_message_id: Discord's message ID (optional)
image_urls: List of image URLs in user message (optional)
Returns:
Tuple of (user_message, assistant_message) Message instances
"""
user_msg = await self.add_message(
conversation, user, "user", user_message, discord_message_id, image_urls
)
assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message)
return user_msg, assistant_msg
async def clear_conversation(self, conversation: Conversation) -> None:
"""Mark a conversation as inactive.
Args:
conversation: Conversation model instance
"""
conversation.is_active = False
logger.info(f"Marked conversation {conversation.id} as inactive")
async def get_message_count(self, conversation: Conversation) -> int:
"""Get the number of messages in a conversation.
Args:
conversation: Conversation model instance
Returns:
Number of messages
"""
return conversation.message_count

View File

@@ -0,0 +1,250 @@
"""User management service."""
import logging
from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from daemon_boyfriend.models import User, UserFact, UserPreference
logger = logging.getLogger(__name__)
class UserService:
"""Service for user-related operations."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def get_or_create_user(
self,
discord_id: int,
username: str,
display_name: str | None = None,
) -> User:
"""Get existing user or create new one.
Args:
discord_id: Discord user ID
username: Discord username
display_name: Discord display name
Returns:
User model instance
"""
stmt = select(User).where(User.discord_id == discord_id)
result = await self._session.execute(stmt)
user = result.scalar_one_or_none()
if user:
# Update last seen and current name
user.last_seen_at = datetime.utcnow()
user.discord_username = username
if display_name:
user.discord_display_name = display_name
return user
# Create new user
user = User(
discord_id=discord_id,
discord_username=username,
discord_display_name=display_name,
)
self._session.add(user)
await self._session.flush()
logger.info(f"Created new user: {username} (discord_id={discord_id})")
return user
async def get_user_by_discord_id(self, discord_id: int) -> User | None:
"""Get a user by their Discord ID.
Args:
discord_id: Discord user ID
Returns:
User if found, None otherwise
"""
stmt = select(User).where(User.discord_id == discord_id)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def set_custom_name(self, discord_id: int, custom_name: str | None) -> User | None:
"""Set a custom name for a user.
Args:
discord_id: Discord user ID
custom_name: Custom name to use, or None to clear
Returns:
Updated user if found, None otherwise
"""
user = await self.get_user_by_discord_id(discord_id)
if user:
user.custom_name = custom_name
logger.info(f"Set custom name for user {discord_id}: {custom_name}")
return user
async def add_fact(
self,
user: User,
fact_type: str,
fact_content: str,
source: str = "conversation",
confidence: float = 1.0,
) -> UserFact:
"""Add a fact about a user.
Args:
user: User model instance
fact_type: Type of fact (general, preference, hobby, relationship, etc.)
fact_content: The fact content
source: How the fact was learned (conversation, explicit, inferred)
confidence: Confidence level (0.0 to 1.0)
Returns:
Created UserFact instance
"""
fact = UserFact(
user_id=user.id,
fact_type=fact_type,
fact_content=fact_content,
source=source,
confidence=confidence,
)
self._session.add(fact)
await self._session.flush()
logger.debug(f"Added fact for user {user.discord_id}: [{fact_type}] {fact_content}")
return fact
async def get_user_facts(
self,
user: User,
fact_type: str | None = None,
active_only: bool = True,
) -> list[UserFact]:
"""Get facts about a user.
Args:
user: User model instance
fact_type: Optional filter by fact type
active_only: Only return active facts
Returns:
List of UserFact instances
"""
stmt = select(UserFact).where(UserFact.user_id == user.id)
if active_only:
stmt = stmt.where(UserFact.is_active == True) # noqa: E712
if fact_type:
stmt = stmt.where(UserFact.fact_type == fact_type)
stmt = stmt.order_by(UserFact.learned_at.desc())
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def delete_user_facts(self, user: User) -> int:
"""Soft-delete all facts for a user (set is_active=False).
Args:
user: User model instance
Returns:
Number of facts deactivated
"""
facts = await self.get_user_facts(user, active_only=True)
for fact in facts:
fact.is_active = False
logger.info(f"Deactivated {len(facts)} facts for user {user.discord_id}")
return len(facts)
async def set_preference(self, user: User, key: str, value: str | None) -> UserPreference:
"""Set a user preference.
Args:
user: User model instance
key: Preference key
value: Preference value, or None to delete
Returns:
UserPreference instance
"""
stmt = select(UserPreference).where(
UserPreference.user_id == user.id,
UserPreference.preference_key == key,
)
result = await self._session.execute(stmt)
pref = result.scalar_one_or_none()
if pref:
pref.preference_value = value
else:
pref = UserPreference(
user_id=user.id,
preference_key=key,
preference_value=value,
)
self._session.add(pref)
await self._session.flush()
return pref
async def get_preference(self, user: User, key: str) -> str | None:
"""Get a user preference value.
Args:
user: User model instance
key: Preference key
Returns:
Preference value or None if not set
"""
stmt = select(UserPreference).where(
UserPreference.user_id == user.id,
UserPreference.preference_key == key,
)
result = await self._session.execute(stmt)
pref = result.scalar_one_or_none()
return pref.preference_value if pref else None
async def get_user_context(self, user: User) -> str:
"""Build a context string about a user for the AI.
Args:
user: User model instance
Returns:
Formatted context string
"""
lines = [f"User's preferred name: {user.display_name}"]
if user.custom_name:
lines.append(f"(You should address them as: {user.custom_name})")
facts = await self.get_user_facts(user, active_only=True)
if facts:
lines.append("\nKnown facts about this user:")
for fact in facts[:20]: # Limit to 20 most recent facts
lines.append(f"- [{fact.fact_type}] {fact.fact_content}")
# Mark as referenced
fact.last_referenced_at = datetime.utcnow()
return "\n".join(lines)
async def get_user_with_facts(self, discord_id: int) -> User | None:
"""Get a user with their facts eagerly loaded.
Args:
discord_id: Discord user ID
Returns:
User with facts loaded, or None if not found
"""
stmt = select(User).where(User.discord_id == discord_id).options(selectinload(User.facts))
result = await self._session.execute(stmt)
return result.scalar_one_or_none()