Add PostgreSQL memory system for persistent user and conversation storage
- Add PostgreSQL database with SQLAlchemy async support - Create models: User, UserFact, UserPreference, Conversation, Message, Guild, GuildMember - Add custom name support so bot knows 'who is who' - Add user facts system for remembering things about users - Add persistent conversation history that survives restarts - Add memory commands cog (!setname, !remember, !whatdoyouknow, !forgetme) - Add admin commands (!setusername, !teachbot) - Set up Alembic for database migrations - Update docker-compose with PostgreSQL service - Gracefully falls back to in-memory storage when DB not configured
This commit is contained in:
43
CLAUDE.md
43
CLAUDE.md
@@ -11,9 +11,12 @@ pip install -r requirements.txt
|
|||||||
# Run the bot (requires .env with DISCORD_TOKEN and AI provider key)
|
# Run the bot (requires .env with DISCORD_TOKEN and AI provider key)
|
||||||
python -m daemon_boyfriend
|
python -m daemon_boyfriend
|
||||||
|
|
||||||
# Run with Docker
|
# Run with Docker (includes PostgreSQL)
|
||||||
docker-compose up -d
|
docker-compose up -d
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
# Syntax check all Python files
|
# Syntax check all Python files
|
||||||
python -m py_compile src/daemon_boyfriend/**/*.py
|
python -m py_compile src/daemon_boyfriend/**/*.py
|
||||||
```
|
```
|
||||||
@@ -33,9 +36,25 @@ The AI system uses a provider abstraction pattern:
|
|||||||
### Cog System
|
### Cog System
|
||||||
Discord functionality is in `cogs/`:
|
Discord functionality is in `cogs/`:
|
||||||
- `ai_chat.py` - `@mention` handler (responds when bot is mentioned)
|
- `ai_chat.py` - `@mention` handler (responds when bot is mentioned)
|
||||||
|
- `memory.py` - Memory management commands (`!setname`, `!remember`, etc.)
|
||||||
|
- `status.py` - Bot health and status commands
|
||||||
|
|
||||||
Cogs are auto-loaded by `bot.py` from the `cogs/` directory.
|
Cogs are auto-loaded by `bot.py` from the `cogs/` directory.
|
||||||
|
|
||||||
|
### Database & Memory System
|
||||||
|
The bot uses PostgreSQL for persistent memory (optional, falls back to in-memory):
|
||||||
|
- `models/` - SQLAlchemy models (User, UserFact, Conversation, Message, Guild, GuildMember)
|
||||||
|
- `services/database.py` - Connection pool and async session management
|
||||||
|
- `services/user_service.py` - User CRUD, custom names, facts management
|
||||||
|
- `services/persistent_conversation.py` - Database-backed conversation history
|
||||||
|
- `alembic/` - Database migrations
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Custom names: Set preferred names for users so the bot knows "who is who"
|
||||||
|
- User facts: Bot remembers things about users (hobbies, preferences, etc.)
|
||||||
|
- Persistent conversations: Chat history survives restarts
|
||||||
|
- Conversation timeout: New conversation starts after 60 minutes of inactivity
|
||||||
|
|
||||||
### Configuration
|
### Configuration
|
||||||
All config flows through `config.py` using pydantic-settings. The `settings` singleton is created at module load, so env vars must be set before importing.
|
All config flows through `config.py` using pydantic-settings. The `settings` singleton is created at module load, so env vars must be set before importing.
|
||||||
|
|
||||||
@@ -47,13 +66,31 @@ The bot can search the web for current information via SearXNG:
|
|||||||
- Configured via `SEARXNG_URL`, `SEARXNG_ENABLED`, and `SEARXNG_MAX_RESULTS` env vars
|
- Configured via `SEARXNG_URL`, `SEARXNG_ENABLED`, and `SEARXNG_MAX_RESULTS` env vars
|
||||||
|
|
||||||
### Key Design Decisions
|
### Key Design Decisions
|
||||||
- `ConversationManager` stores per-user chat history in memory with configurable max length
|
- `PersistentConversationManager` stores conversations in PostgreSQL when `DATABASE_URL` is set
|
||||||
|
- `ConversationManager` is the in-memory fallback when database is not configured
|
||||||
- Long AI responses are split via `split_message()` in `ai_chat.py` to respect Discord's 2000 char limit
|
- Long AI responses are split via `split_message()` in `ai_chat.py` to respect Discord's 2000 char limit
|
||||||
- The bot responds only to @mentions via `on_message` listener
|
- The bot responds only to @mentions via `on_message` listener
|
||||||
- Web search uses AI to decide when to search, avoiding unnecessary API calls for general knowledge questions
|
- Web search uses AI to decide when to search, avoiding unnecessary API calls for general knowledge questions
|
||||||
|
- User context (custom name + known facts) is included in AI prompts for personalized responses
|
||||||
|
|
||||||
## Environment Variables
|
## Environment Variables
|
||||||
|
|
||||||
Required: `DISCORD_TOKEN`, plus one of `OPENAI_API_KEY`, `OPENROUTER_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` depending on `AI_PROVIDER` setting.
|
Required: `DISCORD_TOKEN`, plus one of `OPENAI_API_KEY`, `OPENROUTER_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` depending on `AI_PROVIDER` setting.
|
||||||
|
|
||||||
Optional: `SEARXNG_URL` for web search capability.
|
Optional:
|
||||||
|
- `DATABASE_URL` - PostgreSQL connection string (e.g., `postgresql+asyncpg://user:pass@host:5432/db`)
|
||||||
|
- `POSTGRES_PASSWORD` - Used by docker-compose for the PostgreSQL container
|
||||||
|
- `SEARXNG_URL` - SearXNG instance URL for web search capability
|
||||||
|
|
||||||
|
## Memory Commands
|
||||||
|
|
||||||
|
User commands:
|
||||||
|
- `!setname <name>` - Set your preferred name
|
||||||
|
- `!clearname` - Reset to Discord display name
|
||||||
|
- `!remember <fact>` - Tell the bot something about you
|
||||||
|
- `!whatdoyouknow` - See what the bot remembers about you
|
||||||
|
- `!forgetme` - Clear all facts about you
|
||||||
|
|
||||||
|
Admin commands:
|
||||||
|
- `!setusername @user <name>` - Set name for another user
|
||||||
|
- `!teachbot @user <fact>` - Add a fact about a user
|
||||||
|
|||||||
71
alembic.ini
Normal file
71
alembic.ini
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# Alembic Configuration File
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names
|
||||||
|
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path prepend for the migration environment
|
||||||
|
prepend_sys_path = src
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
timezone = UTC
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during the 'revision' command
|
||||||
|
revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without a source .py file
|
||||||
|
sourceless = false
|
||||||
|
|
||||||
|
# version number format
|
||||||
|
version_num_width = 4
|
||||||
|
|
||||||
|
# version path separator
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# output encoding used when revision files are written
|
||||||
|
output_encoding = utf-8
|
||||||
|
|
||||||
|
# Database URL - will be overridden by env.py from settings
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
86
alembic/env.py
Normal file
86
alembic/env.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""Alembic migration environment configuration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from daemon_boyfriend.config import settings
|
||||||
|
from daemon_boyfriend.models import Base
|
||||||
|
|
||||||
|
# this is the Alembic Config object
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Model metadata for autogenerate support
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_url() -> str:
|
||||||
|
"""Get database URL from settings."""
|
||||||
|
url = settings.database_url
|
||||||
|
if not url:
|
||||||
|
raise ValueError("DATABASE_URL must be set for migrations")
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL and not an Engine,
|
||||||
|
though an Engine is acceptable here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the script output.
|
||||||
|
"""
|
||||||
|
url = get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection) -> None:
|
||||||
|
"""Run migrations with the given connection."""
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""Run migrations in async mode."""
|
||||||
|
configuration = config.get_section(config.config_ini_section, {})
|
||||||
|
configuration["sqlalchemy.url"] = get_url()
|
||||||
|
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
configuration,
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode."""
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
26
alembic/script.py.mako
Normal file
26
alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
200
alembic/versions/20250112_0001_initial_schema.py
Normal file
200
alembic/versions/20250112_0001_initial_schema.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""Initial schema with users, conversations, messages, and guilds.
|
||||||
|
|
||||||
|
Revision ID: 0001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-01-12
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0001"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Users table
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("discord_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("discord_username", sa.String(255), nullable=False),
|
||||||
|
sa.Column("discord_display_name", sa.String(255), nullable=True),
|
||||||
|
sa.Column("custom_name", sa.String(255), nullable=True),
|
||||||
|
sa.Column("first_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("last_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_users")),
|
||||||
|
sa.UniqueConstraint("discord_id", name=op.f("uq_users_discord_id")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_users_discord_id"), "users", ["discord_id"])
|
||||||
|
op.create_index(op.f("ix_users_last_seen_at"), "users", ["last_seen_at"])
|
||||||
|
|
||||||
|
# User preferences table
|
||||||
|
op.create_table(
|
||||||
|
"user_preferences",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("preference_key", sa.String(100), nullable=False),
|
||||||
|
sa.Column("preference_value", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
name=op.f("fk_user_preferences_user_id_users"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_user_preferences")),
|
||||||
|
sa.UniqueConstraint("user_id", "preference_key", name="uq_user_preferences_user_key"),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_user_preferences_user_id"), "user_preferences", ["user_id"])
|
||||||
|
|
||||||
|
# User facts table
|
||||||
|
op.create_table(
|
||||||
|
"user_facts",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("fact_type", sa.String(50), nullable=False),
|
||||||
|
sa.Column("fact_content", sa.Text(), nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float(), server_default="1.0"),
|
||||||
|
sa.Column("source", sa.String(50), server_default="conversation"),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
sa.Column("learned_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("last_referenced_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
name=op.f("fk_user_facts_user_id_users"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_user_facts")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_user_facts_user_id"), "user_facts", ["user_id"])
|
||||||
|
op.create_index(op.f("ix_user_facts_fact_type"), "user_facts", ["fact_type"])
|
||||||
|
op.create_index(op.f("ix_user_facts_is_active"), "user_facts", ["is_active"])
|
||||||
|
|
||||||
|
# Guilds table
|
||||||
|
op.create_table(
|
||||||
|
"guilds",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("discord_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("joined_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
sa.Column("settings", postgresql.JSONB(), server_default="{}"),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_guilds")),
|
||||||
|
sa.UniqueConstraint("discord_id", name=op.f("uq_guilds_discord_id")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_guilds_discord_id"), "guilds", ["discord_id"])
|
||||||
|
|
||||||
|
# Guild members table
|
||||||
|
op.create_table(
|
||||||
|
"guild_members",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("guild_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("guild_nickname", sa.String(255), nullable=True),
|
||||||
|
sa.Column("roles", postgresql.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("joined_guild_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["guild_id"],
|
||||||
|
["guilds.id"],
|
||||||
|
name=op.f("fk_guild_members_guild_id_guilds"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
name=op.f("fk_guild_members_user_id_users"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_guild_members")),
|
||||||
|
sa.UniqueConstraint("guild_id", "user_id", name="uq_guild_members_guild_user"),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_guild_members_guild_id"), "guild_members", ["guild_id"])
|
||||||
|
op.create_index(op.f("ix_guild_members_user_id"), "guild_members", ["user_id"])
|
||||||
|
|
||||||
|
# Conversations table
|
||||||
|
op.create_table(
|
||||||
|
"conversations",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("guild_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("channel_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("last_message_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("message_count", sa.Integer(), server_default="0"),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
name=op.f("fk_conversations_user_id_users"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_conversations")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_conversations_user_id"), "conversations", ["user_id"])
|
||||||
|
op.create_index(op.f("ix_conversations_channel_id"), "conversations", ["channel_id"])
|
||||||
|
op.create_index(op.f("ix_conversations_last_message_at"), "conversations", ["last_message_at"])
|
||||||
|
op.create_index(op.f("ix_conversations_is_active"), "conversations", ["is_active"])
|
||||||
|
|
||||||
|
# Messages table
|
||||||
|
op.create_table(
|
||||||
|
"messages",
|
||||||
|
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("conversation_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("discord_message_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("role", sa.String(20), nullable=False),
|
||||||
|
sa.Column("content", sa.Text(), nullable=False),
|
||||||
|
sa.Column("has_images", sa.Boolean(), nullable=False, server_default="false"),
|
||||||
|
sa.Column("image_urls", postgresql.ARRAY(sa.Text()), nullable=True),
|
||||||
|
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["conversation_id"],
|
||||||
|
["conversations.id"],
|
||||||
|
name=op.f("fk_messages_conversation_id_conversations"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.id"],
|
||||||
|
name=op.f("fk_messages_user_id_users"),
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("pk_messages")),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"])
|
||||||
|
op.create_index(op.f("ix_messages_user_id"), "messages", ["user_id"])
|
||||||
|
op.create_index(op.f("ix_messages_created_at"), "messages", ["created_at"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("messages")
|
||||||
|
op.drop_table("conversations")
|
||||||
|
op.drop_table("guild_members")
|
||||||
|
op.drop_table("guilds")
|
||||||
|
op.drop_table("user_facts")
|
||||||
|
op.drop_table("user_preferences")
|
||||||
|
op.drop_table("users")
|
||||||
@@ -1,9 +1,32 @@
|
|||||||
services:
|
services:
|
||||||
daemon-boyfriend:
|
daemon-boyfriend:
|
||||||
build: .
|
build: .
|
||||||
container_name: daemon-boyfriend
|
container_name: daemon-boyfriend
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
- DATABASE_URL=postgresql+asyncpg://daemon:${POSTGRES_PASSWORD:-daemon}@postgres:5432/daemon_boyfriend
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
postgres:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
container_name: daemon-boyfriend-postgres
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: daemon
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-daemon}
|
||||||
|
POSTGRES_DB: daemon_boyfriend
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U daemon -d daemon_boyfriend"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
|||||||
@@ -13,3 +13,8 @@ aiohttp>=3.9.0
|
|||||||
pydantic>=2.6.0
|
pydantic>=2.6.0
|
||||||
pydantic-settings>=2.2.0
|
pydantic-settings>=2.2.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
|
|
||||||
|
# Database
|
||||||
|
asyncpg>=0.29.0
|
||||||
|
sqlalchemy[asyncio]>=2.0.0
|
||||||
|
alembic>=1.13.0
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import discord
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from daemon_boyfriend.config import settings
|
from daemon_boyfriend.config import settings
|
||||||
|
from daemon_boyfriend.services import db
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -27,7 +28,14 @@ class DaemonBoyfriend(commands.Bot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Load cogs on startup."""
|
"""Initialize database and load cogs on startup."""
|
||||||
|
# Initialize database if configured
|
||||||
|
if db.is_configured:
|
||||||
|
await db.init()
|
||||||
|
logger.info("Database initialized")
|
||||||
|
else:
|
||||||
|
logger.info("Database not configured, using in-memory storage")
|
||||||
|
|
||||||
# Load all cogs
|
# Load all cogs
|
||||||
cogs_path = Path(__file__).parent / "cogs"
|
cogs_path = Path(__file__).parent / "cogs"
|
||||||
for cog_file in cogs_path.glob("*.py"):
|
for cog_file in cogs_path.glob("*.py"):
|
||||||
@@ -55,3 +63,10 @@ class DaemonBoyfriend(commands.Bot):
|
|||||||
name=settings.bot_status,
|
name=settings.bot_status,
|
||||||
)
|
)
|
||||||
await self.change_presence(activity=activity)
|
await self.change_presence(activity=activity)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Clean up resources on shutdown."""
|
||||||
|
logger.info("Shutting down bot...")
|
||||||
|
if db.is_initialized:
|
||||||
|
await db.close()
|
||||||
|
await super().close()
|
||||||
|
|||||||
@@ -12,7 +12,10 @@ from daemon_boyfriend.services import (
|
|||||||
ConversationManager,
|
ConversationManager,
|
||||||
ImageAttachment,
|
ImageAttachment,
|
||||||
Message,
|
Message,
|
||||||
|
PersistentConversationManager,
|
||||||
SearXNGService,
|
SearXNGService,
|
||||||
|
UserService,
|
||||||
|
db,
|
||||||
)
|
)
|
||||||
from daemon_boyfriend.utils import get_monitor
|
from daemon_boyfriend.utils import get_monitor
|
||||||
|
|
||||||
@@ -77,11 +80,17 @@ class AIChatCog(commands.Cog):
|
|||||||
def __init__(self, bot: commands.Bot) -> None:
|
def __init__(self, bot: commands.Bot) -> None:
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.ai_service = AIService()
|
self.ai_service = AIService()
|
||||||
|
# Fallback in-memory conversation manager (used when DB not configured)
|
||||||
self.conversations = ConversationManager()
|
self.conversations = ConversationManager()
|
||||||
self.search_service: SearXNGService | None = None
|
self.search_service: SearXNGService | None = None
|
||||||
if settings.searxng_enabled and settings.searxng_url:
|
if settings.searxng_enabled and settings.searxng_url:
|
||||||
self.search_service = SearXNGService(settings.searxng_url)
|
self.search_service = SearXNGService(settings.searxng_url)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_database(self) -> bool:
|
||||||
|
"""Check if database is available for use."""
|
||||||
|
return db.is_initialized
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_message(self, message: discord.Message) -> None:
|
async def on_message(self, message: discord.Message) -> None:
|
||||||
"""Respond when the bot is mentioned."""
|
"""Respond when the bot is mentioned."""
|
||||||
@@ -395,6 +404,95 @@ class AIChatCog(commands.Cog):
|
|||||||
Returns:
|
Returns:
|
||||||
The AI's response text
|
The AI's response text
|
||||||
"""
|
"""
|
||||||
|
if self.use_database:
|
||||||
|
return await self._generate_response_with_db(message, user_message)
|
||||||
|
else:
|
||||||
|
return await self._generate_response_in_memory(message, user_message)
|
||||||
|
|
||||||
|
async def _generate_response_with_db(self, message: discord.Message, user_message: str) -> str:
|
||||||
|
"""Generate response using database-backed storage."""
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
conv_manager = PersistentConversationManager(session)
|
||||||
|
|
||||||
|
# Get or create user
|
||||||
|
user = await user_service.get_or_create_user(
|
||||||
|
discord_id=message.author.id,
|
||||||
|
username=message.author.name,
|
||||||
|
display_name=message.author.display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or create conversation
|
||||||
|
conversation = await conv_manager.get_or_create_conversation(
|
||||||
|
user=user,
|
||||||
|
guild_id=message.guild.id if message.guild else None,
|
||||||
|
channel_id=message.channel.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get history
|
||||||
|
history = await conv_manager.get_history(conversation)
|
||||||
|
|
||||||
|
# Extract any image attachments from the message
|
||||||
|
images = self._extract_image_attachments(message)
|
||||||
|
image_urls = [img.url for img in images] if images else None
|
||||||
|
|
||||||
|
# Add current message to history for the API call
|
||||||
|
current_message = Message(role="user", content=user_message, images=images)
|
||||||
|
messages = history + [current_message]
|
||||||
|
|
||||||
|
# Check if we should search the web
|
||||||
|
search_context = await self._maybe_search(user_message)
|
||||||
|
|
||||||
|
# Get context about mentioned users
|
||||||
|
mentioned_users_context = self._get_mentioned_users_context(message)
|
||||||
|
|
||||||
|
# Build system prompt with additional context
|
||||||
|
system_prompt = self.ai_service.get_system_prompt()
|
||||||
|
|
||||||
|
# Add user context from database (custom name, known facts)
|
||||||
|
user_context = await user_service.get_user_context(user)
|
||||||
|
system_prompt += f"\n\n--- User Context ---\n{user_context}"
|
||||||
|
|
||||||
|
# Add mentioned users context
|
||||||
|
if mentioned_users_context:
|
||||||
|
system_prompt += f"\n\n--- {mentioned_users_context} ---"
|
||||||
|
|
||||||
|
# Add search results if available
|
||||||
|
if search_context:
|
||||||
|
system_prompt += (
|
||||||
|
"\n\n--- Web Search Results ---\n"
|
||||||
|
"Use the following current information from the web to help answer the user's question. "
|
||||||
|
"Cite sources when relevant.\n\n"
|
||||||
|
f"{search_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate response
|
||||||
|
response = await self.ai_service.chat(
|
||||||
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the exchange to database
|
||||||
|
await conv_manager.add_exchange(
|
||||||
|
conversation=conversation,
|
||||||
|
user=user,
|
||||||
|
user_message=user_message,
|
||||||
|
assistant_message=response.content,
|
||||||
|
discord_message_id=message.id,
|
||||||
|
image_urls=image_urls,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Generated response for user {user.discord_id}: "
|
||||||
|
f"{len(response.content)} chars, {response.usage}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
async def _generate_response_in_memory(
|
||||||
|
self, message: discord.Message, user_message: str
|
||||||
|
) -> str:
|
||||||
|
"""Generate response using in-memory storage (fallback)."""
|
||||||
user_id = message.author.id
|
user_id = message.author.id
|
||||||
|
|
||||||
# Get conversation history
|
# Get conversation history
|
||||||
|
|||||||
260
src/daemon_boyfriend/cogs/memory.py
Normal file
260
src/daemon_boyfriend/cogs/memory.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
"""Memory management cog - commands for managing bot memory about users."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from daemon_boyfriend.services import UserService, db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCog(commands.Cog):
|
||||||
|
"""Commands for managing bot memory about users."""
|
||||||
|
|
||||||
|
def __init__(self, bot: commands.Bot) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
def _check_database(self) -> bool:
|
||||||
|
"""Check if database is available."""
|
||||||
|
return db.is_initialized
|
||||||
|
|
||||||
|
@commands.command(name="setname")
|
||||||
|
async def set_name(self, ctx: commands.Context, *, name: str) -> None:
|
||||||
|
"""Set your preferred name for the bot to use.
|
||||||
|
|
||||||
|
Usage: !setname John
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(name) > 100:
|
||||||
|
await ctx.reply("Name is too long! Please use 100 characters or less.")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
user = await user_service.get_or_create_user(
|
||||||
|
discord_id=ctx.author.id,
|
||||||
|
username=ctx.author.name,
|
||||||
|
display_name=ctx.author.display_name,
|
||||||
|
)
|
||||||
|
await user_service.set_custom_name(ctx.author.id, name)
|
||||||
|
|
||||||
|
await ctx.reply(f"Got it! I'll call you **{name}** from now on.")
|
||||||
|
|
||||||
|
@commands.command(name="clearname")
|
||||||
|
async def clear_name(self, ctx: commands.Context) -> None:
|
||||||
|
"""Clear your custom name and use your Discord name instead.
|
||||||
|
|
||||||
|
Usage: !clearname
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
await user_service.set_custom_name(ctx.author.id, None)
|
||||||
|
|
||||||
|
await ctx.reply("Done! I'll use your Discord display name now.")
|
||||||
|
|
||||||
|
@commands.command(name="remember")
|
||||||
|
async def remember_fact(self, ctx: commands.Context, *, fact: str) -> None:
|
||||||
|
"""Tell the bot something to remember about you.
|
||||||
|
|
||||||
|
Usage: !remember I love pizza
|
||||||
|
Usage: !remember My favorite color is blue
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(fact) > 500:
|
||||||
|
await ctx.reply("That's too long to remember! Please keep it under 500 characters.")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
user = await user_service.get_or_create_user(
|
||||||
|
discord_id=ctx.author.id,
|
||||||
|
username=ctx.author.name,
|
||||||
|
display_name=ctx.author.display_name,
|
||||||
|
)
|
||||||
|
await user_service.add_fact(
|
||||||
|
user=user,
|
||||||
|
fact_type="general",
|
||||||
|
fact_content=fact,
|
||||||
|
source="explicit",
|
||||||
|
confidence=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.reply(f"I'll remember that!")
|
||||||
|
|
||||||
|
@commands.command(name="whatdoyouknow", aliases=["aboutme", "myinfo"])
|
||||||
|
async def what_do_you_know(self, ctx: commands.Context) -> None:
|
||||||
|
"""Show what the bot remembers about you.
|
||||||
|
|
||||||
|
Usage: !whatdoyouknow
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
user = await user_service.get_user_by_discord_id(ctx.author.id)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
await ctx.reply("I don't have any information about you yet!")
|
||||||
|
return
|
||||||
|
|
||||||
|
facts = await user_service.get_user_facts(user, active_only=True)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"What I know about {user.display_name}",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="Discord Username",
|
||||||
|
value=user.discord_username,
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user.custom_name:
|
||||||
|
embed.add_field(
|
||||||
|
name="Preferred Name",
|
||||||
|
value=user.custom_name,
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="First Seen",
|
||||||
|
value=user.first_seen_at.strftime("%Y-%m-%d"),
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if facts:
|
||||||
|
facts_text = "\n".join(f"- {fact.fact_content}" for fact in facts[:15])
|
||||||
|
if len(facts) > 15:
|
||||||
|
facts_text += f"\n... and {len(facts) - 15} more"
|
||||||
|
embed.add_field(
|
||||||
|
name=f"Things I Remember ({len(facts)})",
|
||||||
|
value=facts_text or "Nothing yet!",
|
||||||
|
inline=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embed.add_field(
|
||||||
|
name="Things I Remember",
|
||||||
|
value="Nothing yet! Use `!remember` to tell me something.",
|
||||||
|
inline=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.reply(embed=embed)
|
||||||
|
|
||||||
|
@commands.command(name="forgetme")
|
||||||
|
async def forget_me(self, ctx: commands.Context) -> None:
|
||||||
|
"""Clear all facts the bot knows about you.
|
||||||
|
|
||||||
|
Usage: !forgetme
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
user = await user_service.get_user_by_discord_id(ctx.author.id)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
await ctx.reply("I don't have any information about you to forget!")
|
||||||
|
return
|
||||||
|
|
||||||
|
count = await user_service.delete_user_facts(user)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
await ctx.reply(f"Done! I've forgotten {count} thing(s) about you.")
|
||||||
|
else:
|
||||||
|
await ctx.reply("I didn't have anything to forget about you!")
|
||||||
|
|
||||||
|
@commands.command(name="setusername")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
async def set_user_name(
|
||||||
|
self, ctx: commands.Context, user: discord.Member, *, name: str
|
||||||
|
) -> None:
|
||||||
|
"""[Admin] Set a custom name for another user.
|
||||||
|
|
||||||
|
Usage: !setusername @user John
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(name) > 100:
|
||||||
|
await ctx.reply("Name is too long! Please use 100 characters or less.")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
db_user = await user_service.get_or_create_user(
|
||||||
|
discord_id=user.id,
|
||||||
|
username=user.name,
|
||||||
|
display_name=user.display_name,
|
||||||
|
)
|
||||||
|
await user_service.set_custom_name(user.id, name)
|
||||||
|
|
||||||
|
await ctx.reply(f"Got it! I'll call {user.mention} **{name}** from now on.")
|
||||||
|
|
||||||
|
@commands.command(name="teachbot")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
async def teach_bot(self, ctx: commands.Context, user: discord.Member, *, fact: str) -> None:
|
||||||
|
"""[Admin] Teach the bot a fact about a user.
|
||||||
|
|
||||||
|
Usage: !teachbot @user They are a software developer
|
||||||
|
"""
|
||||||
|
if not self._check_database():
|
||||||
|
await ctx.reply("Memory features are not available (database not configured).")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(fact) > 500:
|
||||||
|
await ctx.reply("That's too long! Please keep it under 500 characters.")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with db.session() as session:
|
||||||
|
user_service = UserService(session)
|
||||||
|
db_user = await user_service.get_or_create_user(
|
||||||
|
discord_id=user.id,
|
||||||
|
username=user.name,
|
||||||
|
display_name=user.display_name,
|
||||||
|
)
|
||||||
|
await user_service.add_fact(
|
||||||
|
user=db_user,
|
||||||
|
fact_type="general",
|
||||||
|
fact_content=fact,
|
||||||
|
source="admin",
|
||||||
|
confidence=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.reply(f"I'll remember that about {user.mention}!")
|
||||||
|
|
||||||
|
@set_user_name.error
|
||||||
|
@teach_bot.error
|
||||||
|
async def admin_command_error(
|
||||||
|
self, ctx: commands.Context, error: commands.CommandError
|
||||||
|
) -> None:
|
||||||
|
"""Handle errors for admin commands."""
|
||||||
|
if isinstance(error, commands.MissingPermissions):
|
||||||
|
await ctx.reply("You need administrator permissions to use this command.")
|
||||||
|
elif isinstance(error, commands.MemberNotFound):
|
||||||
|
await ctx.reply("I couldn't find that user.")
|
||||||
|
else:
|
||||||
|
logger.error(f"Error in admin command: {error}", exc_info=True)
|
||||||
|
await ctx.reply(f"An error occurred: {error}")
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: commands.Bot) -> None:
|
||||||
|
"""Load the Memory cog."""
|
||||||
|
await bot.add_cog(MemoryCog(bot))
|
||||||
@@ -67,6 +67,20 @@ class Settings(BaseSettings):
|
|||||||
max_conversation_history: int = Field(
|
max_conversation_history: int = Field(
|
||||||
20, description="Max messages to keep in conversation memory per user"
|
20, description="Max messages to keep in conversation memory per user"
|
||||||
)
|
)
|
||||||
|
conversation_timeout_minutes: int = Field(
|
||||||
|
60, ge=5, le=1440, description="Minutes of inactivity before starting new conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Database Configuration
|
||||||
|
database_url: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="PostgreSQL connection URL (asyncpg format). If not set, uses in-memory storage.",
|
||||||
|
)
|
||||||
|
database_echo: bool = Field(False, description="Echo SQL statements (for debugging)")
|
||||||
|
database_pool_size: int = Field(5, ge=1, le=20, description="Database connection pool size")
|
||||||
|
database_max_overflow: int = Field(
|
||||||
|
10, ge=0, le=30, description="Max connections beyond pool size"
|
||||||
|
)
|
||||||
|
|
||||||
# SearXNG Configuration
|
# SearXNG Configuration
|
||||||
searxng_url: str | None = Field(None, description="SearXNG instance URL for web search")
|
searxng_url: str | None = Field(None, description="SearXNG instance URL for web search")
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""Database models."""
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
from .conversation import Conversation, Message
|
||||||
|
from .guild import Guild, GuildMember
|
||||||
|
from .user import User, UserFact, UserPreference
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"Conversation",
|
||||||
|
"Guild",
|
||||||
|
"GuildMember",
|
||||||
|
"Message",
|
||||||
|
"User",
|
||||||
|
"UserFact",
|
||||||
|
"UserPreference",
|
||||||
|
]
|
||||||
|
|||||||
28
src/daemon_boyfriend/models/base.py
Normal file
28
src/daemon_boyfriend/models/base.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""SQLAlchemy base model and metadata configuration."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import MetaData
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
# Naming convention for constraints (helps with migrations)
|
||||||
|
convention = {
|
||||||
|
"ix": "ix_%(column_0_label)s",
|
||||||
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||||
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||||
|
"pk": "pk_%(table_name)s",
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = MetaData(naming_convention=convention)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(AsyncAttrs, DeclarativeBase):
|
||||||
|
"""Base class for all database models."""
|
||||||
|
|
||||||
|
metadata = metadata
|
||||||
|
|
||||||
|
# Common timestamp columns
|
||||||
|
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
57
src/daemon_boyfriend/models/conversation.py
Normal file
57
src/daemon_boyfriend/models/conversation.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Conversation and message database models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, Integer, String, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(Base):
|
||||||
|
"""A conversation session with a user."""
|
||||||
|
|
||||||
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||||
|
guild_id: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
channel_id: Mapped[int | None] = mapped_column(BigInteger, index=True)
|
||||||
|
started_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
last_message_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, index=True)
|
||||||
|
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user: Mapped["User"] = relationship(back_populates="conversations")
|
||||||
|
messages: Mapped[list["Message"]] = relationship(
|
||||||
|
back_populates="conversation",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
order_by="Message.created_at",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Message(Base):
|
||||||
|
"""Individual chat message."""
|
||||||
|
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
conversation_id: Mapped[int] = mapped_column(
|
||||||
|
ForeignKey("conversations.id", ondelete="CASCADE"), index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||||
|
discord_message_id: Mapped[int | None] = mapped_column(BigInteger)
|
||||||
|
role: Mapped[str] = mapped_column(String(20)) # user, assistant, system
|
||||||
|
content: Mapped[str] = mapped_column(Text)
|
||||||
|
has_images: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
image_urls: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
|
||||||
|
token_count: Mapped[int | None] = mapped_column(Integer)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||||
|
user: Mapped["User"] = relationship(back_populates="messages")
|
||||||
50
src/daemon_boyfriend/models/guild.py
Normal file
50
src/daemon_boyfriend/models/guild.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Guild (Discord server) database models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, String, Text, UniqueConstraint
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
class Guild(Base):
|
||||||
|
"""Discord server tracking."""
|
||||||
|
|
||||||
|
__tablename__ = "guilds"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255))
|
||||||
|
joined_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
settings: Mapped[dict] = mapped_column(JSONB, default=dict)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
members: Mapped[list["GuildMember"]] = relationship(
|
||||||
|
back_populates="guild", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GuildMember(Base):
|
||||||
|
"""Track users per guild with guild-specific settings."""
|
||||||
|
|
||||||
|
__tablename__ = "guild_members"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
guild_id: Mapped[int] = mapped_column(ForeignKey("guilds.id", ondelete="CASCADE"), index=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||||
|
guild_nickname: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
roles: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
|
||||||
|
joined_guild_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
guild: Mapped["Guild"] = relationship(back_populates="members")
|
||||||
|
user: Mapped["User"] = relationship(back_populates="guild_memberships")
|
||||||
|
|
||||||
|
__table_args__ = (UniqueConstraint("guild_id", "user_id"),)
|
||||||
83
src/daemon_boyfriend/models/user.py
Normal file
83
src/daemon_boyfriend/models/user.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""User-related database models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Boolean, Float, ForeignKey, String, Text, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .conversation import Conversation, Message
|
||||||
|
from .guild import GuildMember
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
"""Discord user tracking."""
|
||||||
|
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
|
||||||
|
discord_username: Mapped[str] = mapped_column(String(255))
|
||||||
|
discord_display_name: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
custom_name: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
first_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
last_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
preferences: Mapped[list["UserPreference"]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
facts: Mapped[list["UserFact"]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
conversations: Mapped[list["Conversation"]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
messages: Mapped[list["Message"]] = relationship(back_populates="user")
|
||||||
|
guild_memberships: Mapped[list["GuildMember"]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Get the name to use when addressing this user."""
|
||||||
|
return self.custom_name or self.discord_display_name or self.discord_username
|
||||||
|
|
||||||
|
|
||||||
|
class UserPreference(Base):
|
||||||
|
"""Per-user preferences/settings."""
|
||||||
|
|
||||||
|
__tablename__ = "user_preferences"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||||
|
preference_key: Mapped[str] = mapped_column(String(100))
|
||||||
|
preference_value: Mapped[str | None] = mapped_column(Text)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user: Mapped["User"] = relationship(back_populates="preferences")
|
||||||
|
|
||||||
|
__table_args__ = (UniqueConstraint("user_id", "preference_key"),)
|
||||||
|
|
||||||
|
|
||||||
|
class UserFact(Base):
|
||||||
|
"""Facts the bot has learned about users."""
|
||||||
|
|
||||||
|
__tablename__ = "user_facts"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||||
|
fact_type: Mapped[str] = mapped_column(String(50), index=True) # general, preference, hobby
|
||||||
|
fact_content: Mapped[str] = mapped_column(Text)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, default=1.0)
|
||||||
|
source: Mapped[str] = mapped_column(String(50), default="conversation")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
|
||||||
|
learned_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
last_referenced_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user: Mapped["User"] = relationship(back_populates="facts")
|
||||||
@@ -2,14 +2,22 @@
|
|||||||
|
|
||||||
from .ai_service import AIService
|
from .ai_service import AIService
|
||||||
from .conversation import ConversationManager
|
from .conversation import ConversationManager
|
||||||
|
from .database import DatabaseService, db, get_db
|
||||||
|
from .persistent_conversation import PersistentConversationManager
|
||||||
from .providers import AIResponse, ImageAttachment, Message
|
from .providers import AIResponse, ImageAttachment, Message
|
||||||
from .searxng import SearXNGService
|
from .searxng import SearXNGService
|
||||||
|
from .user_service import UserService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIService",
|
"AIService",
|
||||||
"AIResponse",
|
"AIResponse",
|
||||||
|
"ConversationManager",
|
||||||
|
"DatabaseService",
|
||||||
"ImageAttachment",
|
"ImageAttachment",
|
||||||
"Message",
|
"Message",
|
||||||
"ConversationManager",
|
"PersistentConversationManager",
|
||||||
"SearXNGService",
|
"SearXNGService",
|
||||||
|
"UserService",
|
||||||
|
"db",
|
||||||
|
"get_db",
|
||||||
]
|
]
|
||||||
|
|||||||
94
src/daemon_boyfriend/services/database.py
Normal file
94
src/daemon_boyfriend/services/database.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Database connection and session management."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from daemon_boyfriend.config import settings
|
||||||
|
from daemon_boyfriend.models.base import Base
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseService:
|
||||||
|
"""Manages database connections and sessions."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._engine = None
|
||||||
|
self._session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""Check if database URL is configured."""
|
||||||
|
return settings.database_url is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if database has been initialized."""
|
||||||
|
return self._initialized
|
||||||
|
|
||||||
|
async def init(self) -> None:
|
||||||
|
"""Initialize database connection."""
|
||||||
|
if not self.is_configured:
|
||||||
|
logger.info("Database URL not configured, skipping database initialization")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Initializing database connection...")
|
||||||
|
self._engine = create_async_engine(
|
||||||
|
settings.database_url,
|
||||||
|
echo=settings.database_echo,
|
||||||
|
pool_size=settings.database_pool_size,
|
||||||
|
max_overflow=settings.database_max_overflow,
|
||||||
|
)
|
||||||
|
self._session_factory = async_sessionmaker(
|
||||||
|
self._engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Database connection initialized")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
if self._engine:
|
||||||
|
logger.info("Closing database connection...")
|
||||||
|
await self._engine.dispose()
|
||||||
|
self._engine = None
|
||||||
|
self._session_factory = None
|
||||||
|
self._initialized = False
|
||||||
|
logger.info("Database connection closed")
|
||||||
|
|
||||||
|
async def create_tables(self) -> None:
|
||||||
|
"""Create all tables (for development/testing)."""
|
||||||
|
if not self._engine:
|
||||||
|
raise RuntimeError("Database not initialized")
|
||||||
|
|
||||||
|
async with self._engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
logger.info("Database tables created")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Get a database session with automatic commit/rollback."""
|
||||||
|
if not self._session_factory:
|
||||||
|
raise RuntimeError("Database not initialized. Call init() first.")
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
db = DatabaseService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> DatabaseService:
|
||||||
|
"""Get the global database service instance."""
|
||||||
|
return db
|
||||||
188
src/daemon_boyfriend/services/persistent_conversation.py
Normal file
188
src/daemon_boyfriend/services/persistent_conversation.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""Persistent conversation management using PostgreSQL."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from daemon_boyfriend.config import settings
|
||||||
|
from daemon_boyfriend.models import Conversation, Message, User
|
||||||
|
from daemon_boyfriend.services.providers import Message as ProviderMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PersistentConversationManager:
|
||||||
|
"""Manages conversation history in PostgreSQL."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession, max_history: int | None = None) -> None:
|
||||||
|
self._session = session
|
||||||
|
self.max_history = max_history or settings.max_conversation_history
|
||||||
|
self._timeout = timedelta(minutes=settings.conversation_timeout_minutes)
|
||||||
|
|
||||||
|
async def get_or_create_conversation(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
guild_id: int | None = None,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
) -> Conversation:
|
||||||
|
"""Get active conversation or create new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
guild_id: Discord guild ID (None for DMs)
|
||||||
|
channel_id: Discord channel ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Conversation model instance
|
||||||
|
"""
|
||||||
|
# Look for recent active conversation in this channel
|
||||||
|
cutoff = datetime.utcnow() - self._timeout
|
||||||
|
|
||||||
|
stmt = select(Conversation).where(
|
||||||
|
Conversation.user_id == user.id,
|
||||||
|
Conversation.is_active == True, # noqa: E712
|
||||||
|
Conversation.last_message_at > cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
stmt = stmt.where(Conversation.channel_id == channel_id)
|
||||||
|
|
||||||
|
stmt = stmt.order_by(Conversation.last_message_at.desc())
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
conversation = result.scalar_first()
|
||||||
|
|
||||||
|
if conversation:
|
||||||
|
logger.debug(
|
||||||
|
f"Found existing conversation {conversation.id} for user {user.discord_id}"
|
||||||
|
)
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
# Create new conversation
|
||||||
|
conversation = Conversation(
|
||||||
|
user_id=user.id,
|
||||||
|
guild_id=guild_id,
|
||||||
|
channel_id=channel_id,
|
||||||
|
)
|
||||||
|
self._session.add(conversation)
|
||||||
|
await self._session.flush()
|
||||||
|
logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}")
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
async def get_history(self, conversation: Conversation) -> list[ProviderMessage]:
|
||||||
|
"""Get conversation history as provider messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Conversation model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ProviderMessage instances for the AI
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(Message)
|
||||||
|
.where(Message.conversation_id == conversation.id)
|
||||||
|
.order_by(Message.created_at.desc())
|
||||||
|
.limit(self.max_history)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
messages = list(reversed(result.scalars().all()))
|
||||||
|
|
||||||
|
return [
|
||||||
|
ProviderMessage(
|
||||||
|
role=msg.role,
|
||||||
|
content=msg.content,
|
||||||
|
# Note: Images would need special handling if needed
|
||||||
|
)
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
async def add_message(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
user: User,
|
||||||
|
role: str,
|
||||||
|
content: str,
|
||||||
|
discord_message_id: int | None = None,
|
||||||
|
image_urls: list[str] | None = None,
|
||||||
|
) -> Message:
|
||||||
|
"""Add a message to the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Conversation model instance
|
||||||
|
user: User model instance
|
||||||
|
role: Message role (user, assistant, system)
|
||||||
|
content: Message content
|
||||||
|
discord_message_id: Discord's message ID (optional)
|
||||||
|
image_urls: List of image URLs (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created Message instance
|
||||||
|
"""
|
||||||
|
message = Message(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
discord_message_id=discord_message_id,
|
||||||
|
has_images=bool(image_urls),
|
||||||
|
image_urls=image_urls,
|
||||||
|
)
|
||||||
|
self._session.add(message)
|
||||||
|
|
||||||
|
# Update conversation stats
|
||||||
|
conversation.last_message_at = datetime.utcnow()
|
||||||
|
conversation.message_count += 1
|
||||||
|
|
||||||
|
await self._session.flush()
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def add_exchange(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
user: User,
|
||||||
|
user_message: str,
|
||||||
|
assistant_message: str,
|
||||||
|
discord_message_id: int | None = None,
|
||||||
|
image_urls: list[str] | None = None,
|
||||||
|
) -> tuple[Message, Message]:
|
||||||
|
"""Add a user/assistant exchange to the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Conversation model instance
|
||||||
|
user: User model instance
|
||||||
|
user_message: The user's message content
|
||||||
|
assistant_message: The assistant's response
|
||||||
|
discord_message_id: Discord's message ID (optional)
|
||||||
|
image_urls: List of image URLs in user message (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (user_message, assistant_message) Message instances
|
||||||
|
"""
|
||||||
|
user_msg = await self.add_message(
|
||||||
|
conversation, user, "user", user_message, discord_message_id, image_urls
|
||||||
|
)
|
||||||
|
assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message)
|
||||||
|
return user_msg, assistant_msg
|
||||||
|
|
||||||
|
async def clear_conversation(self, conversation: Conversation) -> None:
|
||||||
|
"""Mark a conversation as inactive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Conversation model instance
|
||||||
|
"""
|
||||||
|
conversation.is_active = False
|
||||||
|
logger.info(f"Marked conversation {conversation.id} as inactive")
|
||||||
|
|
||||||
|
async def get_message_count(self, conversation: Conversation) -> int:
|
||||||
|
"""Get the number of messages in a conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Conversation model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of messages
|
||||||
|
"""
|
||||||
|
return conversation.message_count
|
||||||
250
src/daemon_boyfriend/services/user_service.py
Normal file
250
src/daemon_boyfriend/services/user_service.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""User management service."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from daemon_boyfriend.models import User, UserFact, UserPreference
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Service for user-related operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def get_or_create_user(
|
||||||
|
self,
|
||||||
|
discord_id: int,
|
||||||
|
username: str,
|
||||||
|
display_name: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
"""Get existing user or create new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discord_id: Discord user ID
|
||||||
|
username: Discord username
|
||||||
|
display_name: Discord display name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User model instance
|
||||||
|
"""
|
||||||
|
stmt = select(User).where(User.discord_id == discord_id)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if user:
|
||||||
|
# Update last seen and current name
|
||||||
|
user.last_seen_at = datetime.utcnow()
|
||||||
|
user.discord_username = username
|
||||||
|
if display_name:
|
||||||
|
user.discord_display_name = display_name
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Create new user
|
||||||
|
user = User(
|
||||||
|
discord_id=discord_id,
|
||||||
|
discord_username=username,
|
||||||
|
discord_display_name=display_name,
|
||||||
|
)
|
||||||
|
self._session.add(user)
|
||||||
|
await self._session.flush()
|
||||||
|
logger.info(f"Created new user: {username} (discord_id={discord_id})")
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_user_by_discord_id(self, discord_id: int) -> User | None:
|
||||||
|
"""Get a user by their Discord ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discord_id: Discord user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User if found, None otherwise
|
||||||
|
"""
|
||||||
|
stmt = select(User).where(User.discord_id == discord_id)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def set_custom_name(self, discord_id: int, custom_name: str | None) -> User | None:
|
||||||
|
"""Set a custom name for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discord_id: Discord user ID
|
||||||
|
custom_name: Custom name to use, or None to clear
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated user if found, None otherwise
|
||||||
|
"""
|
||||||
|
user = await self.get_user_by_discord_id(discord_id)
|
||||||
|
if user:
|
||||||
|
user.custom_name = custom_name
|
||||||
|
logger.info(f"Set custom name for user {discord_id}: {custom_name}")
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def add_fact(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
fact_type: str,
|
||||||
|
fact_content: str,
|
||||||
|
source: str = "conversation",
|
||||||
|
confidence: float = 1.0,
|
||||||
|
) -> UserFact:
|
||||||
|
"""Add a fact about a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
fact_type: Type of fact (general, preference, hobby, relationship, etc.)
|
||||||
|
fact_content: The fact content
|
||||||
|
source: How the fact was learned (conversation, explicit, inferred)
|
||||||
|
confidence: Confidence level (0.0 to 1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created UserFact instance
|
||||||
|
"""
|
||||||
|
fact = UserFact(
|
||||||
|
user_id=user.id,
|
||||||
|
fact_type=fact_type,
|
||||||
|
fact_content=fact_content,
|
||||||
|
source=source,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
self._session.add(fact)
|
||||||
|
await self._session.flush()
|
||||||
|
logger.debug(f"Added fact for user {user.discord_id}: [{fact_type}] {fact_content}")
|
||||||
|
return fact
|
||||||
|
|
||||||
|
async def get_user_facts(
|
||||||
|
self,
|
||||||
|
user: User,
|
||||||
|
fact_type: str | None = None,
|
||||||
|
active_only: bool = True,
|
||||||
|
) -> list[UserFact]:
|
||||||
|
"""Get facts about a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
fact_type: Optional filter by fact type
|
||||||
|
active_only: Only return active facts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of UserFact instances
|
||||||
|
"""
|
||||||
|
stmt = select(UserFact).where(UserFact.user_id == user.id)
|
||||||
|
|
||||||
|
if active_only:
|
||||||
|
stmt = stmt.where(UserFact.is_active == True) # noqa: E712
|
||||||
|
|
||||||
|
if fact_type:
|
||||||
|
stmt = stmt.where(UserFact.fact_type == fact_type)
|
||||||
|
|
||||||
|
stmt = stmt.order_by(UserFact.learned_at.desc())
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def delete_user_facts(self, user: User) -> int:
|
||||||
|
"""Soft-delete all facts for a user (set is_active=False).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of facts deactivated
|
||||||
|
"""
|
||||||
|
facts = await self.get_user_facts(user, active_only=True)
|
||||||
|
for fact in facts:
|
||||||
|
fact.is_active = False
|
||||||
|
logger.info(f"Deactivated {len(facts)} facts for user {user.discord_id}")
|
||||||
|
return len(facts)
|
||||||
|
|
||||||
|
async def set_preference(self, user: User, key: str, value: str | None) -> UserPreference:
|
||||||
|
"""Set a user preference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
key: Preference key
|
||||||
|
value: Preference value, or None to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserPreference instance
|
||||||
|
"""
|
||||||
|
stmt = select(UserPreference).where(
|
||||||
|
UserPreference.user_id == user.id,
|
||||||
|
UserPreference.preference_key == key,
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
pref = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if pref:
|
||||||
|
pref.preference_value = value
|
||||||
|
else:
|
||||||
|
pref = UserPreference(
|
||||||
|
user_id=user.id,
|
||||||
|
preference_key=key,
|
||||||
|
preference_value=value,
|
||||||
|
)
|
||||||
|
self._session.add(pref)
|
||||||
|
await self._session.flush()
|
||||||
|
|
||||||
|
return pref
|
||||||
|
|
||||||
|
async def get_preference(self, user: User, key: str) -> str | None:
|
||||||
|
"""Get a user preference value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
key: Preference key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preference value or None if not set
|
||||||
|
"""
|
||||||
|
stmt = select(UserPreference).where(
|
||||||
|
UserPreference.user_id == user.id,
|
||||||
|
UserPreference.preference_key == key,
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
pref = result.scalar_one_or_none()
|
||||||
|
return pref.preference_value if pref else None
|
||||||
|
|
||||||
|
async def get_user_context(self, user: User) -> str:
|
||||||
|
"""Build a context string about a user for the AI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
|
"""
|
||||||
|
lines = [f"User's preferred name: {user.display_name}"]
|
||||||
|
|
||||||
|
if user.custom_name:
|
||||||
|
lines.append(f"(You should address them as: {user.custom_name})")
|
||||||
|
|
||||||
|
facts = await self.get_user_facts(user, active_only=True)
|
||||||
|
|
||||||
|
if facts:
|
||||||
|
lines.append("\nKnown facts about this user:")
|
||||||
|
for fact in facts[:20]: # Limit to 20 most recent facts
|
||||||
|
lines.append(f"- [{fact.fact_type}] {fact.fact_content}")
|
||||||
|
# Mark as referenced
|
||||||
|
fact.last_referenced_at = datetime.utcnow()
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def get_user_with_facts(self, discord_id: int) -> User | None:
|
||||||
|
"""Get a user with their facts eagerly loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discord_id: Discord user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User with facts loaded, or None if not found
|
||||||
|
"""
|
||||||
|
stmt = select(User).where(User.discord_id == discord_id).options(selectinload(User.facts))
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
Reference in New Issue
Block a user