Add PostgreSQL memory system for persistent user and conversation storage
- Add PostgreSQL database with SQLAlchemy async support - Create models: User, UserFact, UserPreference, Conversation, Message, Guild, GuildMember - Add custom name support so bot knows 'who is who' - Add user facts system for remembering things about users - Add persistent conversation history that survives restarts - Add memory commands cog (!setname, !remember, !whatdoyouknow, !forgetme) - Add admin commands (!setusername, !teachbot) - Set up Alembic for database migrations - Update docker-compose with PostgreSQL service - Gracefully falls back to in-memory storage when DB not configured
This commit is contained in:
43
CLAUDE.md
43
CLAUDE.md
@@ -11,9 +11,12 @@ pip install -r requirements.txt
|
||||
# Run the bot (requires .env with DISCORD_TOKEN and AI provider key)
|
||||
python -m daemon_boyfriend
|
||||
|
||||
# Run with Docker
|
||||
# Run with Docker (includes PostgreSQL)
|
||||
docker-compose up -d
|
||||
|
||||
# Run database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Syntax check all Python files
|
||||
python -m py_compile src/daemon_boyfriend/**/*.py
|
||||
```
|
||||
@@ -33,9 +36,25 @@ The AI system uses a provider abstraction pattern:
|
||||
### Cog System
|
||||
Discord functionality is in `cogs/`:
|
||||
- `ai_chat.py` - `@mention` handler (responds when bot is mentioned)
|
||||
- `memory.py` - Memory management commands (`!setname`, `!remember`, etc.)
|
||||
- `status.py` - Bot health and status commands
|
||||
|
||||
Cogs are auto-loaded by `bot.py` from the `cogs/` directory.
|
||||
|
||||
### Database & Memory System
|
||||
The bot uses PostgreSQL for persistent memory (optional, falls back to in-memory):
|
||||
- `models/` - SQLAlchemy models (User, UserFact, Conversation, Message, Guild, GuildMember)
|
||||
- `services/database.py` - Connection pool and async session management
|
||||
- `services/user_service.py` - User CRUD, custom names, facts management
|
||||
- `services/persistent_conversation.py` - Database-backed conversation history
|
||||
- `alembic/` - Database migrations
|
||||
|
||||
Key features:
|
||||
- Custom names: Set preferred names for users so the bot knows "who is who"
|
||||
- User facts: Bot remembers things about users (hobbies, preferences, etc.)
|
||||
- Persistent conversations: Chat history survives restarts
|
||||
- Conversation timeout: New conversation starts after 60 minutes of inactivity
|
||||
|
||||
### Configuration
|
||||
All config flows through `config.py` using pydantic-settings. The `settings` singleton is created at module load, so env vars must be set before importing.
|
||||
|
||||
@@ -47,13 +66,31 @@ The bot can search the web for current information via SearXNG:
|
||||
- Configured via `SEARXNG_URL`, `SEARXNG_ENABLED`, and `SEARXNG_MAX_RESULTS` env vars
|
||||
|
||||
### Key Design Decisions
|
||||
- `ConversationManager` stores per-user chat history in memory with configurable max length
|
||||
- `PersistentConversationManager` stores conversations in PostgreSQL when `DATABASE_URL` is set
|
||||
- `ConversationManager` is the in-memory fallback when database is not configured
|
||||
- Long AI responses are split via `split_message()` in `ai_chat.py` to respect Discord's 2000 char limit
|
||||
- The bot responds only to @mentions via `on_message` listener
|
||||
- Web search uses AI to decide when to search, avoiding unnecessary API calls for general knowledge questions
|
||||
- User context (custom name + known facts) is included in AI prompts for personalized responses
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Required: `DISCORD_TOKEN`, plus one of `OPENAI_API_KEY`, `OPENROUTER_API_KEY`, `ANTHROPIC_API_KEY`, or `GEMINI_API_KEY` depending on `AI_PROVIDER` setting.
|
||||
|
||||
Optional: `SEARXNG_URL` for web search capability.
|
||||
Optional:
|
||||
- `DATABASE_URL` - PostgreSQL connection string (e.g., `postgresql+asyncpg://user:pass@host:5432/db`)
|
||||
- `POSTGRES_PASSWORD` - Used by docker-compose for the PostgreSQL container
|
||||
- `SEARXNG_URL` - SearXNG instance URL for web search capability
|
||||
|
||||
## Memory Commands
|
||||
|
||||
User commands:
|
||||
- `!setname <name>` - Set your preferred name
|
||||
- `!clearname` - Reset to Discord display name
|
||||
- `!remember <fact>` - Tell the bot something about you
|
||||
- `!whatdoyouknow` - See what the bot remembers about you
|
||||
- `!forgetme` - Clear all facts about you
|
||||
|
||||
Admin commands:
|
||||
- `!setusername @user <name>` - Set name for another user
|
||||
- `!teachbot @user <fact>` - Add a fact about a user
|
||||
|
||||
71
alembic.ini
Normal file
71
alembic.ini
Normal file
@@ -0,0 +1,71 @@
|
||||
# Alembic Configuration File
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration file names
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path prepend for the migration environment
|
||||
prepend_sys_path = src
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
timezone = UTC
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during the 'revision' command
|
||||
revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without a source .py file
|
||||
sourceless = false
|
||||
|
||||
# version number format
|
||||
version_num_width = 4
|
||||
|
||||
# version path separator
|
||||
version_path_separator = os
|
||||
|
||||
# output encoding used when revision files are written
|
||||
output_encoding = utf-8
|
||||
|
||||
# Database URL - will be overridden by env.py from settings
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
86
alembic/env.py
Normal file
86
alembic/env.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Alembic migration environment configuration."""
|
||||
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
from daemon_boyfriend.config import settings
|
||||
from daemon_boyfriend.models import Base
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Model metadata for autogenerate support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def get_url() -> str:
|
||||
"""Get database URL from settings."""
|
||||
url = settings.database_url
|
||||
if not url:
|
||||
raise ValueError("DATABASE_URL must be set for migrations")
|
||||
return url
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL and not an Engine,
|
||||
though an Engine is acceptable here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the script output.
|
||||
"""
|
||||
url = get_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection) -> None:
|
||||
"""Run migrations with the given connection."""
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in async mode."""
|
||||
configuration = config.get_section(config.config_ini_section, {})
|
||||
configuration["sqlalchemy.url"] = get_url()
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
alembic/script.py.mako
Normal file
26
alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
200
alembic/versions/20250112_0001_initial_schema.py
Normal file
200
alembic/versions/20250112_0001_initial_schema.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Initial schema with users, conversations, messages, and guilds.
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Create Date: 2025-01-12
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Users table
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("discord_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("discord_username", sa.String(255), nullable=False),
|
||||
sa.Column("discord_display_name", sa.String(255), nullable=True),
|
||||
sa.Column("custom_name", sa.String(255), nullable=True),
|
||||
sa.Column("first_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("last_seen_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_users")),
|
||||
sa.UniqueConstraint("discord_id", name=op.f("uq_users_discord_id")),
|
||||
)
|
||||
op.create_index(op.f("ix_users_discord_id"), "users", ["discord_id"])
|
||||
op.create_index(op.f("ix_users_last_seen_at"), "users", ["last_seen_at"])
|
||||
|
||||
# User preferences table
|
||||
op.create_table(
|
||||
"user_preferences",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("preference_key", sa.String(100), nullable=False),
|
||||
sa.Column("preference_value", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
name=op.f("fk_user_preferences_user_id_users"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_user_preferences")),
|
||||
sa.UniqueConstraint("user_id", "preference_key", name="uq_user_preferences_user_key"),
|
||||
)
|
||||
op.create_index(op.f("ix_user_preferences_user_id"), "user_preferences", ["user_id"])
|
||||
|
||||
# User facts table
|
||||
op.create_table(
|
||||
"user_facts",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("fact_type", sa.String(50), nullable=False),
|
||||
sa.Column("fact_content", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Float(), server_default="1.0"),
|
||||
sa.Column("source", sa.String(50), server_default="conversation"),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("learned_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("last_referenced_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
name=op.f("fk_user_facts_user_id_users"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_user_facts")),
|
||||
)
|
||||
op.create_index(op.f("ix_user_facts_user_id"), "user_facts", ["user_id"])
|
||||
op.create_index(op.f("ix_user_facts_fact_type"), "user_facts", ["fact_type"])
|
||||
op.create_index(op.f("ix_user_facts_is_active"), "user_facts", ["is_active"])
|
||||
|
||||
# Guilds table
|
||||
op.create_table(
|
||||
"guilds",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("discord_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("joined_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("settings", postgresql.JSONB(), server_default="{}"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_guilds")),
|
||||
sa.UniqueConstraint("discord_id", name=op.f("uq_guilds_discord_id")),
|
||||
)
|
||||
op.create_index(op.f("ix_guilds_discord_id"), "guilds", ["discord_id"])
|
||||
|
||||
# Guild members table
|
||||
op.create_table(
|
||||
"guild_members",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("guild_nickname", sa.String(255), nullable=True),
|
||||
sa.Column("roles", postgresql.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("joined_guild_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.ForeignKeyConstraint(
|
||||
["guild_id"],
|
||||
["guilds.id"],
|
||||
name=op.f("fk_guild_members_guild_id_guilds"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
name=op.f("fk_guild_members_user_id_users"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_guild_members")),
|
||||
sa.UniqueConstraint("guild_id", "user_id", name="uq_guild_members_guild_user"),
|
||||
)
|
||||
op.create_index(op.f("ix_guild_members_guild_id"), "guild_members", ["guild_id"])
|
||||
op.create_index(op.f("ix_guild_members_user_id"), "guild_members", ["user_id"])
|
||||
|
||||
# Conversations table
|
||||
op.create_table(
|
||||
"conversations",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("last_message_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("message_count", sa.Integer(), server_default="0"),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
name=op.f("fk_conversations_user_id_users"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_conversations")),
|
||||
)
|
||||
op.create_index(op.f("ix_conversations_user_id"), "conversations", ["user_id"])
|
||||
op.create_index(op.f("ix_conversations_channel_id"), "conversations", ["channel_id"])
|
||||
op.create_index(op.f("ix_conversations_last_message_at"), "conversations", ["last_message_at"])
|
||||
op.create_index(op.f("ix_conversations_is_active"), "conversations", ["is_active"])
|
||||
|
||||
# Messages table
|
||||
op.create_table(
|
||||
"messages",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
||||
sa.Column("conversation_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("discord_message_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("role", sa.String(20), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("has_images", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column("image_urls", postgresql.ARRAY(sa.Text()), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.ForeignKeyConstraint(
|
||||
["conversation_id"],
|
||||
["conversations.id"],
|
||||
name=op.f("fk_messages_conversation_id_conversations"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
name=op.f("fk_messages_user_id_users"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("pk_messages")),
|
||||
)
|
||||
op.create_index(op.f("ix_messages_conversation_id"), "messages", ["conversation_id"])
|
||||
op.create_index(op.f("ix_messages_user_id"), "messages", ["user_id"])
|
||||
op.create_index(op.f("ix_messages_created_at"), "messages", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("messages")
|
||||
op.drop_table("conversations")
|
||||
op.drop_table("guild_members")
|
||||
op.drop_table("guilds")
|
||||
op.drop_table("user_facts")
|
||||
op.drop_table("user_preferences")
|
||||
op.drop_table("users")
|
||||
@@ -1,9 +1,32 @@
|
||||
services:
|
||||
daemon-boyfriend:
|
||||
build: .
|
||||
container_name: daemon-boyfriend
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
daemon-boyfriend:
|
||||
build: .
|
||||
container_name: daemon-boyfriend
|
||||
restart: unless-stopped
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- DATABASE_URL=postgresql+asyncpg://daemon:${POSTGRES_PASSWORD:-daemon}@postgres:5432/daemon_boyfriend
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: daemon-boyfriend-postgres
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
POSTGRES_USER: daemon
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-daemon}
|
||||
POSTGRES_DB: daemon_boyfriend
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U daemon -d daemon_boyfriend"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
|
||||
@@ -13,3 +13,8 @@ aiohttp>=3.9.0
|
||||
pydantic>=2.6.0
|
||||
pydantic-settings>=2.2.0
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# Database
|
||||
asyncpg>=0.29.0
|
||||
sqlalchemy[asyncio]>=2.0.0
|
||||
alembic>=1.13.0
|
||||
|
||||
@@ -7,6 +7,7 @@ import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
from daemon_boyfriend.services import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +28,14 @@ class DaemonBoyfriend(commands.Bot):
|
||||
)
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Load cogs on startup."""
|
||||
"""Initialize database and load cogs on startup."""
|
||||
# Initialize database if configured
|
||||
if db.is_configured:
|
||||
await db.init()
|
||||
logger.info("Database initialized")
|
||||
else:
|
||||
logger.info("Database not configured, using in-memory storage")
|
||||
|
||||
# Load all cogs
|
||||
cogs_path = Path(__file__).parent / "cogs"
|
||||
for cog_file in cogs_path.glob("*.py"):
|
||||
@@ -55,3 +63,10 @@ class DaemonBoyfriend(commands.Bot):
|
||||
name=settings.bot_status,
|
||||
)
|
||||
await self.change_presence(activity=activity)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources on shutdown."""
|
||||
logger.info("Shutting down bot...")
|
||||
if db.is_initialized:
|
||||
await db.close()
|
||||
await super().close()
|
||||
|
||||
@@ -12,7 +12,10 @@ from daemon_boyfriend.services import (
|
||||
ConversationManager,
|
||||
ImageAttachment,
|
||||
Message,
|
||||
PersistentConversationManager,
|
||||
SearXNGService,
|
||||
UserService,
|
||||
db,
|
||||
)
|
||||
from daemon_boyfriend.utils import get_monitor
|
||||
|
||||
@@ -77,11 +80,17 @@ class AIChatCog(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
self.ai_service = AIService()
|
||||
# Fallback in-memory conversation manager (used when DB not configured)
|
||||
self.conversations = ConversationManager()
|
||||
self.search_service: SearXNGService | None = None
|
||||
if settings.searxng_enabled and settings.searxng_url:
|
||||
self.search_service = SearXNGService(settings.searxng_url)
|
||||
|
||||
@property
|
||||
def use_database(self) -> bool:
|
||||
"""Check if database is available for use."""
|
||||
return db.is_initialized
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Respond when the bot is mentioned."""
|
||||
@@ -395,6 +404,95 @@ class AIChatCog(commands.Cog):
|
||||
Returns:
|
||||
The AI's response text
|
||||
"""
|
||||
if self.use_database:
|
||||
return await self._generate_response_with_db(message, user_message)
|
||||
else:
|
||||
return await self._generate_response_in_memory(message, user_message)
|
||||
|
||||
async def _generate_response_with_db(self, message: discord.Message, user_message: str) -> str:
|
||||
"""Generate response using database-backed storage."""
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
conv_manager = PersistentConversationManager(session)
|
||||
|
||||
# Get or create user
|
||||
user = await user_service.get_or_create_user(
|
||||
discord_id=message.author.id,
|
||||
username=message.author.name,
|
||||
display_name=message.author.display_name,
|
||||
)
|
||||
|
||||
# Get or create conversation
|
||||
conversation = await conv_manager.get_or_create_conversation(
|
||||
user=user,
|
||||
guild_id=message.guild.id if message.guild else None,
|
||||
channel_id=message.channel.id,
|
||||
)
|
||||
|
||||
# Get history
|
||||
history = await conv_manager.get_history(conversation)
|
||||
|
||||
# Extract any image attachments from the message
|
||||
images = self._extract_image_attachments(message)
|
||||
image_urls = [img.url for img in images] if images else None
|
||||
|
||||
# Add current message to history for the API call
|
||||
current_message = Message(role="user", content=user_message, images=images)
|
||||
messages = history + [current_message]
|
||||
|
||||
# Check if we should search the web
|
||||
search_context = await self._maybe_search(user_message)
|
||||
|
||||
# Get context about mentioned users
|
||||
mentioned_users_context = self._get_mentioned_users_context(message)
|
||||
|
||||
# Build system prompt with additional context
|
||||
system_prompt = self.ai_service.get_system_prompt()
|
||||
|
||||
# Add user context from database (custom name, known facts)
|
||||
user_context = await user_service.get_user_context(user)
|
||||
system_prompt += f"\n\n--- User Context ---\n{user_context}"
|
||||
|
||||
# Add mentioned users context
|
||||
if mentioned_users_context:
|
||||
system_prompt += f"\n\n--- {mentioned_users_context} ---"
|
||||
|
||||
# Add search results if available
|
||||
if search_context:
|
||||
system_prompt += (
|
||||
"\n\n--- Web Search Results ---\n"
|
||||
"Use the following current information from the web to help answer the user's question. "
|
||||
"Cite sources when relevant.\n\n"
|
||||
f"{search_context}"
|
||||
)
|
||||
|
||||
# Generate response
|
||||
response = await self.ai_service.chat(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
# Save the exchange to database
|
||||
await conv_manager.add_exchange(
|
||||
conversation=conversation,
|
||||
user=user,
|
||||
user_message=user_message,
|
||||
assistant_message=response.content,
|
||||
discord_message_id=message.id,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Generated response for user {user.discord_id}: "
|
||||
f"{len(response.content)} chars, {response.usage}"
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
async def _generate_response_in_memory(
|
||||
self, message: discord.Message, user_message: str
|
||||
) -> str:
|
||||
"""Generate response using in-memory storage (fallback)."""
|
||||
user_id = message.author.id
|
||||
|
||||
# Get conversation history
|
||||
|
||||
260
src/daemon_boyfriend/cogs/memory.py
Normal file
260
src/daemon_boyfriend/cogs/memory.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Memory management cog - commands for managing bot memory about users."""
|
||||
|
||||
import logging
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from daemon_boyfriend.services import UserService, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryCog(commands.Cog):
|
||||
"""Commands for managing bot memory about users."""
|
||||
|
||||
def __init__(self, bot: commands.Bot) -> None:
|
||||
self.bot = bot
|
||||
|
||||
def _check_database(self) -> bool:
|
||||
"""Check if database is available."""
|
||||
return db.is_initialized
|
||||
|
||||
@commands.command(name="setname")
|
||||
async def set_name(self, ctx: commands.Context, *, name: str) -> None:
|
||||
"""Set your preferred name for the bot to use.
|
||||
|
||||
Usage: !setname John
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
if len(name) > 100:
|
||||
await ctx.reply("Name is too long! Please use 100 characters or less.")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_or_create_user(
|
||||
discord_id=ctx.author.id,
|
||||
username=ctx.author.name,
|
||||
display_name=ctx.author.display_name,
|
||||
)
|
||||
await user_service.set_custom_name(ctx.author.id, name)
|
||||
|
||||
await ctx.reply(f"Got it! I'll call you **{name}** from now on.")
|
||||
|
||||
@commands.command(name="clearname")
|
||||
async def clear_name(self, ctx: commands.Context) -> None:
|
||||
"""Clear your custom name and use your Discord name instead.
|
||||
|
||||
Usage: !clearname
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
await user_service.set_custom_name(ctx.author.id, None)
|
||||
|
||||
await ctx.reply("Done! I'll use your Discord display name now.")
|
||||
|
||||
@commands.command(name="remember")
|
||||
async def remember_fact(self, ctx: commands.Context, *, fact: str) -> None:
|
||||
"""Tell the bot something to remember about you.
|
||||
|
||||
Usage: !remember I love pizza
|
||||
Usage: !remember My favorite color is blue
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
if len(fact) > 500:
|
||||
await ctx.reply("That's too long to remember! Please keep it under 500 characters.")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_or_create_user(
|
||||
discord_id=ctx.author.id,
|
||||
username=ctx.author.name,
|
||||
display_name=ctx.author.display_name,
|
||||
)
|
||||
await user_service.add_fact(
|
||||
user=user,
|
||||
fact_type="general",
|
||||
fact_content=fact,
|
||||
source="explicit",
|
||||
confidence=1.0,
|
||||
)
|
||||
|
||||
await ctx.reply(f"I'll remember that!")
|
||||
|
||||
@commands.command(name="whatdoyouknow", aliases=["aboutme", "myinfo"])
|
||||
async def what_do_you_know(self, ctx: commands.Context) -> None:
|
||||
"""Show what the bot remembers about you.
|
||||
|
||||
Usage: !whatdoyouknow
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_discord_id(ctx.author.id)
|
||||
|
||||
if not user:
|
||||
await ctx.reply("I don't have any information about you yet!")
|
||||
return
|
||||
|
||||
facts = await user_service.get_user_facts(user, active_only=True)
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"What I know about {user.display_name}",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Discord Username",
|
||||
value=user.discord_username,
|
||||
inline=True,
|
||||
)
|
||||
|
||||
if user.custom_name:
|
||||
embed.add_field(
|
||||
name="Preferred Name",
|
||||
value=user.custom_name,
|
||||
inline=True,
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="First Seen",
|
||||
value=user.first_seen_at.strftime("%Y-%m-%d"),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
if facts:
|
||||
facts_text = "\n".join(f"- {fact.fact_content}" for fact in facts[:15])
|
||||
if len(facts) > 15:
|
||||
facts_text += f"\n... and {len(facts) - 15} more"
|
||||
embed.add_field(
|
||||
name=f"Things I Remember ({len(facts)})",
|
||||
value=facts_text or "Nothing yet!",
|
||||
inline=False,
|
||||
)
|
||||
else:
|
||||
embed.add_field(
|
||||
name="Things I Remember",
|
||||
value="Nothing yet! Use `!remember` to tell me something.",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@commands.command(name="forgetme")
|
||||
async def forget_me(self, ctx: commands.Context) -> None:
|
||||
"""Clear all facts the bot knows about you.
|
||||
|
||||
Usage: !forgetme
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
user = await user_service.get_user_by_discord_id(ctx.author.id)
|
||||
|
||||
if not user:
|
||||
await ctx.reply("I don't have any information about you to forget!")
|
||||
return
|
||||
|
||||
count = await user_service.delete_user_facts(user)
|
||||
|
||||
if count > 0:
|
||||
await ctx.reply(f"Done! I've forgotten {count} thing(s) about you.")
|
||||
else:
|
||||
await ctx.reply("I didn't have anything to forget about you!")
|
||||
|
||||
@commands.command(name="setusername")
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def set_user_name(
|
||||
self, ctx: commands.Context, user: discord.Member, *, name: str
|
||||
) -> None:
|
||||
"""[Admin] Set a custom name for another user.
|
||||
|
||||
Usage: !setusername @user John
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
if len(name) > 100:
|
||||
await ctx.reply("Name is too long! Please use 100 characters or less.")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
db_user = await user_service.get_or_create_user(
|
||||
discord_id=user.id,
|
||||
username=user.name,
|
||||
display_name=user.display_name,
|
||||
)
|
||||
await user_service.set_custom_name(user.id, name)
|
||||
|
||||
await ctx.reply(f"Got it! I'll call {user.mention} **{name}** from now on.")
|
||||
|
||||
@commands.command(name="teachbot")
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def teach_bot(self, ctx: commands.Context, user: discord.Member, *, fact: str) -> None:
|
||||
"""[Admin] Teach the bot a fact about a user.
|
||||
|
||||
Usage: !teachbot @user They are a software developer
|
||||
"""
|
||||
if not self._check_database():
|
||||
await ctx.reply("Memory features are not available (database not configured).")
|
||||
return
|
||||
|
||||
if len(fact) > 500:
|
||||
await ctx.reply("That's too long! Please keep it under 500 characters.")
|
||||
return
|
||||
|
||||
async with db.session() as session:
|
||||
user_service = UserService(session)
|
||||
db_user = await user_service.get_or_create_user(
|
||||
discord_id=user.id,
|
||||
username=user.name,
|
||||
display_name=user.display_name,
|
||||
)
|
||||
await user_service.add_fact(
|
||||
user=db_user,
|
||||
fact_type="general",
|
||||
fact_content=fact,
|
||||
source="admin",
|
||||
confidence=1.0,
|
||||
)
|
||||
|
||||
await ctx.reply(f"I'll remember that about {user.mention}!")
|
||||
|
||||
@set_user_name.error
|
||||
@teach_bot.error
|
||||
async def admin_command_error(
|
||||
self, ctx: commands.Context, error: commands.CommandError
|
||||
) -> None:
|
||||
"""Handle errors for admin commands."""
|
||||
if isinstance(error, commands.MissingPermissions):
|
||||
await ctx.reply("You need administrator permissions to use this command.")
|
||||
elif isinstance(error, commands.MemberNotFound):
|
||||
await ctx.reply("I couldn't find that user.")
|
||||
else:
|
||||
logger.error(f"Error in admin command: {error}", exc_info=True)
|
||||
await ctx.reply(f"An error occurred: {error}")
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot) -> None:
|
||||
"""Load the Memory cog."""
|
||||
await bot.add_cog(MemoryCog(bot))
|
||||
@@ -67,6 +67,20 @@ class Settings(BaseSettings):
|
||||
max_conversation_history: int = Field(
|
||||
20, description="Max messages to keep in conversation memory per user"
|
||||
)
|
||||
conversation_timeout_minutes: int = Field(
|
||||
60, ge=5, le=1440, description="Minutes of inactivity before starting new conversation"
|
||||
)
|
||||
|
||||
# Database Configuration
|
||||
database_url: str | None = Field(
|
||||
None,
|
||||
description="PostgreSQL connection URL (asyncpg format). If not set, uses in-memory storage.",
|
||||
)
|
||||
database_echo: bool = Field(False, description="Echo SQL statements (for debugging)")
|
||||
database_pool_size: int = Field(5, ge=1, le=20, description="Database connection pool size")
|
||||
database_max_overflow: int = Field(
|
||||
10, ge=0, le=30, description="Max connections beyond pool size"
|
||||
)
|
||||
|
||||
# SearXNG Configuration
|
||||
searxng_url: str | None = Field(None, description="SearXNG instance URL for web search")
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Database models."""
|
||||
|
||||
from .base import Base
|
||||
from .conversation import Conversation, Message
|
||||
from .guild import Guild, GuildMember
|
||||
from .user import User, UserFact, UserPreference
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Conversation",
|
||||
"Guild",
|
||||
"GuildMember",
|
||||
"Message",
|
||||
"User",
|
||||
"UserFact",
|
||||
"UserPreference",
|
||||
]
|
||||
|
||||
28
src/daemon_boyfriend/models/base.py
Normal file
28
src/daemon_boyfriend/models/base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""SQLAlchemy base model and metadata configuration."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
# Naming convention for constraints (helps with migrations)
|
||||
convention = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
metadata = MetaData(naming_convention=convention)
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
|
||||
metadata = metadata
|
||||
|
||||
# Common timestamp columns
|
||||
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
57
src/daemon_boyfriend/models/conversation.py
Normal file
57
src/daemon_boyfriend/models/conversation.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Conversation and message database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
"""A conversation session with a user."""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
guild_id: Mapped[int | None] = mapped_column(BigInteger)
|
||||
channel_id: Mapped[int | None] = mapped_column(BigInteger, index=True)
|
||||
started_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
last_message_at: Mapped[datetime] = mapped_column(default=datetime.utcnow, index=True)
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="conversations")
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
back_populates="conversation",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="Message.created_at",
|
||||
)
|
||||
|
||||
|
||||
class Message(Base):
|
||||
"""Individual chat message."""
|
||||
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
conversation_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("conversations.id", ondelete="CASCADE"), index=True
|
||||
)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
discord_message_id: Mapped[int | None] = mapped_column(BigInteger)
|
||||
role: Mapped[str] = mapped_column(String(20)) # user, assistant, system
|
||||
content: Mapped[str] = mapped_column(Text)
|
||||
has_images: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
image_urls: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
|
||||
token_count: Mapped[int | None] = mapped_column(Integer)
|
||||
|
||||
# Relationships
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
user: Mapped["User"] = relationship(back_populates="messages")
|
||||
50
src/daemon_boyfriend/models/guild.py
Normal file
50
src/daemon_boyfriend/models/guild.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Guild (Discord server) database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ARRAY, BigInteger, Boolean, ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class Guild(Base):
|
||||
"""Discord server tracking."""
|
||||
|
||||
__tablename__ = "guilds"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
joined_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
settings: Mapped[dict] = mapped_column(JSONB, default=dict)
|
||||
|
||||
# Relationships
|
||||
members: Mapped[list["GuildMember"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class GuildMember(Base):
|
||||
"""Track users per guild with guild-specific settings."""
|
||||
|
||||
__tablename__ = "guild_members"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
guild_id: Mapped[int] = mapped_column(ForeignKey("guilds.id", ondelete="CASCADE"), index=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
guild_nickname: Mapped[str | None] = mapped_column(String(255))
|
||||
roles: Mapped[list[str] | None] = mapped_column(ARRAY(Text), default=None)
|
||||
joined_guild_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||
|
||||
# Relationships
|
||||
guild: Mapped["Guild"] = relationship(back_populates="members")
|
||||
user: Mapped["User"] = relationship(back_populates="guild_memberships")
|
||||
|
||||
__table_args__ = (UniqueConstraint("guild_id", "user_id"),)
|
||||
83
src/daemon_boyfriend/models/user.py
Normal file
83
src/daemon_boyfriend/models/user.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""User-related database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, Float, ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .base import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .conversation import Conversation, Message
|
||||
from .guild import GuildMember
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Discord user tracking."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
discord_id: Mapped[int] = mapped_column(BigInteger, unique=True, index=True)
|
||||
discord_username: Mapped[str] = mapped_column(String(255))
|
||||
discord_display_name: Mapped[str | None] = mapped_column(String(255))
|
||||
custom_name: Mapped[str | None] = mapped_column(String(255))
|
||||
first_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
last_seen_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Relationships
|
||||
preferences: Mapped[list["UserPreference"]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
facts: Mapped[list["UserFact"]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
conversations: Mapped[list["Conversation"]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
messages: Mapped[list["Message"]] = relationship(back_populates="user")
|
||||
guild_memberships: Mapped[list["GuildMember"]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Get the name to use when addressing this user."""
|
||||
return self.custom_name or self.discord_display_name or self.discord_username
|
||||
|
||||
|
||||
class UserPreference(Base):
|
||||
"""Per-user preferences/settings."""
|
||||
|
||||
__tablename__ = "user_preferences"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
preference_key: Mapped[str] = mapped_column(String(100))
|
||||
preference_value: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="preferences")
|
||||
|
||||
__table_args__ = (UniqueConstraint("user_id", "preference_key"),)
|
||||
|
||||
|
||||
class UserFact(Base):
|
||||
"""Facts the bot has learned about users."""
|
||||
|
||||
__tablename__ = "user_facts"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
fact_type: Mapped[str] = mapped_column(String(50), index=True) # general, preference, hobby
|
||||
fact_content: Mapped[str] = mapped_column(Text)
|
||||
confidence: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
source: Mapped[str] = mapped_column(String(50), default="conversation")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
|
||||
learned_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||
last_referenced_at: Mapped[datetime | None] = mapped_column(default=None)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(back_populates="facts")
|
||||
@@ -2,14 +2,22 @@
|
||||
|
||||
from .ai_service import AIService
|
||||
from .conversation import ConversationManager
|
||||
from .database import DatabaseService, db, get_db
|
||||
from .persistent_conversation import PersistentConversationManager
|
||||
from .providers import AIResponse, ImageAttachment, Message
|
||||
from .searxng import SearXNGService
|
||||
from .user_service import UserService
|
||||
|
||||
__all__ = [
|
||||
"AIService",
|
||||
"AIResponse",
|
||||
"ConversationManager",
|
||||
"DatabaseService",
|
||||
"ImageAttachment",
|
||||
"Message",
|
||||
"ConversationManager",
|
||||
"PersistentConversationManager",
|
||||
"SearXNGService",
|
||||
"UserService",
|
||||
"db",
|
||||
"get_db",
|
||||
]
|
||||
|
||||
94
src/daemon_boyfriend/services/database.py
Normal file
94
src/daemon_boyfriend/services/database.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Database connection and session management."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
from daemon_boyfriend.models.base import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
"""Manages database connections and sessions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._engine = None
|
||||
self._session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if database URL is configured."""
|
||||
return settings.database_url is not None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if database has been initialized."""
|
||||
return self._initialized
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Initialize database connection."""
|
||||
if not self.is_configured:
|
||||
logger.info("Database URL not configured, skipping database initialization")
|
||||
return
|
||||
|
||||
logger.info("Initializing database connection...")
|
||||
self._engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.database_echo,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
)
|
||||
self._session_factory = async_sessionmaker(
|
||||
self._engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
self._initialized = True
|
||||
logger.info("Database connection initialized")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._engine:
|
||||
logger.info("Closing database connection...")
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
self._session_factory = None
|
||||
self._initialized = False
|
||||
logger.info("Database connection closed")
|
||||
|
||||
async def create_tables(self) -> None:
|
||||
"""Create all tables (for development/testing)."""
|
||||
if not self._engine:
|
||||
raise RuntimeError("Database not initialized")
|
||||
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Database tables created")
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session with automatic commit/rollback."""
|
||||
if not self._session_factory:
|
||||
raise RuntimeError("Database not initialized. Call init() first.")
|
||||
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
# Global instance
|
||||
db = DatabaseService()
|
||||
|
||||
|
||||
def get_db() -> DatabaseService:
|
||||
"""Get the global database service instance."""
|
||||
return db
|
||||
188
src/daemon_boyfriend/services/persistent_conversation.py
Normal file
188
src/daemon_boyfriend/services/persistent_conversation.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Persistent conversation management using PostgreSQL."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
from daemon_boyfriend.models import Conversation, Message, User
|
||||
from daemon_boyfriend.services.providers import Message as ProviderMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PersistentConversationManager:
|
||||
"""Manages conversation history in PostgreSQL."""
|
||||
|
||||
def __init__(self, session: AsyncSession, max_history: int | None = None) -> None:
|
||||
self._session = session
|
||||
self.max_history = max_history or settings.max_conversation_history
|
||||
self._timeout = timedelta(minutes=settings.conversation_timeout_minutes)
|
||||
|
||||
async def get_or_create_conversation(
|
||||
self,
|
||||
user: User,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> Conversation:
|
||||
"""Get active conversation or create new one.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
guild_id: Discord guild ID (None for DMs)
|
||||
channel_id: Discord channel ID
|
||||
|
||||
Returns:
|
||||
Conversation model instance
|
||||
"""
|
||||
# Look for recent active conversation in this channel
|
||||
cutoff = datetime.utcnow() - self._timeout
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.user_id == user.id,
|
||||
Conversation.is_active == True, # noqa: E712
|
||||
Conversation.last_message_at > cutoff,
|
||||
)
|
||||
|
||||
if channel_id:
|
||||
stmt = stmt.where(Conversation.channel_id == channel_id)
|
||||
|
||||
stmt = stmt.order_by(Conversation.last_message_at.desc())
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
conversation = result.scalar_first()
|
||||
|
||||
if conversation:
|
||||
logger.debug(
|
||||
f"Found existing conversation {conversation.id} for user {user.discord_id}"
|
||||
)
|
||||
return conversation
|
||||
|
||||
# Create new conversation
|
||||
conversation = Conversation(
|
||||
user_id=user.id,
|
||||
guild_id=guild_id,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
self._session.add(conversation)
|
||||
await self._session.flush()
|
||||
logger.info(f"Created new conversation {conversation.id} for user {user.discord_id}")
|
||||
return conversation
|
||||
|
||||
async def get_history(self, conversation: Conversation) -> list[ProviderMessage]:
|
||||
"""Get conversation history as provider messages.
|
||||
|
||||
Args:
|
||||
conversation: Conversation model instance
|
||||
|
||||
Returns:
|
||||
List of ProviderMessage instances for the AI
|
||||
"""
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(self.max_history)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
messages = list(reversed(result.scalars().all()))
|
||||
|
||||
return [
|
||||
ProviderMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
# Note: Images would need special handling if needed
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
user: User,
|
||||
role: str,
|
||||
content: str,
|
||||
discord_message_id: int | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
) -> Message:
|
||||
"""Add a message to the conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation model instance
|
||||
user: User model instance
|
||||
role: Message role (user, assistant, system)
|
||||
content: Message content
|
||||
discord_message_id: Discord's message ID (optional)
|
||||
image_urls: List of image URLs (optional)
|
||||
|
||||
Returns:
|
||||
Created Message instance
|
||||
"""
|
||||
message = Message(
|
||||
conversation_id=conversation.id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
content=content,
|
||||
discord_message_id=discord_message_id,
|
||||
has_images=bool(image_urls),
|
||||
image_urls=image_urls,
|
||||
)
|
||||
self._session.add(message)
|
||||
|
||||
# Update conversation stats
|
||||
conversation.last_message_at = datetime.utcnow()
|
||||
conversation.message_count += 1
|
||||
|
||||
await self._session.flush()
|
||||
return message
|
||||
|
||||
async def add_exchange(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
user: User,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
discord_message_id: int | None = None,
|
||||
image_urls: list[str] | None = None,
|
||||
) -> tuple[Message, Message]:
|
||||
"""Add a user/assistant exchange to the conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation model instance
|
||||
user: User model instance
|
||||
user_message: The user's message content
|
||||
assistant_message: The assistant's response
|
||||
discord_message_id: Discord's message ID (optional)
|
||||
image_urls: List of image URLs in user message (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (user_message, assistant_message) Message instances
|
||||
"""
|
||||
user_msg = await self.add_message(
|
||||
conversation, user, "user", user_message, discord_message_id, image_urls
|
||||
)
|
||||
assistant_msg = await self.add_message(conversation, user, "assistant", assistant_message)
|
||||
return user_msg, assistant_msg
|
||||
|
||||
async def clear_conversation(self, conversation: Conversation) -> None:
|
||||
"""Mark a conversation as inactive.
|
||||
|
||||
Args:
|
||||
conversation: Conversation model instance
|
||||
"""
|
||||
conversation.is_active = False
|
||||
logger.info(f"Marked conversation {conversation.id} as inactive")
|
||||
|
||||
async def get_message_count(self, conversation: Conversation) -> int:
|
||||
"""Get the number of messages in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation model instance
|
||||
|
||||
Returns:
|
||||
Number of messages
|
||||
"""
|
||||
return conversation.message_count
|
||||
250
src/daemon_boyfriend/services/user_service.py
Normal file
250
src/daemon_boyfriend/services/user_service.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""User management service."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from daemon_boyfriend.models import User, UserFact, UserPreference
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user-related operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_or_create_user(
|
||||
self,
|
||||
discord_id: int,
|
||||
username: str,
|
||||
display_name: str | None = None,
|
||||
) -> User:
|
||||
"""Get existing user or create new one.
|
||||
|
||||
Args:
|
||||
discord_id: Discord user ID
|
||||
username: Discord username
|
||||
display_name: Discord display name
|
||||
|
||||
Returns:
|
||||
User model instance
|
||||
"""
|
||||
stmt = select(User).where(User.discord_id == discord_id)
|
||||
result = await self._session.execute(stmt)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
# Update last seen and current name
|
||||
user.last_seen_at = datetime.utcnow()
|
||||
user.discord_username = username
|
||||
if display_name:
|
||||
user.discord_display_name = display_name
|
||||
return user
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
discord_id=discord_id,
|
||||
discord_username=username,
|
||||
discord_display_name=display_name,
|
||||
)
|
||||
self._session.add(user)
|
||||
await self._session.flush()
|
||||
logger.info(f"Created new user: {username} (discord_id={discord_id})")
|
||||
return user
|
||||
|
||||
async def get_user_by_discord_id(self, discord_id: int) -> User | None:
|
||||
"""Get a user by their Discord ID.
|
||||
|
||||
Args:
|
||||
discord_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
stmt = select(User).where(User.discord_id == discord_id)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def set_custom_name(self, discord_id: int, custom_name: str | None) -> User | None:
|
||||
"""Set a custom name for a user.
|
||||
|
||||
Args:
|
||||
discord_id: Discord user ID
|
||||
custom_name: Custom name to use, or None to clear
|
||||
|
||||
Returns:
|
||||
Updated user if found, None otherwise
|
||||
"""
|
||||
user = await self.get_user_by_discord_id(discord_id)
|
||||
if user:
|
||||
user.custom_name = custom_name
|
||||
logger.info(f"Set custom name for user {discord_id}: {custom_name}")
|
||||
return user
|
||||
|
||||
async def add_fact(
|
||||
self,
|
||||
user: User,
|
||||
fact_type: str,
|
||||
fact_content: str,
|
||||
source: str = "conversation",
|
||||
confidence: float = 1.0,
|
||||
) -> UserFact:
|
||||
"""Add a fact about a user.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
fact_type: Type of fact (general, preference, hobby, relationship, etc.)
|
||||
fact_content: The fact content
|
||||
source: How the fact was learned (conversation, explicit, inferred)
|
||||
confidence: Confidence level (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
Created UserFact instance
|
||||
"""
|
||||
fact = UserFact(
|
||||
user_id=user.id,
|
||||
fact_type=fact_type,
|
||||
fact_content=fact_content,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
self._session.add(fact)
|
||||
await self._session.flush()
|
||||
logger.debug(f"Added fact for user {user.discord_id}: [{fact_type}] {fact_content}")
|
||||
return fact
|
||||
|
||||
async def get_user_facts(
|
||||
self,
|
||||
user: User,
|
||||
fact_type: str | None = None,
|
||||
active_only: bool = True,
|
||||
) -> list[UserFact]:
|
||||
"""Get facts about a user.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
fact_type: Optional filter by fact type
|
||||
active_only: Only return active facts
|
||||
|
||||
Returns:
|
||||
List of UserFact instances
|
||||
"""
|
||||
stmt = select(UserFact).where(UserFact.user_id == user.id)
|
||||
|
||||
if active_only:
|
||||
stmt = stmt.where(UserFact.is_active == True) # noqa: E712
|
||||
|
||||
if fact_type:
|
||||
stmt = stmt.where(UserFact.fact_type == fact_type)
|
||||
|
||||
stmt = stmt.order_by(UserFact.learned_at.desc())
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_user_facts(self, user: User) -> int:
|
||||
"""Soft-delete all facts for a user (set is_active=False).
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
|
||||
Returns:
|
||||
Number of facts deactivated
|
||||
"""
|
||||
facts = await self.get_user_facts(user, active_only=True)
|
||||
for fact in facts:
|
||||
fact.is_active = False
|
||||
logger.info(f"Deactivated {len(facts)} facts for user {user.discord_id}")
|
||||
return len(facts)
|
||||
|
||||
async def set_preference(self, user: User, key: str, value: str | None) -> UserPreference:
|
||||
"""Set a user preference.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
key: Preference key
|
||||
value: Preference value, or None to delete
|
||||
|
||||
Returns:
|
||||
UserPreference instance
|
||||
"""
|
||||
stmt = select(UserPreference).where(
|
||||
UserPreference.user_id == user.id,
|
||||
UserPreference.preference_key == key,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
pref = result.scalar_one_or_none()
|
||||
|
||||
if pref:
|
||||
pref.preference_value = value
|
||||
else:
|
||||
pref = UserPreference(
|
||||
user_id=user.id,
|
||||
preference_key=key,
|
||||
preference_value=value,
|
||||
)
|
||||
self._session.add(pref)
|
||||
await self._session.flush()
|
||||
|
||||
return pref
|
||||
|
||||
async def get_preference(self, user: User, key: str) -> str | None:
|
||||
"""Get a user preference value.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
key: Preference key
|
||||
|
||||
Returns:
|
||||
Preference value or None if not set
|
||||
"""
|
||||
stmt = select(UserPreference).where(
|
||||
UserPreference.user_id == user.id,
|
||||
UserPreference.preference_key == key,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
pref = result.scalar_one_or_none()
|
||||
return pref.preference_value if pref else None
|
||||
|
||||
async def get_user_context(self, user: User) -> str:
|
||||
"""Build a context string about a user for the AI.
|
||||
|
||||
Args:
|
||||
user: User model instance
|
||||
|
||||
Returns:
|
||||
Formatted context string
|
||||
"""
|
||||
lines = [f"User's preferred name: {user.display_name}"]
|
||||
|
||||
if user.custom_name:
|
||||
lines.append(f"(You should address them as: {user.custom_name})")
|
||||
|
||||
facts = await self.get_user_facts(user, active_only=True)
|
||||
|
||||
if facts:
|
||||
lines.append("\nKnown facts about this user:")
|
||||
for fact in facts[:20]: # Limit to 20 most recent facts
|
||||
lines.append(f"- [{fact.fact_type}] {fact.fact_content}")
|
||||
# Mark as referenced
|
||||
fact.last_referenced_at = datetime.utcnow()
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def get_user_with_facts(self, discord_id: int) -> User | None:
|
||||
"""Get a user with their facts eagerly loaded.
|
||||
|
||||
Args:
|
||||
discord_id: Discord user ID
|
||||
|
||||
Returns:
|
||||
User with facts loaded, or None if not found
|
||||
"""
|
||||
stmt = select(User).where(User.discord_id == discord_id).options(selectinload(User.facts))
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
Reference in New Issue
Block a user