From e00d4fd501760340f4b3fc0f50696ad2e13755eb Mon Sep 17 00:00:00 2001 From: latte Date: Mon, 12 Jan 2026 14:00:06 +0100 Subject: [PATCH] Add PostgreSQL memory system for persistent user and conversation storage - Add PostgreSQL database with SQLAlchemy async support - Create models: User, UserFact, UserPreference, Conversation, Message, Guild, GuildMember - Add custom name support so bot knows 'who is who' - Add user facts system for remembering things about users - Add persistent conversation history that survives restarts - Add memory commands cog (!setname, !remember, !whatdoyouknow, !forgetme) - Add admin commands (!setusername, !teachbot) - Set up Alembic for database migrations - Update docker-compose with PostgreSQL service - Gracefully falls back to in-memory storage when DB not configured --- CLAUDE.md | 43 ++- alembic.ini | 71 +++++ alembic/env.py | 86 ++++++ alembic/script.py.mako | 26 ++ .../versions/20250112_0001_initial_schema.py | 200 ++++++++++++++ docker-compose.yml | 39 ++- requirements.txt | 5 + src/daemon_boyfriend/bot.py | 17 +- src/daemon_boyfriend/cogs/ai_chat.py | 98 +++++++ src/daemon_boyfriend/cogs/memory.py | 260 ++++++++++++++++++ src/daemon_boyfriend/config.py | 14 + src/daemon_boyfriend/models/__init__.py | 17 ++ src/daemon_boyfriend/models/base.py | 28 ++ src/daemon_boyfriend/models/conversation.py | 57 ++++ src/daemon_boyfriend/models/guild.py | 50 ++++ src/daemon_boyfriend/models/user.py | 83 ++++++ src/daemon_boyfriend/services/__init__.py | 10 +- src/daemon_boyfriend/services/database.py | 94 +++++++ .../services/persistent_conversation.py | 188 +++++++++++++ src/daemon_boyfriend/services/user_service.py | 250 +++++++++++++++++ 20 files changed, 1623 insertions(+), 13 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/20250112_0001_initial_schema.py create mode 100644 src/daemon_boyfriend/cogs/memory.py create mode 100644 src/daemon_boyfriend/models/base.py create mode 100644 src/daemon_boyfriend/models/conversation.py create mode 100644 src/daemon_boyfriend/models/guild.py create mode 100644 src/daemon_boyfriend/models/user.py create mode 100644 src/daemon_boyfriend/services/database.py create mode 100644 src/daemon_boyfriend/services/persistent_conversation.py create mode 100644 src/daemon_boyfriend/services/user_service.py diff --git a/CLAUDE.md b/CLAUDE.md index 10808b4..c00eca7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,9 +11,12 @@ pip install -r requirements.txt # Run the bot (requires .env with DISCORD_TOKEN and AI provider key) python -m daemon_boyfriend -# Run with Docker +# Run with Docker (includes PostgreSQL) docker-compose up -d +# Run database migrations +alembic upgrade head + # Syntax check all Python files python -m py_compile src/daemon_boyfriend/**/*.py ``` @@ -33,9 +36,25 @@ The AI system uses a provider abstraction pattern: ### Cog System Discord functionality is in `cogs/`: - `ai_chat.py` - `@mention` handler (responds when bot is mentioned) +- `memory.py` - Memory management commands (`!setname`, `!remember`, etc.) +- `status.py` - Bot health and status commands Cogs are auto-loaded by `bot.py` from the `cogs/` directory. +### Database & Memory System +The bot uses PostgreSQL for persistent memory (optional, falls back to in-memory): +- `models/` - SQLAlchemy models (User, UserFact, Conversation, Message, Guild, GuildMember) +- `services/database.py` - Connection pool and async session management +- `services/user_service.py` - User CRUD, custom names, facts management +- `services/persistent_conversation.py` - Database-backed conversation history +- `alembic/` - Database migrations + +Key features: +- Custom names: Set preferred names for users so the bot knows "who is who" +- User facts: Bot remembers things about users (hobbies, preferences, etc.) +- Persistent conversations: Chat history survives restarts +- Conversation timeout: New conversation starts after 60 minutes of inactivity + ### Configuration All config flows through `config.py` using pydantic-settings. The `settings` singleton is created at module load, so env vars must be set before importing. @@ -47,13 +66,31 @@ The bot can search the web for current information via SearXNG: - Configured via `SEARXNG_URL`, `SEARXNG_ENABLED`, and `SEARXNG_MAX_RESULTS` env vars ### Key Design Decisions -- `ConversationManager` stores per-user chat history in memory with configurable max length +- `PersistentConversationManager` stores conversations in PostgreSQL when `DATABASE_URL` is set +- `ConversationManager` is the in-memory fallback when database is not configured - Long AI responses are split via `split_message()` in `ai_chat.py` to respect Discord's 2000 char limit - The bot responds only to @mentions via `on_message` listener - Web search uses AI to decide when to search, avoiding unnecessary API calls for general knowledge questions +- User context (custom name + known facts) is included in AI prompts for personalized responses ## Environment Variables Required: `DISCORD_TOKEN`, plus one of `OPENAI_API_KEY`, `OPENROUTER_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` depending on `AI_PROVIDER` setting. -Optional: `SEARXNG_URL` for web search capability. +Optional: +- `DATABASE_URL` - PostgreSQL connection string (e.g., `postgresql+asyncpg://user:pass@host:5432/db`) +- `POSTGRES_PASSWORD` - Used by docker-compose for the PostgreSQL container +- `SEARXNG_URL` - SearXNG instance URL for web search capability + +## Memory Commands + +User commands: +- `!setname ` - Set your preferred name +- `!clearname` - Reset to Discord display name +- `!remember ` - Tell the bot something about you +- `!whatdoyouknow` - See what the bot remembers about you +- `!forgetme` - Clear all facts about you + +Admin commands: +- `!setusername @user ` - Set name for another user +- `!teachbot @user ` - Add a fact about a user diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..10c13c7 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,71 @@ +# Alembic Configuration File + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names +file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s + +# sys.path path prepend for the migration environment +prepend_sys_path = src + +# timezone to use when rendering the date within the migration file +timezone = UTC + +# max length of characters to apply to the "slug" field +truncate_slug_length = 40 + +# set to 'true' to run the environment during the 'revision' command +revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without a source .py file +sourceless = false + +# version number format +version_num_width = 4 + +# version path separator +version_path_separator = os + +# output encoding used when revision files are written +output_encoding = utf-8 + +# Database URL - will be overridden by env.py from settings +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..b85ee68 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,86 @@ +"""Alembic migration environment configuration.""" + +import asyncio +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from alembic import context +from daemon_boyfriend.config import settings +from daemon_boyfriend.models import Base + +# this is the Alembic Config object +config = context.config + +# Interpret the config file for Python logging +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Model metadata for autogenerate support +target_metadata = Base.metadata + + +def get_url() -> str: + """Get database URL from settings.""" + url = settings.database_url + if not url: + raise ValueError("DATABASE_URL must be set for migrations") + return url + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL and not an Engine, + though an Engine is acceptable here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the script output. + """ + url = get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection) -> None: + """Run migrations with the given connection.""" + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in async mode.""" + configuration = config.get_section(config.config_ini_section, {}) + configuration["sqlalchemy.url"] = get_url() + + connectable = async_engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20250112_0001_initial_schema.py b/alembic/versions/20250112_0001_initial_schema.py new file mode 100644 index 0000000..2c9998d --- /dev/null +++ b/alembic/versions/20250112_0001_initial_schema.py @@ -0,0 +1,200 @@ +"""Initial schema with users, conversations, messages, and guilds. + +Revision ID: 0001 +Revises: +Create Date: 2025-01-12 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0001" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Users table + op.create_table( + "users", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("discord_id", sa.BigInteger(), nullable=False), + sa.Column("discord_username", sa.String(255), nullable=False), + sa.Column("discord_display_name", sa.String(255), nullable=True), + sa.Column("custom_name", sa.String(255), nullable=True), + sa.Column("first_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("last_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.PrimaryKeyConstraint("id", name=op.f("pk_users")), + sa.UniqueConstraint("discord_id", name=op.f("uq_users_discord_id")), + ) + op.create_index(op.f("ix_users_discord_id"), "users", ["discord_id"]) + op.create_index(op.f("ix_users_last_seen_at"), "users", ["last_seen_at"]) + + # User preferences table + op.create_table( + "user_preferences", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("preference_key", sa.String(100), nullable=False), + sa.Column("preference_value", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_user_preferences_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_user_preferences")), + sa.UniqueConstraint("user_id", "preference_key", name="uq_user_preferences_user_key"), + ) + op.create_index(op.f("ix_user_preferences_user_id"), "user_preferences", ["user_id"]) + + # User facts table + op.create_table( + "user_facts", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("fact_type", sa.String(50), nullable=False), + sa.Column("fact_content", sa.Text(), nullable=False), + sa.Column("confidence", sa.Float(), server_default="1.0"), + sa.Column("source", sa.String(50), server_default="conversation"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("learned_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("last_referenced_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_user_facts_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_user_facts")), + ) + op.create_index(op.f("ix_user_facts_user_id"), "user_facts", ["user_id"]) + op.create_index(op.f("ix_user_facts_fact_type"), "user_facts", ["fact_type"]) + op.create_index(op.f("ix_user_facts_is_active"), "user_facts", ["is_active"]) + + # Guilds table + op.create_table( + "guilds", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("discord_id", sa.BigInteger(), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("joined_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("settings", postgresql.JSONB(), server_default="{}"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.PrimaryKeyConstraint("id", name=op.f("pk_guilds")), + sa.UniqueConstraint("discord_id", name=op.f("uq_guilds_discord_id")), + ) + op.create_index(op.f("ix_guilds_discord_id"), "guilds", ["discord_id"]) + + # Guild members table + op.create_table( + "guild_members", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("guild_id", sa.BigInteger(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("guild_nickname", sa.String(255), nullable=True), + sa.Column("roles", postgresql.ARRAY(sa.Text()), nullable=True), + sa.Column("joined_guild_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint( + ["guild_id"], + ["guilds.id"], + name=op.f("fk_guild_members_guild_id_guilds"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_guild_members_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_guild_members")), + sa.UniqueConstraint("guild_id", "user_id", name="uq_guild_members_guild_user"), + ) + op.create_index(op.f("ix_guild_members_guild_id"), "guild_members", ["guild_id"]) + op.create_index(op.f("ix_guild_members_user_id"), "guild_members", ["user_id"]) + + # Conversations table + op.create_table( + "conversations", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("guild_id", sa.BigInteger(), nullable=True), + sa.Column("channel_id", sa.BigInteger(), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("last_message_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("message_count", sa.Integer(), server_default="0"), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_conversations_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_conversations")), + ) + op.create_index(op.f("ix_conversations_user_id"), "conversations", ["user_id"]) + op.create_index(op.f("ix_conversations_channel_id"), "conversations", ["channel_id"]) + op.create_index(op.f("ix_conversations_last_message_at"), "conversations", ["last_message_at"]) + op.create_index(op.f("ix_conversations_is_active"), "conversations", ["is_active"]) + + # Messages table + op.create_table( + "messages", + sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False), + sa.Column("conversation_id", sa.BigInteger(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("discord_message_id", sa.BigInteger(), nullable=True), + sa.Column("role", sa.String(20), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("has_images", sa.Boolean(), nullable=False, server_default="false"), + sa.Column("image_urls", postgresql.ARRAY(sa.Text()), nullable=True), + sa.Column("token_count", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.ForeignKeyConstraint( + ["conversation_id"], + ["conversations.id"], + name=op.f("fk_messages_conversation_id_conversations"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_messages_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_messages")), + ) + op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"]) + op.create_index(op.f("ix_messages_user_id"), "messages", ["user_id"]) + op.create_index(op.f("ix_messages_created_at"), "messages", ["created_at"]) + + +def downgrade() -> None: + op.drop_table("messages") + op.drop_table("conversations") + op.drop_table("guild_members") + op.drop_table("guilds") + op.drop_table("user_facts") + op.drop_table("user_preferences") + op.drop_table("users") diff --git a/docker-compose.yml b/docker-compose.yml index 15201a5..80318ae 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,32 @@ services: - daemon-boyfriend: - build: . - container_name: daemon-boyfriend - restart: unless-stopped - env_file: - - .env - environment: - - PYTHONUNBUFFERED=1 + daemon-boyfriend: + build: . + container_name: daemon-boyfriend + restart: unless-stopped + env_file: + - .env + environment: + - PYTHONUNBUFFERED=1 + - DATABASE_URL=postgresql+asyncpg://daemon:${POSTGRES_PASSWORD:-daemon}@postgres:5432/daemon_boyfriend + depends_on: + postgres: + condition: service_healthy + + postgres: + image: postgres:16-alpine + container_name: daemon-boyfriend-postgres + restart: unless-stopped + environment: + POSTGRES_USER: daemon + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-daemon} + POSTGRES_DB: daemon_boyfriend + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U daemon -d daemon_boyfriend"] + interval: 10s + timeout: 5s + retries: 5 + +volumes: + postgres_data: diff --git a/requirements.txt b/requirements.txt index 22516f5..a2ec44e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,8 @@ aiohttp>=3.9.0 pydantic>=2.6.0 pydantic-settings>=2.2.0 python-dotenv>=1.0.0 + +# Database +asyncpg>=0.29.0 +sqlalchemy[asyncio]>=2.0.0 +alembic>=1.13.0 diff --git a/src/daemon_boyfriend/bot.py b/src/daemon_boyfriend/bot.py index 99f7467..3c3dfc6 100644 --- a/src/daemon_boyfriend/bot.py +++ b/src/daemon_boyfriend/bot.py @@ -7,6 +7,7 @@ import discord from discord.ext import commands from daemon_boyfriend.config import settings +from daemon_boyfriend.services import db logger = logging.getLogger(__name__) @@ -27,7 +28,14 @@ class DaemonBoyfriend(commands.Bot): ) async def setup_hook(self) -> None: - """Load cogs on startup.""" + """Initialize database and load cogs on startup.""" + # Initialize database if configured + if db.is_configured: + await db.init() + logger.info("Database initialized") + else: + logger.info("Database not configured, using in-memory storage") + # Load all cogs cogs_path = Path(__file__).parent / "cogs" for cog_file in cogs_path.glob("*.py"): @@ -55,3 +63,10 @@ class DaemonBoyfriend(commands.Bot): name=settings.bot_status, ) await self.change_presence(activity=activity) + + async def close(self) -> None: + """Clean up resources on shutdown.""" + logger.info("Shutting down bot...") + if db.is_initialized: + await db.close() + await super().close() diff --git a/src/daemon_boyfriend/cogs/ai_chat.py b/src/daemon_boyfriend/cogs/ai_chat.py index 63d6b41..eff3ebc 100644 --- a/src/daemon_boyfriend/cogs/ai_chat.py +++ b/src/daemon_boyfriend/cogs/ai_chat.py @@ -12,7 +12,10 @@ from daemon_boyfriend.services import ( ConversationManager, ImageAttachment, Message, + PersistentConversationManager, SearXNGService, + UserService, + db, ) from daemon_boyfriend.utils import get_monitor @@ -77,11 +80,17 @@ class AIChatCog(commands.Cog): def __init__(self, bot: commands.Bot) -> None: self.bot = bot self.ai_service = AIService() + # Fallback in-memory conversation manager (used when DB not configured) self.conversations = ConversationManager() self.search_service: SearXNGService | None = None if settings.searxng_enabled and settings.searxng_url: self.search_service = SearXNGService(settings.searxng_url) + @property + def use_database(self) -> bool: + """Check if database is available for use.""" + return db.is_initialized + @commands.Cog.listener() async def on_message(self, message: discord.Message) -> None: """Respond when the bot is mentioned.""" @@ -395,6 +404,95 @@ class AIChatCog(commands.Cog): Returns: The AI's response text """ + if self.use_database: + return await self._generate_response_with_db(message, user_message) + else: + return await self._generate_response_in_memory(message, user_message) + + async def _generate_response_with_db(self, message: discord.Message, user_message: str) -> str: + """Generate response using database-backed storage.""" + async with db.session() as session: + user_service = UserService(session) + conv_manager = PersistentConversationManager(session) + + # Get or create user + user = await user_service.get_or_create_user( + discord_id=message.author.id, + username=message.author.name, + display_name=message.author.display_name, + ) + + # Get or create conversation + conversation = await conv_manager.get_or_create_conversation( + user=user, + guild_id=message.guild.id if message.guild else None, + channel_id=message.channel.id, + ) + + # Get history + history = await conv_manager.get_history(conversation) + + # Extract any image attachments from the message + images = self._extract_image_attachments(message) + image_urls = [img.url for img in images] if images else None + + # Add current message to history for the API call + current_message = Message(role="user", content=user_message, images=images) + messages = history + [current_message] + + # Check if we should search the web + search_context = await self._maybe_search(user_message) + + # Get context about mentioned users + mentioned_users_context = self._get_mentioned_users_context(message) + + # Build system prompt with additional context + system_prompt = self.ai_service.get_system_prompt() + + # Add user context from database (custom name, known facts) + user_context = await user_service.get_user_context(user) + system_prompt += f"\n\n--- User Context ---\n{user_context}" + + # Add mentioned users context + if mentioned_users_context: + system_prompt += f"\n\n--- {mentioned_users_context} ---" + + # Add search results if available + if search_context: + system_prompt += ( + "\n\n--- Web Search Results ---\n" + "Use the following current information from the web to help answer the user's question. " + "Cite sources when relevant.\n\n" + f"{search_context}" + ) + + # Generate response + response = await self.ai_service.chat( + messages=messages, + system_prompt=system_prompt, + ) + + # Save the exchange to database + await conv_manager.add_exchange( + conversation=conversation, + user=user, + user_message=user_message, + assistant_message=response.content, + discord_message_id=message.id, + image_urls=image_urls, + ) + + logger.debug( + f"Generated response for user {user.discord_id}: " + f"{len(response.content)} chars, {response.usage}" + ) + + return response.content + + async def _generate_response_in_memory( + self, message: discord.Message, user_message: str + ) -> str: + """Generate response using in-memory storage (fallback).""" user_id = message.author.id # Get conversation history diff --git a/src/daemon_boyfriend/cogs/memory.py b/src/daemon_boyfriend/cogs/memory.py new file mode 100644 index 0000000..e1dca31 --- /dev/null +++ b/src/daemon_boyfriend/cogs/memory.py @@ -0,0 +1,260 @@ +"""Memory management cog - commands for managing bot memory about users.""" + +import logging + +import discord +from discord.ext import commands + +from daemon_boyfriend.services import UserService, db + +logger = logging.getLogger(__name__) + + +class MemoryCog(commands.Cog): + """Commands for managing bot memory about users.""" + + def __init__(self, bot: commands.Bot) -> None: + self.bot = bot + + def _check_database(self) -> bool: + """Check if database is available.""" + return db.is_initialized + + @commands.command(name="setname") + async def set_name(self, ctx: commands.Context, *, name: str) -> None: + """Set your preferred name for the bot to use. + + Usage: !setname John + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + if len(name) > 100: + await ctx.reply("Name is too long! Please use 100 characters or less.") + return + + async with db.session() as session: + user_service = UserService(session) + user = await user_service.get_or_create_user( + discord_id=ctx.author.id, + username=ctx.author.name, + display_name=ctx.author.display_name, + ) + await user_service.set_custom_name(ctx.author.id, name) + + await ctx.reply(f"Got it! I'll call you **{name}** from now on.") + + @commands.command(name="clearname") + async def clear_name(self, ctx: commands.Context) -> None: + """Clear your custom name and use your Discord name instead. + + Usage: !clearname + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + async with db.session() as session: + user_service = UserService(session) + await user_service.set_custom_name(ctx.author.id, None) + + await ctx.reply("Done! I'll use your Discord display name now.") + + @commands.command(name="remember") + async def remember_fact(self, ctx: commands.Context, *, fact: str) -> None: + """Tell the bot something to remember about you. + + Usage: !remember I love pizza + Usage: !remember My favorite color is blue + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + if len(fact) > 500: + await ctx.reply("That's too long to remember! Please keep it under 500 characters.") + return + + async with db.session() as session: + user_service = UserService(session) + user = await user_service.get_or_create_user( + discord_id=ctx.author.id, + username=ctx.author.name, + display_name=ctx.author.display_name, + ) + await user_service.add_fact( + user=user, + fact_type="general", + fact_content=fact, + source="explicit", + confidence=1.0, + ) + + await ctx.reply(f"I'll remember that!") + + @commands.command(name="whatdoyouknow", aliases=["aboutme", "myinfo"]) + async def what_do_you_know(self, ctx: commands.Context) -> None: + """Show what the bot remembers about you. + + Usage: !whatdoyouknow + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + async with db.session() as session: + user_service = UserService(session) + user = await user_service.get_user_by_discord_id(ctx.author.id) + + if not user: + await ctx.reply("I don't have any information about you yet!") + return + + facts = await user_service.get_user_facts(user, active_only=True) + + embed = discord.Embed( + title=f"What I know about {user.display_name}", + color=discord.Color.blue(), + ) + + embed.add_field( + name="Discord Username", + value=user.discord_username, + inline=True, + ) + + if user.custom_name: + embed.add_field( + name="Preferred Name", + value=user.custom_name, + inline=True, + ) + + embed.add_field( + name="First Seen", + value=user.first_seen_at.strftime("%Y-%m-%d"), + inline=True, + ) + + if facts: + facts_text = "\n".join(f"- {fact.fact_content}" for fact in facts[:15]) + if len(facts) > 15: + facts_text += f"\n... and {len(facts) - 15} more" + embed.add_field( + name=f"Things I Remember ({len(facts)})", + value=facts_text or "Nothing yet!", + inline=False, + ) + else: + embed.add_field( + name="Things I Remember", + value="Nothing yet! Use `!remember` to tell me something.", + inline=False, + ) + + await ctx.reply(embed=embed) + + @commands.command(name="forgetme") + async def forget_me(self, ctx: commands.Context) -> None: + """Clear all facts the bot knows about you. + + Usage: !forgetme + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + async with db.session() as session: + user_service = UserService(session) + user = await user_service.get_user_by_discord_id(ctx.author.id) + + if not user: + await ctx.reply("I don't have any information about you to forget!") + return + + count = await user_service.delete_user_facts(user) + + if count > 0: + await ctx.reply(f"Done! I've forgotten {count} thing(s) about you.") + else: + await ctx.reply("I didn't have anything to forget about you!") + + @commands.command(name="setusername") + @commands.has_permissions(administrator=True) + async def set_user_name( + self, ctx: commands.Context, user: discord.Member, *, name: str + ) -> None: + """[Admin] Set a custom name for another user. + + Usage: !setusername @user John + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + if len(name) > 100: + await ctx.reply("Name is too long! Please use 100 characters or less.") + return + + async with db.session() as session: + user_service = UserService(session) + db_user = await user_service.get_or_create_user( + discord_id=user.id, + username=user.name, + display_name=user.display_name, + ) + await user_service.set_custom_name(user.id, name) + + await ctx.reply(f"Got it! I'll call {user.mention} **{name}** from now on.") + + @commands.command(name="teachbot") + @commands.has_permissions(administrator=True) + async def teach_bot(self, ctx: commands.Context, user: discord.Member, *, fact: str) -> None: + """[Admin] Teach the bot a fact about a user. + + Usage: !teachbot @user They are a software developer + """ + if not self._check_database(): + await ctx.reply("Memory features are not available (database not configured).") + return + + if len(fact) > 500: + await ctx.reply("That's too long! Please keep it under 500 characters.") + return + + async with db.session() as session: + user_service = UserService(session) + db_user = await user_service.get_or_create_user( + discord_id=user.id, + username=user.name, + display_name=user.display_name, + ) + await user_service.add_fact( + user=db_user, + fact_type="general", + fact_content=fact, + source="admin", + confidence=1.0, + ) + + await ctx.reply(f"I'll remember that about {user.mention}!") + + @set_user_name.error + @teach_bot.error + async def admin_command_error( + self, ctx: commands.Context, error: commands.CommandError + ) -> None: + """Handle errors for admin commands.""" + if isinstance(error, commands.MissingPermissions): + await ctx.reply("You need administrator permissions to use this command.") + elif isinstance(error, commands.MemberNotFound): + await ctx.reply("I couldn't find that user.") + else: + logger.error(f"Error in admin command: {error}", exc_info=True) + await ctx.reply(f"An error occurred: {error}") + + +async def setup(bot: commands.Bot) -> None: + """Load the Memory cog.""" + await bot.add_cog(MemoryCog(bot)) diff --git a/src/daemon_boyfriend/config.py b/src/daemon_boyfriend/config.py index 20a2a99..e9c6131 100644 --- a/src/daemon_boyfriend/config.py +++ b/src/daemon_boyfriend/config.py @@ -67,6 +67,20 @@ class Settings(BaseSettings): max_conversation_history: int = Field( 20, description="Max messages to keep in conversation memory per user" ) + conversation_timeout_minutes: int = Field( + 60, ge=5, le=1440, description="Minutes of inactivity before starting new conversation" + ) + + # Database Configuration + database_url: str | None = Field( + None, + description="PostgreSQL connection URL (asyncpg format). If not set, uses in-memory storage.", + ) + database_echo: bool = Field(False, description="Echo SQL statements (for debugging)") + database_pool_size: int = Field(5, ge=1, le=20, description="Database connection pool size") + database_max_overflow: int = Field( + 10, ge=0, le=30, description="Max connections beyond pool size" + ) # SearXNG Configuration searxng_url: str | None = Field(None, description="SearXNG instance URL for web search") diff --git a/src/daemon_boyfriend/models/__init__.py b/src/daemon_boyfriend/models/__init__.py index e69de29..4791377 100644 --- a/src/daemon_boyfriend/models/__init__.py +++ b/src/daemon_boyfriend/models/__init__.py @@ -0,0 +1,17 @@ +"""Database models.""" + +from .base import Base +from .conversation import Conversation, Message +from .guild import Guild, GuildMember +from .user import User, UserFact, UserPreference + +__all__ = [ + "Base", + "Conversation", + "Guild", + "GuildMember", + "Message", + "User", + "UserFact", + "UserPreference", +] diff --git a/src/daemon_boyfriend/models/base.py b/src/daemon_boyfriend/models/base.py new file mode 100644 index 0000000..407ce98 --- /dev/null +++ b/src/daemon_boyfriend/models/base.py @@ -0,0 +1,28 @@ +"""SQLAlchemy base model and metadata configuration.""" + +from datetime import datetime + +from sqlalchemy import MetaData +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +# Naming convention for constraints (helps with migrations) +convention = { + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +} + +metadata = MetaData(naming_convention=convention) + + +class Base(AsyncAttrs, DeclarativeBase): + """Base class for all database models.""" + + metadata = metadata + + # Common timestamp columns + created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, onupdate=datetime.utcnow) diff --git a/src/daemon_boyfriend/models/conversation.py b/src/daemon_boyfriend/models/conversation.py new file mode 100644 index 0000000..5b0b0f7 --- /dev/null +++ b/src/daemon_boyfriend/models/conversation.py @@ -0,0 +1,57 @@ +"""Conversation and message database models.""" + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base import Base + +if TYPE_CHECKING: + from .user import User + + +class Conversation(Base): + """A conversation session with a user.""" + + __tablename__ = "conversations" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + guild_id: Mapped[int | None] = mapped_column(BigInteger) + channel_id: Mapped[int | None] = mapped_column(BigInteger, index=True) + started_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + last_message_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, index=True) + message_count: Mapped[int] = mapped_column(Integer, default=0) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True) + + # Relationships + user: Mapped["User"] = relationship(back_populates="conversations") + messages: Mapped[list["Message"]] = relationship( + back_populates="conversation", + cascade="all, delete-orphan", + order_by="Message.created_at", + ) + + +class Message(Base): + """Individual chat message.""" + + __tablename__ = "messages" + + id: Mapped[int] = mapped_column(primary_key=True) + conversation_id: Mapped[int] = mapped_column( + ForeignKey("conversations.id", ondelete="CASCADE"), index=True + ) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + discord_message_id: Mapped[int | None] = mapped_column(BigInteger) + role: Mapped[str] = mapped_column(String(20)) # user, assistant, system + content: Mapped[str] = mapped_column(Text) + has_images: Mapped[bool] = mapped_column(Boolean, default=False) + image_urls: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None) + token_count: Mapped[int | None] = mapped_column(Integer) + + # Relationships + conversation: Mapped["Conversation"] = relationship(back_populates="messages") + user: Mapped["User"] = relationship(back_populates="messages") diff --git a/src/daemon_boyfriend/models/guild.py b/src/daemon_boyfriend/models/guild.py new file mode 100644 index 0000000..df42f52 --- /dev/null +++ b/src/daemon_boyfriend/models/guild.py @@ -0,0 +1,50 @@ +"""Guild (Discord server) database models.""" + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, String, Text, UniqueConstraint +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base import Base + +if TYPE_CHECKING: + from .user import User + + +class Guild(Base): + """Discord server tracking.""" + + __tablename__ = "guilds" + + id: Mapped[int] = mapped_column(primary_key=True) + discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True) + name: Mapped[str] = mapped_column(String(255)) + joined_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + is_active: Mapped[bool] = mapped_column(Boolean, default=True) + settings: Mapped[dict] = mapped_column(JSONB, default=dict) + + # Relationships + members: Mapped[list["GuildMember"]] = relationship( + back_populates="guild", cascade="all, delete-orphan" + ) + + +class GuildMember(Base): + """Track users per guild with guild-specific settings.""" + + __tablename__ = "guild_members" + + id: Mapped[int] = mapped_column(primary_key=True) + guild_id: Mapped[int] = mapped_column(ForeignKey("guilds.id", ondelete="CASCADE"), index=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + guild_nickname: Mapped[str | None] = mapped_column(String(255)) + roles: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None) + joined_guild_at: Mapped[datetime | None] = mapped_column(default=None) + + # Relationships + guild: Mapped["Guild"] = relationship(back_populates="members") + user: Mapped["User"] = relationship(back_populates="guild_memberships") + + __table_args__ = (UniqueConstraint("guild_id", "user_id"),) diff --git a/src/daemon_boyfriend/models/user.py b/src/daemon_boyfriend/models/user.py new file mode 100644 index 0000000..00b44bd --- /dev/null +++ b/src/daemon_boyfriend/models/user.py @@ -0,0 +1,83 @@ +"""User-related database models.""" + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import BigInteger, Boolean, Float, ForeignKey, String, Text, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base import Base + +if TYPE_CHECKING: + from .conversation import Conversation, Message + from .guild import GuildMember + + +class User(Base): + """Discord user tracking.""" + + __tablename__ = "users" + + id: Mapped[int] = mapped_column(primary_key=True) + discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True) + discord_username: Mapped[str] = mapped_column(String(255)) + discord_display_name: Mapped[str | None] = mapped_column(String(255)) + custom_name: Mapped[str | None] = mapped_column(String(255)) + first_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + last_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + is_active: Mapped[bool] = mapped_column(Boolean, default=True) + + # Relationships + preferences: Mapped[list["UserPreference"]] = relationship( + back_populates="user", cascade="all, delete-orphan" + ) + facts: Mapped[list["UserFact"]] = relationship( + back_populates="user", cascade="all, delete-orphan" + ) + conversations: Mapped[list["Conversation"]] = relationship( + back_populates="user", cascade="all, delete-orphan" + ) + messages: Mapped[list["Message"]] = relationship(back_populates="user") + guild_memberships: Mapped[list["GuildMember"]] = relationship( + back_populates="user", cascade="all, delete-orphan" + ) + + @property + def display_name(self) -> str: + """Get the name to use when addressing this user.""" + return self.custom_name or self.discord_display_name or self.discord_username + + +class UserPreference(Base): + """Per-user preferences/settings.""" + + __tablename__ = "user_preferences" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + preference_key: Mapped[str] = mapped_column(String(100)) + preference_value: Mapped[str | None] = mapped_column(Text) + + # Relationships + user: Mapped["User"] = relationship(back_populates="preferences") + + __table_args__ = (UniqueConstraint("user_id", "preference_key"),) + + +class UserFact(Base): + """Facts the bot has learned about users.""" + + __tablename__ = "user_facts" + + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + fact_type: Mapped[str] = mapped_column(String(50), index=True) # general, preference, hobby + fact_content: Mapped[str] = mapped_column(Text) + confidence: Mapped[float] = mapped_column(Float, default=1.0) + source: Mapped[str] = mapped_column(String(50), default="conversation") + is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True) + learned_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) + last_referenced_at: Mapped[datetime | None] = mapped_column(default=None) + + # Relationships + user: Mapped["User"] = relationship(back_populates="facts") diff --git a/src/daemon_boyfriend/services/__init__.py b/src/daemon_boyfriend/services/__init__.py index 2b33978..44f9713 100644 --- a/src/daemon_boyfriend/services/__init__.py +++ b/src/daemon_boyfriend/services/__init__.py @@ -2,14 +2,22 @@ from .ai_service import AIService from .conversation import ConversationManager +from .database import DatabaseService, db, get_db +from .persistent_conversation import PersistentConversationManager from .providers import AIResponse, ImageAttachment, Message from .searxng import SearXNGService +from .user_service import UserService __all__ = [ "AIService", "AIResponse", + "ConversationManager", + "DatabaseService", "ImageAttachment", "Message", - "ConversationManager", + "PersistentConversationManager", "SearXNGService", + "UserService", + "db", + "get_db", ] diff --git a/src/daemon_boyfriend/services/database.py b/src/daemon_boyfriend/services/database.py new file mode 100644 index 0000000..28a417f --- /dev/null +++ b/src/daemon_boyfriend/services/database.py @@ -0,0 +1,94 @@ +"""Database connection and session management.""" + +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from daemon_boyfriend.config import settings +from daemon_boyfriend.models.base import Base + +logger = logging.getLogger(__name__) + + +class DatabaseService: + """Manages database connections and sessions.""" + + def __init__(self) -> None: + self._engine = None + self._session_factory: async_sessionmaker[AsyncSession] | None = None + self._initialized = False + + @property + def is_configured(self) -> bool: + """Check if database URL is configured.""" + return settings.database_url is not None + + @property + def is_initialized(self) -> bool: + """Check if database has been initialized.""" + return self._initialized + + async def init(self) -> None: + """Initialize database connection.""" + if not self.is_configured: + logger.info("Database URL not configured, skipping database initialization") + return + + logger.info("Initializing database connection...") + self._engine = create_async_engine( + settings.database_url, + echo=settings.database_echo, + pool_size=settings.database_pool_size, + max_overflow=settings.database_max_overflow, + ) + self._session_factory = async_sessionmaker( + self._engine, + class_=AsyncSession, + expire_on_commit=False, + ) + self._initialized = True + logger.info("Database connection initialized") + + async def close(self) -> None: + """Close database connection.""" + if self._engine: + logger.info("Closing database connection...") + await self._engine.dispose() + self._engine = None + self._session_factory = None + self._initialized = False + logger.info("Database connection closed") + + async def create_tables(self) -> None: + """Create all tables (for development/testing).""" + if not self._engine: + raise RuntimeError("Database not initialized") + + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("Database tables created") + + @asynccontextmanager + async def session(self) -> AsyncGenerator[AsyncSession, None]: + """Get a database session with automatic commit/rollback.""" + if not self._session_factory: + raise RuntimeError("Database not initialized. Call init() first.") + + async with self._session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + +# Global instance +db = DatabaseService() + + +def get_db() -> DatabaseService: + """Get the global database service instance.""" + return db diff --git a/src/daemon_boyfriend/services/persistent_conversation.py b/src/daemon_boyfriend/services/persistent_conversation.py new file mode 100644 index 0000000..9ee6258 --- /dev/null +++ b/src/daemon_boyfriend/services/persistent_conversation.py @@ -0,0 +1,188 @@ +"""Persistent conversation management using PostgreSQL.""" + +import logging +from datetime import datetime, timedelta + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from daemon_boyfriend.config import settings +from daemon_boyfriend.models import Conversation, Message, User +from daemon_boyfriend.services.providers import Message as ProviderMessage + +logger = logging.getLogger(__name__) + + +class PersistentConversationManager: + """Manages conversation history in PostgreSQL.""" + + def __init__(self, session: AsyncSession, max_history: int | None = None) -> None: + self._session = session + self.max_history = max_history or settings.max_conversation_history + self._timeout = timedelta(minutes=settings.conversation_timeout_minutes) + + async def get_or_create_conversation( + self, + user: User, + guild_id: int | None = None, + channel_id: int | None = None, + ) -> Conversation: + """Get active conversation or create new one. + + Args: + user: User model instance + guild_id: Discord guild ID (None for DMs) + channel_id: Discord channel ID + + Returns: + Conversation model instance + """ + # Look for recent active conversation in this channel + cutoff = datetime.utcnow() - self._timeout + + stmt = select(Conversation).where( + Conversation.user_id == user.id, + Conversation.is_active == True, # noqa: E712 + Conversation.last_message_at > cutoff, + ) + + if channel_id: + stmt = stmt.where(Conversation.channel_id == channel_id) + + stmt = stmt.order_by(Conversation.last_message_at.desc()) + + result = await self._session.execute(stmt) + conversation = result.scalar_first() + + if conversation: + logger.debug( + f"Found existing conversation {conversation.id} for user {user.discord_id}" + ) + return conversation + + # Create new conversation + conversation = Conversation( + user_id=user.id, + guild_id=guild_id, + channel_id=channel_id, + ) + self._session.add(conversation) + await self._session.flush() + logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}") + return conversation + + async def get_history(self, conversation: Conversation) -> list[ProviderMessage]: + """Get conversation history as provider messages. + + Args: + conversation: Conversation model instance + + Returns: + List of ProviderMessage instances for the AI + """ + stmt = ( + select(Message) + .where(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(self.max_history) + ) + + result = await self._session.execute(stmt) + messages = list(reversed(result.scalars().all())) + + return [ + ProviderMessage( + role=msg.role, + content=msg.content, + # Note: Images would need special handling if needed + ) + for msg in messages + ] + + async def add_message( + self, + conversation: Conversation, + user: User, + role: str, + content: str, + discord_message_id: int | None = None, + image_urls: list[str] | None = None, + ) -> Message: + """Add a message to the conversation. + + Args: + conversation: Conversation model instance + user: User model instance + role: Message role (user, assistant, system) + content: Message content + discord_message_id: Discord's message ID (optional) + image_urls: List of image URLs (optional) + + Returns: + Created Message instance + """ + message = Message( + conversation_id=conversation.id, + user_id=user.id, + role=role, + content=content, + discord_message_id=discord_message_id, + has_images=bool(image_urls), + image_urls=image_urls, + ) + self._session.add(message) + + # Update conversation stats + conversation.last_message_at = datetime.utcnow() + conversation.message_count += 1 + + await self._session.flush() + return message + + async def add_exchange( + self, + conversation: Conversation, + user: User, + user_message: str, + assistant_message: str, + discord_message_id: int | None = None, + image_urls: list[str] | None = None, + ) -> tuple[Message, Message]: + """Add a user/assistant exchange to the conversation. + + Args: + conversation: Conversation model instance + user: User model instance + user_message: The user's message content + assistant_message: The assistant's response + discord_message_id: Discord's message ID (optional) + image_urls: List of image URLs in user message (optional) + + Returns: + Tuple of (user_message, assistant_message) Message instances + """ + user_msg = await self.add_message( + conversation, user, "user", user_message, discord_message_id, image_urls + ) + assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message) + return user_msg, assistant_msg + + async def clear_conversation(self, conversation: Conversation) -> None: + """Mark a conversation as inactive. + + Args: + conversation: Conversation model instance + """ + conversation.is_active = False + logger.info(f"Marked conversation {conversation.id} as inactive") + + async def get_message_count(self, conversation: Conversation) -> int: + """Get the number of messages in a conversation. + + Args: + conversation: Conversation model instance + + Returns: + Number of messages + """ + return conversation.message_count diff --git a/src/daemon_boyfriend/services/user_service.py b/src/daemon_boyfriend/services/user_service.py new file mode 100644 index 0000000..d06c98b --- /dev/null +++ b/src/daemon_boyfriend/services/user_service.py @@ -0,0 +1,250 @@ +"""User management service.""" + +import logging +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from daemon_boyfriend.models import User, UserFact, UserPreference + +logger = logging.getLogger(__name__) + + +class UserService: + """Service for user-related operations.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get_or_create_user( + self, + discord_id: int, + username: str, + display_name: str | None = None, + ) -> User: + """Get existing user or create new one. + + Args: + discord_id: Discord user ID + username: Discord username + display_name: Discord display name + + Returns: + User model instance + """ + stmt = select(User).where(User.discord_id == discord_id) + result = await self._session.execute(stmt) + user = result.scalar_one_or_none() + + if user: + # Update last seen and current name + user.last_seen_at = datetime.utcnow() + user.discord_username = username + if display_name: + user.discord_display_name = display_name + return user + + # Create new user + user = User( + discord_id=discord_id, + discord_username=username, + discord_display_name=display_name, + ) + self._session.add(user) + await self._session.flush() + logger.info(f"Created new user: {username} (discord_id={discord_id})") + return user + + async def get_user_by_discord_id(self, discord_id: int) -> User | None: + """Get a user by their Discord ID. + + Args: + discord_id: Discord user ID + + Returns: + User if found, None otherwise + """ + stmt = select(User).where(User.discord_id == discord_id) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def set_custom_name(self, discord_id: int, custom_name: str | None) -> User | None: + """Set a custom name for a user. + + Args: + discord_id: Discord user ID + custom_name: Custom name to use, or None to clear + + Returns: + Updated user if found, None otherwise + """ + user = await self.get_user_by_discord_id(discord_id) + if user: + user.custom_name = custom_name + logger.info(f"Set custom name for user {discord_id}: {custom_name}") + return user + + async def add_fact( + self, + user: User, + fact_type: str, + fact_content: str, + source: str = "conversation", + confidence: float = 1.0, + ) -> UserFact: + """Add a fact about a user. + + Args: + user: User model instance + fact_type: Type of fact (general, preference, hobby, relationship, etc.) + fact_content: The fact content + source: How the fact was learned (conversation, explicit, inferred) + confidence: Confidence level (0.0 to 1.0) + + Returns: + Created UserFact instance + """ + fact = UserFact( + user_id=user.id, + fact_type=fact_type, + fact_content=fact_content, + source=source, + confidence=confidence, + ) + self._session.add(fact) + await self._session.flush() + logger.debug(f"Added fact for user {user.discord_id}: [{fact_type}] {fact_content}") + return fact + + async def get_user_facts( + self, + user: User, + fact_type: str | None = None, + active_only: bool = True, + ) -> list[UserFact]: + """Get facts about a user. + + Args: + user: User model instance + fact_type: Optional filter by fact type + active_only: Only return active facts + + Returns: + List of UserFact instances + """ + stmt = select(UserFact).where(UserFact.user_id == user.id) + + if active_only: + stmt = stmt.where(UserFact.is_active == True) # noqa: E712 + + if fact_type: + stmt = stmt.where(UserFact.fact_type == fact_type) + + stmt = stmt.order_by(UserFact.learned_at.desc()) + + result = await self._session.execute(stmt) + return list(result.scalars().all()) + + async def delete_user_facts(self, user: User) -> int: + """Soft-delete all facts for a user (set is_active=False). + + Args: + user: User model instance + + Returns: + Number of facts deactivated + """ + facts = await self.get_user_facts(user, active_only=True) + for fact in facts: + fact.is_active = False + logger.info(f"Deactivated {len(facts)} facts for user {user.discord_id}") + return len(facts) + + async def set_preference(self, user: User, key: str, value: str | None) -> UserPreference: + """Set a user preference. + + Args: + user: User model instance + key: Preference key + value: Preference value, or None to delete + + Returns: + UserPreference instance + """ + stmt = select(UserPreference).where( + UserPreference.user_id == user.id, + UserPreference.preference_key == key, + ) + result = await self._session.execute(stmt) + pref = result.scalar_one_or_none() + + if pref: + pref.preference_value = value + else: + pref = UserPreference( + user_id=user.id, + preference_key=key, + preference_value=value, + ) + self._session.add(pref) + await self._session.flush() + + return pref + + async def get_preference(self, user: User, key: str) -> str | None: + """Get a user preference value. + + Args: + user: User model instance + key: Preference key + + Returns: + Preference value or None if not set + """ + stmt = select(UserPreference).where( + UserPreference.user_id == user.id, + UserPreference.preference_key == key, + ) + result = await self._session.execute(stmt) + pref = result.scalar_one_or_none() + return pref.preference_value if pref else None + + async def get_user_context(self, user: User) -> str: + """Build a context string about a user for the AI. + + Args: + user: User model instance + + Returns: + Formatted context string + """ + lines = [f"User's preferred name: {user.display_name}"] + + if user.custom_name: + lines.append(f"(You should address them as: {user.custom_name})") + + facts = await self.get_user_facts(user, active_only=True) + + if facts: + lines.append("\nKnown facts about this user:") + for fact in facts[:20]: # Limit to 20 most recent facts + lines.append(f"- [{fact.fact_type}] {fact.fact_content}") + # Mark as referenced + fact.last_referenced_at = datetime.utcnow() + + return "\n".join(lines) + + async def get_user_with_facts(self, discord_id: int) -> User | None: + """Get a user with their facts eagerly loaded. + + Args: + discord_id: Discord user ID + + Returns: + User with facts loaded, or None if not found + """ + stmt = select(User).where(User.discord_id == discord_id).options(selectinload(User.facts)) + result = await self._session.execute(stmt) + return result.scalar_one_or_none()