From d371fb77cf867b183be27478fc1cc8c875c60e42 Mon Sep 17 00:00:00 2001 From: latte Date: Mon, 12 Jan 2026 20:41:04 +0100 Subject: [PATCH] quick adding (not working) --- schema.sql | 4 +- tests/__init__.py | 1 + tests/conftest.py | 309 ++++++++++++++++++++ tests/test_models.py | 487 +++++++++++++++++++++++++++++++ tests/test_providers.py | 290 +++++++++++++++++++ tests/test_services.py | 620 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 1710 insertions(+), 1 deletion(-) create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_models.py create mode 100644 tests/test_providers.py create mode 100644 tests/test_services.py diff --git a/schema.sql b/schema.sql index 791a8a7..4d377db 100644 --- a/schema.sql +++ b/schema.sql @@ -245,7 +245,9 @@ CREATE TABLE IF NOT EXISTS mood_history ( trigger_type VARCHAR(50) NOT NULL, -- conversation, time_decay, event trigger_user_id BIGINT REFERENCES users(id) ON DELETE SET NULL, trigger_description TEXT, - recorded_at TIMESTAMPTZ DEFAULT NOW() + recorded_at TIMESTAMPTZ DEFAULT NOW(), + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS ix_mood_history_guild_id ON mood_history(guild_id); diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..901f328 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for Daemon Boyfriend Discord bot.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..86cfc36 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,309 @@ +"""Pytest configuration and fixtures for the test suite.""" + +import asyncio +from datetime import datetime, timezone +from typing import AsyncGenerator, Generator +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from daemon_boyfriend.config import Settings +from daemon_boyfriend.models.base import Base + +# --- Event Loop Fixture --- + + +@pytest.fixture(scope="session") +def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: + """Create an event loop for the test session.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +# --- Database Fixtures --- + + +@pytest_asyncio.fixture +async def async_engine(): + """Create an async SQLite engine for testing.""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + echo=False, + ) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + yield engine + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest_asyncio.fixture +async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]: + """Create a database session for testing.""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + async with async_session_maker() as session: + yield session + await session.rollback() + + +# --- Mock Settings Fixture --- + + +@pytest.fixture +def mock_settings() -> Settings: + """Create mock settings for testing.""" + with patch.dict( + "os.environ", + { + "DISCORD_TOKEN": "test_token", + "AI_PROVIDER": "openai", + "AI_MODEL": "gpt-4o-mini", + "OPENAI_API_KEY": "test_openai_key", + "ANTHROPIC_API_KEY": "test_anthropic_key", + "GEMINI_API_KEY": "test_gemini_key", + "OPENROUTER_API_KEY": "test_openrouter_key", + "BOT_NAME": "TestBot", + "BOT_PERSONALITY": "helpful and friendly", + "DATABASE_URL": "", + "LIVING_AI_ENABLED": "true", + "MOOD_ENABLED": "true", + "RELATIONSHIP_ENABLED": "true", + }, + ): + return Settings() + + +# --- Mock Discord Fixtures --- + + +@pytest.fixture +def mock_discord_user() -> MagicMock: + """Create a mock Discord user.""" + user = MagicMock() + user.id = 123456789 + user.name = "TestUser" + user.display_name = "Test User" + user.mention = "<@123456789>" + user.bot = False + return user + + +@pytest.fixture +def mock_discord_message(mock_discord_user) -> MagicMock: + """Create a mock Discord message.""" + message = MagicMock() + message.author = mock_discord_user + message.content = "Hello, bot!" + message.channel = MagicMock() + message.channel.id = 987654321 + message.channel.send = AsyncMock() + message.channel.typing = MagicMock(return_value=AsyncMock()) + message.guild = MagicMock() + message.guild.id = 111222333 + message.guild.name = "Test Guild" + message.id = 555666777 + message.mentions = [] + return message + + +@pytest.fixture +def mock_discord_bot() -> MagicMock: + """Create a mock Discord bot.""" + bot = MagicMock() + bot.user = MagicMock() + bot.user.id = 999888777 + bot.user.name = "TestBot" + bot.user.mentioned_in = MagicMock(return_value=True) + return bot + + +# --- Mock AI Provider Fixtures --- + + +@pytest.fixture +def mock_ai_response() -> MagicMock: + """Create a mock AI response.""" + from daemon_boyfriend.services.providers.base import AIResponse + + return AIResponse( + content="This is a test response from the AI.", + model="test-model", + usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + ) + + +@pytest.fixture +def mock_openai_client() -> MagicMock: + """Create a mock OpenAI client.""" + client = MagicMock() + + response = MagicMock() + response.choices = [MagicMock()] + response.choices[0].message.content = "Test OpenAI response" + response.model = "gpt-4o-mini" + response.usage = MagicMock() + response.usage.prompt_tokens = 10 + response.usage.completion_tokens = 20 + response.usage.total_tokens = 30 + + client.chat.completions.create = AsyncMock(return_value=response) + return client + + +@pytest.fixture +def mock_anthropic_client() -> MagicMock: + """Create a mock Anthropic client.""" + client = MagicMock() + + response = MagicMock() + response.content = [MagicMock()] + response.content[0].type = "text" + response.content[0].text = "Test Anthropic response" + response.model = "claude-sonnet-4-20250514" + response.usage = MagicMock() + response.usage.input_tokens = 10 + response.usage.output_tokens = 20 + + client.messages.create = AsyncMock(return_value=response) + return client + + +@pytest.fixture +def mock_gemini_client() -> MagicMock: + """Create a mock Gemini client.""" + client = MagicMock() + + response = MagicMock() + response.text = "Test Gemini response" + response.usage_metadata = MagicMock() + response.usage_metadata.prompt_token_count = 10 + response.usage_metadata.candidates_token_count = 20 + response.usage_metadata.total_token_count = 30 + + client.aio.models.generate_content = AsyncMock(return_value=response) + return client + + +# --- Model Fixtures --- + + +@pytest_asyncio.fixture +async def sample_user(db_session: AsyncSession): + """Create a sample user in the database.""" + from daemon_boyfriend.models import User + + user = User( + discord_id=123456789, + discord_username="testuser", + discord_display_name="Test User", + ) + db_session.add(user) + await db_session.commit() + await db_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def sample_user_with_facts(db_session: AsyncSession, sample_user): + """Create a sample user with facts.""" + from daemon_boyfriend.models import UserFact + + facts = [ + UserFact( + user_id=sample_user.id, + fact_type="hobby", + fact_content="likes programming", + confidence=1.0, + source="explicit", + ), + UserFact( + user_id=sample_user.id, + fact_type="preference", + fact_content="prefers dark mode", + confidence=0.8, + source="conversation", + ), + ] + + for fact in facts: + db_session.add(fact) + + await db_session.commit() + return sample_user + + +@pytest_asyncio.fixture +async def sample_conversation(db_session: AsyncSession, sample_user): + """Create a sample conversation.""" + from daemon_boyfriend.models import Conversation + + conversation = Conversation( + user_id=sample_user.id, + guild_id=111222333, + channel_id=987654321, + ) + db_session.add(conversation) + await db_session.commit() + await db_session.refresh(conversation) + return conversation + + +@pytest_asyncio.fixture +async def sample_bot_state(db_session: AsyncSession): + """Create a sample bot state.""" + from daemon_boyfriend.models import BotState + + bot_state = BotState( + guild_id=111222333, + mood_valence=0.5, + mood_arousal=0.3, + ) + db_session.add(bot_state) + await db_session.commit() + await db_session.refresh(bot_state) + return bot_state + + +@pytest_asyncio.fixture +async def sample_user_relationship(db_session: AsyncSession, sample_user): + """Create a sample user relationship.""" + from daemon_boyfriend.models import UserRelationship + + relationship = UserRelationship( + user_id=sample_user.id, + guild_id=111222333, + relationship_score=50.0, + total_interactions=10, + positive_interactions=8, + negative_interactions=1, + ) + db_session.add(relationship) + await db_session.commit() + await db_session.refresh(relationship) + return relationship + + +# --- Utility Fixtures --- + + +@pytest.fixture +def utc_now() -> datetime: + """Get current UTC time.""" + return datetime.now(timezone.utc) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..8ffd5b4 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,487 @@ +"""Tests for database models.""" + +from datetime import datetime, timedelta, timezone + +import pytest + +from daemon_boyfriend.models import ( + BotOpinion, + BotState, + Conversation, + FactAssociation, + Guild, + GuildMember, + Message, + MoodHistory, + ScheduledEvent, + User, + UserCommunicationStyle, + UserFact, + UserPreference, + UserRelationship, +) +from daemon_boyfriend.models.base import utc_now + + +class TestUtcNow: + """Tests for the utc_now helper function.""" + + def test_returns_timezone_aware(self): + """Test that utc_now returns timezone-aware datetime.""" + now = utc_now() + assert now.tzinfo is not None + assert now.tzinfo == timezone.utc + + def test_returns_current_time(self): + """Test that utc_now returns approximately current time.""" + before = datetime.now(timezone.utc) + now = utc_now() + after = datetime.now(timezone.utc) + assert before <= now <= after + + +class TestUserModel: + """Tests for the User model.""" + + @pytest.mark.asyncio + async def test_create_user(self, db_session): + """Test creating a user.""" + user = User( + discord_id=123456789, + discord_username="testuser", + discord_display_name="Test User", + ) + db_session.add(user) + await db_session.commit() + + assert user.id is not None + assert user.discord_id == 123456789 + assert user.is_active is True + assert user.created_at is not None + + @pytest.mark.asyncio + async def test_user_display_name_property(self, db_session): + """Test the display_name property.""" + user = User( + discord_id=123456789, + discord_username="testuser", + discord_display_name="Test User", + ) + db_session.add(user) + await db_session.commit() + + # Uses discord_display_name when no custom_name + assert user.display_name == "Test User" + + # Uses custom_name when set + user.custom_name = "Custom Name" + assert user.display_name == "Custom Name" + + @pytest.mark.asyncio + async def test_user_display_name_fallback(self, db_session): + """Test display_name falls back to username.""" + user = User( + discord_id=123456789, + discord_username="testuser", + ) + db_session.add(user) + await db_session.commit() + + assert user.display_name == "testuser" + + @pytest.mark.asyncio + async def test_user_timestamps(self, db_session): + """Test user timestamp fields.""" + user = User( + discord_id=123456789, + discord_username="testuser", + ) + db_session.add(user) + await db_session.commit() + + assert user.first_seen_at is not None + assert user.last_seen_at is not None + assert user.created_at is not None + assert user.updated_at is not None + + +class TestUserFactModel: + """Tests for the UserFact model.""" + + @pytest.mark.asyncio + async def test_create_fact(self, db_session, sample_user): + """Test creating a user fact.""" + fact = UserFact( + user_id=sample_user.id, + fact_type="hobby", + fact_content="likes gaming", + confidence=0.9, + source="conversation", + ) + db_session.add(fact) + await db_session.commit() + + assert fact.id is not None + assert fact.is_active is True + assert fact.learned_at is not None + + @pytest.mark.asyncio + async def test_fact_default_values(self, db_session, sample_user): + """Test fact default values.""" + fact = UserFact( + user_id=sample_user.id, + fact_type="general", + fact_content="test fact", + ) + db_session.add(fact) + await db_session.commit() + + assert fact.confidence == 1.0 + assert fact.source == "conversation" + assert fact.is_active is True + + +class TestConversationModel: + """Tests for the Conversation model.""" + + @pytest.mark.asyncio + async def test_create_conversation(self, db_session, sample_user): + """Test creating a conversation.""" + conv = Conversation( + user_id=sample_user.id, + guild_id=111222333, + channel_id=444555666, + ) + db_session.add(conv) + await db_session.commit() + + assert conv.id is not None + assert conv.message_count == 0 + assert conv.is_active is True + assert conv.started_at is not None + + @pytest.mark.asyncio + async def test_conversation_with_messages(self, db_session, sample_user): + """Test conversation with messages.""" + conv = Conversation( + user_id=sample_user.id, + channel_id=444555666, + ) + db_session.add(conv) + await db_session.commit() + + msg = Message( + conversation_id=conv.id, + user_id=sample_user.id, + role="user", + content="Hello!", + ) + db_session.add(msg) + await db_session.commit() + + assert msg.id is not None + assert msg.role == "user" + assert msg.has_images is False + + +class TestMessageModel: + """Tests for the Message model.""" + + @pytest.mark.asyncio + async def test_create_message(self, db_session, sample_conversation, sample_user): + """Test creating a message.""" + msg = Message( + conversation_id=sample_conversation.id, + user_id=sample_user.id, + role="user", + content="Test message", + ) + db_session.add(msg) + await db_session.commit() + + assert msg.id is not None + assert msg.has_images is False + assert msg.image_urls is None + + @pytest.mark.asyncio + async def test_message_with_images(self, db_session, sample_conversation, sample_user): + """Test message with images.""" + msg = Message( + conversation_id=sample_conversation.id, + user_id=sample_user.id, + role="user", + content="Look at this", + has_images=True, + image_urls=["https://example.com/image.png"], + ) + db_session.add(msg) + await db_session.commit() + + assert msg.has_images is True + assert len(msg.image_urls) == 1 + + +class TestGuildModel: + """Tests for the Guild model.""" + + @pytest.mark.asyncio + async def test_create_guild(self, db_session): + """Test creating a guild.""" + guild = Guild( + discord_id=111222333, + name="Test Guild", + ) + db_session.add(guild) + await db_session.commit() + + assert guild.id is not None + assert guild.is_active is True + assert guild.settings == {} + + @pytest.mark.asyncio + async def test_guild_with_settings(self, db_session): + """Test guild with custom settings.""" + guild = Guild( + discord_id=111222333, + name="Test Guild", + settings={"prefix": "!", "language": "en"}, + ) + db_session.add(guild) + await db_session.commit() + + assert guild.settings["prefix"] == "!" + assert guild.settings["language"] == "en" + + +class TestGuildMemberModel: + """Tests for the GuildMember model.""" + + @pytest.mark.asyncio + async def test_create_guild_member(self, db_session, sample_user): + """Test creating a guild member.""" + guild = Guild(discord_id=111222333, name="Test Guild") + db_session.add(guild) + await db_session.commit() + + member = GuildMember( + guild_id=guild.id, + user_id=sample_user.id, + guild_nickname="TestNick", + ) + db_session.add(member) + await db_session.commit() + + assert member.id is not None + assert member.guild_nickname == "TestNick" + + +class TestBotStateModel: + """Tests for the BotState model.""" + + @pytest.mark.asyncio + async def test_create_bot_state(self, db_session): + """Test creating a bot state.""" + state = BotState(guild_id=111222333) + db_session.add(state) + await db_session.commit() + + assert state.id is not None + assert state.mood_valence == 0.0 + assert state.mood_arousal == 0.0 + assert state.total_messages_sent == 0 + + @pytest.mark.asyncio + async def test_bot_state_defaults(self, db_session): + """Test bot state default values.""" + state = BotState() + db_session.add(state) + await db_session.commit() + + assert state.guild_id is None + assert state.preferences == {} + assert state.total_facts_learned == 0 + assert state.total_users_known == 0 + + +class TestBotOpinionModel: + """Tests for the BotOpinion model.""" + + @pytest.mark.asyncio + async def test_create_opinion(self, db_session): + """Test creating a bot opinion.""" + opinion = BotOpinion( + topic="programming", + sentiment=0.8, + interest_level=0.9, + ) + db_session.add(opinion) + await db_session.commit() + + assert opinion.id is not None + assert opinion.discussion_count == 0 + assert opinion.formed_at is not None + + +class TestUserRelationshipModel: + """Tests for the UserRelationship model.""" + + @pytest.mark.asyncio + async def test_create_relationship(self, db_session, sample_user): + """Test creating a user relationship.""" + rel = UserRelationship( + user_id=sample_user.id, + guild_id=111222333, + ) + db_session.add(rel) + await db_session.commit() + + assert rel.id is not None + assert rel.relationship_score == 10.0 + assert rel.total_interactions == 0 + + @pytest.mark.asyncio + async def test_relationship_defaults(self, db_session, sample_user): + """Test relationship default values.""" + rel = UserRelationship(user_id=sample_user.id) + db_session.add(rel) + await db_session.commit() + + assert rel.shared_references == {} + assert rel.positive_interactions == 0 + assert rel.negative_interactions == 0 + assert rel.avg_message_length == 0.0 + + +class TestUserCommunicationStyleModel: + """Tests for the UserCommunicationStyle model.""" + + @pytest.mark.asyncio + async def test_create_style(self, db_session, sample_user): + """Test creating a communication style.""" + style = UserCommunicationStyle(user_id=sample_user.id) + db_session.add(style) + await db_session.commit() + + assert style.id is not None + assert style.preferred_length == "medium" + assert style.preferred_formality == 0.5 + + @pytest.mark.asyncio + async def test_style_defaults(self, db_session, sample_user): + """Test communication style defaults.""" + style = UserCommunicationStyle(user_id=sample_user.id) + db_session.add(style) + await db_session.commit() + + assert style.emoji_affinity == 0.5 + assert style.humor_affinity == 0.5 + assert style.detail_preference == 0.5 + assert style.samples_collected == 0 + assert style.confidence == 0.0 + + +class TestScheduledEventModel: + """Tests for the ScheduledEvent model.""" + + @pytest.mark.asyncio + async def test_create_event(self, db_session, sample_user): + """Test creating a scheduled event.""" + trigger_time = datetime.now(timezone.utc) + timedelta(days=1) + event = ScheduledEvent( + user_id=sample_user.id, + event_type="birthday", + trigger_at=trigger_time, + title="Birthday reminder", + ) + db_session.add(event) + await db_session.commit() + + assert event.id is not None + assert event.status == "pending" + assert event.is_recurring is False + + @pytest.mark.asyncio + async def test_recurring_event(self, db_session, sample_user): + """Test creating a recurring event.""" + trigger_time = datetime.now(timezone.utc) + timedelta(days=1) + event = ScheduledEvent( + user_id=sample_user.id, + event_type="birthday", + trigger_at=trigger_time, + title="Birthday", + is_recurring=True, + recurrence_rule="yearly", + ) + db_session.add(event) + await db_session.commit() + + assert event.is_recurring is True + assert event.recurrence_rule == "yearly" + + +class TestFactAssociationModel: + """Tests for the FactAssociation model.""" + + @pytest.mark.asyncio + async def test_create_association(self, db_session, sample_user): + """Test creating a fact association.""" + fact1 = UserFact( + user_id=sample_user.id, + fact_type="hobby", + fact_content="likes programming", + ) + fact2 = UserFact( + user_id=sample_user.id, + fact_type="hobby", + fact_content="likes Python", + ) + db_session.add(fact1) + db_session.add(fact2) + await db_session.commit() + + assoc = FactAssociation( + fact_id_1=fact1.id, + fact_id_2=fact2.id, + association_type="shared_interest", + strength=0.8, + ) + db_session.add(assoc) + await db_session.commit() + + assert assoc.id is not None + assert assoc.discovered_at is not None + + +class TestMoodHistoryModel: + """Tests for the MoodHistory model.""" + + @pytest.mark.asyncio + async def test_create_mood_history(self, db_session, sample_user): + """Test creating a mood history entry.""" + history = MoodHistory( + guild_id=111222333, + valence=0.5, + arousal=0.3, + trigger_type="conversation", + trigger_user_id=sample_user.id, + trigger_description="Had a nice chat", + ) + db_session.add(history) + await db_session.commit() + + assert history.id is not None + assert history.recorded_at is not None + + @pytest.mark.asyncio + async def test_mood_history_without_user(self, db_session): + """Test mood history without trigger user.""" + history = MoodHistory( + valence=-0.2, + arousal=-0.1, + trigger_type="time_decay", + ) + db_session.add(history) + await db_session.commit() + + assert history.trigger_user_id is None + assert history.trigger_description is None diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 0000000..c1b4ed0 --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,290 @@ +"""Tests for AI provider implementations.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from daemon_boyfriend.services.providers.base import ( + AIProvider, + AIResponse, + Message, + ImageAttachment, +) +from daemon_boyfriend.services.providers.openai import OpenAIProvider +from daemon_boyfriend.services.providers.anthropic import AnthropicProvider +from daemon_boyfriend.services.providers.gemini import GeminiProvider +from daemon_boyfriend.services.providers.openrouter import OpenRouterProvider + + +class TestMessage: + """Tests for the Message dataclass.""" + + def test_message_creation(self): + """Test creating a basic message.""" + msg = Message(role="user", content="Hello") + assert msg.role == "user" + assert msg.content == "Hello" + assert msg.images == [] + + def test_message_with_images(self): + """Test creating a message with images.""" + images = [ImageAttachment(url="https://example.com/image.png")] + msg = Message(role="user", content="Look at this", images=images) + assert len(msg.images) == 1 + assert msg.images[0].url == "https://example.com/image.png" + + +class TestImageAttachment: + """Tests for the ImageAttachment dataclass.""" + + def test_default_media_type(self): + """Test default media type.""" + img = ImageAttachment(url="https://example.com/image.png") + assert img.media_type == "image/png" + + def test_custom_media_type(self): + """Test custom media type.""" + img = ImageAttachment(url="https://example.com/image.jpg", media_type="image/jpeg") + assert img.media_type == "image/jpeg" + + +class TestAIResponse: + """Tests for the AIResponse dataclass.""" + + def test_response_creation(self): + """Test creating an AI response.""" + response = AIResponse( + content="Hello!", + model="test-model", + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + ) + assert response.content == "Hello!" + assert response.model == "test-model" + assert response.usage["total_tokens"] == 15 + + +class TestOpenAIProvider: + """Tests for the OpenAI provider.""" + + @pytest.fixture + def provider(self, mock_openai_client): + """Create an OpenAI provider with mocked client.""" + with patch("daemon_boyfriend.services.providers.openai.AsyncOpenAI") as mock_class: + mock_class.return_value = mock_openai_client + provider = OpenAIProvider(api_key="test_key", model="gpt-4o-mini") + provider.client = mock_openai_client + return provider + + def test_provider_name(self, provider): + """Test provider name.""" + assert provider.provider_name == "openai" + + def test_model_setting(self, provider): + """Test model is set correctly.""" + assert provider.model == "gpt-4o-mini" + + @pytest.mark.asyncio + async def test_generate_simple_message(self, provider, mock_openai_client): + """Test generating a response with a simple message.""" + messages = [Message(role="user", content="Hello")] + + response = await provider.generate(messages) + + assert response.content == "Test OpenAI response" + assert response.model == "gpt-4o-mini" + mock_openai_client.chat.completions.create.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_with_system_prompt(self, provider, mock_openai_client): + """Test generating a response with a system prompt.""" + messages = [Message(role="user", content="Hello")] + + await provider.generate(messages, system_prompt="You are a helpful assistant.") + + call_args = mock_openai_client.chat.completions.create.call_args + api_messages = call_args.kwargs["messages"] + assert api_messages[0]["role"] == "system" + assert api_messages[0]["content"] == "You are a helpful assistant." + + @pytest.mark.asyncio + async def test_generate_with_images(self, provider, mock_openai_client): + """Test generating a response with images.""" + images = [ImageAttachment(url="https://example.com/image.png")] + messages = [Message(role="user", content="What's in this image?", images=images)] + + await provider.generate(messages) + + call_args = mock_openai_client.chat.completions.create.call_args + api_messages = call_args.kwargs["messages"] + content = api_messages[0]["content"] + assert isinstance(content, list) + assert content[0]["type"] == "text" + assert content[1]["type"] == "image_url" + + def test_build_message_content_no_images(self, provider): + """Test building message content without images.""" + msg = Message(role="user", content="Hello") + content = provider._build_message_content(msg) + assert content == "Hello" + + def test_build_message_content_with_images(self, provider): + """Test building message content with images.""" + images = [ImageAttachment(url="https://example.com/image.png")] + msg = Message(role="user", content="Hello", images=images) + content = provider._build_message_content(msg) + assert isinstance(content, list) + assert len(content) == 2 + + +class TestAnthropicProvider: + """Tests for the Anthropic provider.""" + + @pytest.fixture + def provider(self, mock_anthropic_client): + """Create an Anthropic provider with mocked client.""" + with patch( + "daemon_boyfriend.services.providers.anthropic.anthropic.AsyncAnthropic" + ) as mock_class: + """Tests for the Anthropic provider.""" + + @pytest.fixture + def provider(self, mock_anthropic_client): + """Create an Anthropic provider with mocked client.""" + with patch("daemon_boyfriend.services.providers.anthropic.anthropic.AsyncAnthropic") as mock_class: + mock_class.return_value = mock_anthropic_client + provider = AnthropicProvider(api_key="test_key", model="claude-sonnet-4-20250514") + provider.client = mock_anthropic_client + return provider + + def test_provider_name(self, provider): + """Test provider name.""" + assert provider.provider_name == "anthropic" + + def test_model_setting(self, provider): + """Test model is set correctly.""" + assert provider.model == "claude-sonnet-4-20250514" + + @pytest.mark.asyncio + async def test_generate_simple_message(self, provider, mock_anthropic_client): + """Test generating a response with a simple message.""" + messages = [Message(role="user", content="Hello")] + + response = await provider.generate(messages) + + assert response.content == "Test Anthropic response" + mock_anthropic_client.messages.create.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_with_system_prompt(self, provider, mock_anthropic_client): + """Test generating a response with a system prompt.""" + messages = [Message(role="user", content="Hello")] + + await provider.generate(messages, system_prompt="You are a helpful assistant.") + + call_args = mock_anthropic_client.messages.create.call_args + assert call_args.kwargs["system"] == "You are a helpful assistant." + + def test_build_message_content_no_images(self, provider): + """Test building message content without images.""" + msg = Message(role="user", content="Hello") + content = provider._build_message_content(msg) + assert content == "Hello" + + def test_build_message_content_with_images(self, provider): + """Test building message content with images.""" + images = [ImageAttachment(url="https://example.com/image.png")] + msg = Message(role="user", content="Hello", images=images) + content = provider._build_message_content(msg) + assert isinstance(content, list) + assert content[0]["type"] == "image" + assert content[1]["type"] == "text" + + +class TestGeminiProvider: + """Tests for the Gemini provider.""" + + @pytest.fixture + def provider(self, mock_gemini_client): + """Create a Gemini provider with mocked client.""" + with patch("daemon_boyfriend.services.providers.gemini.genai.Client") as mock_class: + mock_class.return_value = mock_gemini_client + provider = GeminiProvider(api_key="test_key", model="gemini-2.0-flash") + provider.client = mock_gemini_client + return provider + + def test_provider_name(self, provider): + """Test provider name.""" + assert provider.provider_name == "gemini" + + def test_model_setting(self, provider): + """Test model is set correctly.""" + assert provider.model == "gemini-2.0-flash" + + @pytest.mark.asyncio + async def test_generate_simple_message(self, provider, mock_gemini_client): + """Test generating a response with a simple message.""" + messages = [Message(role="user", content="Hello")] + + response = await provider.generate(messages) + + assert response.content == "Test Gemini response" + mock_gemini_client.aio.models.generate_content.assert_called_once() + + @pytest.mark.asyncio + async def test_role_mapping(self, provider, mock_gemini_client): + """Test that 'assistant' role is mapped to 'model'.""" + messages = [ + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + Message(role="user", content="How are you?"), + ] + + await provider.generate(messages) + + call_args = mock_gemini_client.aio.models.generate_content.call_args + contents = call_args.kwargs["contents"] + assert contents[0].role == "user" + assert contents[1].role == "model" + assert contents[2].role == "user" + + +class TestOpenRouterProvider: + """Tests for the OpenRouter provider.""" + + @pytest.fixture + def provider(self, mock_openai_client): + """Create an OpenRouter provider with mocked client.""" + with patch("daemon_boyfriend.services.providers.openrouter.AsyncOpenAI") as mock_class: + mock_class.return_value = mock_openai_client + provider = OpenRouterProvider(api_key="test_key", model="openai/gpt-4o") + provider.client = mock_openai_client + return provider + + def test_provider_name(self, provider): + """Test provider name.""" + assert provider.provider_name == "openrouter" + + def test_model_setting(self, provider): + """Test model is set correctly.""" + assert provider.model == "openai/gpt-4o" + + @pytest.mark.asyncio + async def test_generate_simple_message(self, provider, mock_openai_client): + """Test generating a response with a simple message.""" + messages = [Message(role="user", content="Hello")] + + response = await provider.generate(messages) + + assert response.content == "Test OpenAI response" + mock_openai_client.chat.completions.create.assert_called_once() + + @pytest.mark.asyncio + async def test_extra_headers(self, provider, mock_openai_client): + """Test that OpenRouter-specific headers are included.""" + messages = [Message(role="user", content="Hello")] + + await provider.generate(messages) + + call_args = mock_openai_client.chat.completions.create.call_args + extra_headers = call_args.kwargs.get("extra_headers", {}) + assert "HTTP-Referer" in extra_headers + assert "X-Title" in extra_headers diff --git a/tests/test_services.py b/tests/test_services.py new file mode 100644 index 0000000..f2487eb --- /dev/null +++ b/tests/test_services.py @@ -0,0 +1,620 @@ +"""Tests for service layer.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from daemon_boyfriend.models import ( + BotOpinion, + BotState, + Conversation, + Message, + User, + UserFact, + UserRelationship, +) +from daemon_boyfriend.services.ai_service import AIService +from daemon_boyfriend.services.fact_extraction_service import FactExtractionService +from daemon_boyfriend.services.mood_service import MoodLabel, MoodService, MoodState +from daemon_boyfriend.services.opinion_service import OpinionService, extract_topics_from_message +from daemon_boyfriend.services.persistent_conversation import PersistentConversationManager +from daemon_boyfriend.services.relationship_service import RelationshipLevel, RelationshipService +from daemon_boyfriend.services.self_awareness_service import SelfAwarenessService +from daemon_boyfriend.services.user_service import UserService + + +class TestUserService: + """Tests for UserService.""" + + @pytest.mark.asyncio + async def test_get_or_create_user_new(self, db_session): + """Test creating a new user.""" + service = UserService(db_session) + + user = await service.get_or_create_user( + discord_id=123456789, + username="testuser", + display_name="Test User", + ) + + assert user.id is not None + assert user.discord_id == 123456789 + assert user.discord_username == "testuser" + + @pytest.mark.asyncio + async def test_get_or_create_user_existing(self, db_session, sample_user): + """Test getting an existing user.""" + service = UserService(db_session) + + user = await service.get_or_create_user( + discord_id=sample_user.discord_id, + username="newname", + display_name="New Display", + ) + + assert user.id == sample_user.id + assert user.discord_username == "newname" + + @pytest.mark.asyncio + async def test_set_custom_name(self, db_session, sample_user): + """Test setting a custom name.""" + service = UserService(db_session) + + user = await service.set_custom_name(sample_user.discord_id, "CustomName") + + assert user.custom_name == "CustomName" + assert user.display_name == "CustomName" + + @pytest.mark.asyncio + async def test_clear_custom_name(self, db_session, sample_user): + """Test clearing a custom name.""" + service = UserService(db_session) + sample_user.custom_name = "OldName" + + user = await service.set_custom_name(sample_user.discord_id, None) + + assert user.custom_name is None + + @pytest.mark.asyncio + async def test_add_fact(self, db_session, sample_user): + """Test adding a fact about a user.""" + service = UserService(db_session) + + fact = await service.add_fact( + user=sample_user, + fact_type="hobby", + fact_content="likes programming", + ) + + assert fact.id is not None + assert fact.user_id == sample_user.id + assert fact.fact_type == "hobby" + + @pytest.mark.asyncio + async def test_get_user_facts(self, db_session, sample_user_with_facts): + """Test getting user facts.""" + service = UserService(db_session) + + facts = await service.get_user_facts(sample_user_with_facts) + + assert len(facts) == 2 + + @pytest.mark.asyncio + async def test_get_user_facts_by_type(self, db_session, sample_user_with_facts): + """Test getting user facts by type.""" + service = UserService(db_session) + + facts = await service.get_user_facts(sample_user_with_facts, fact_type="hobby") + + assert len(facts) == 1 + assert facts[0].fact_type == "hobby" + + @pytest.mark.asyncio + async def test_delete_user_facts(self, db_session, sample_user_with_facts): + """Test deleting user facts.""" + service = UserService(db_session) + + count = await service.delete_user_facts(sample_user_with_facts) + + assert count == 2 + facts = await service.get_user_facts(sample_user_with_facts, active_only=True) + assert len(facts) == 0 + + @pytest.mark.asyncio + async def test_get_user_context(self, db_session, sample_user_with_facts): + """Test getting user context string.""" + service = UserService(db_session) + + context = await service.get_user_context(sample_user_with_facts) + + assert "Test User" in context + assert "likes programming" in context + + +class TestMoodService: + """Tests for MoodService.""" + + @pytest.mark.asyncio + async def test_get_or_create_bot_state(self, db_session): + """Test getting or creating bot state.""" + service = MoodService(db_session) + + state = await service.get_or_create_bot_state(guild_id=111222333) + + assert state.id is not None + assert state.guild_id == 111222333 + + @pytest.mark.asyncio + async def test_get_current_mood(self, db_session, sample_bot_state): + """Test getting current mood.""" + service = MoodService(db_session) + + mood = await service.get_current_mood(guild_id=sample_bot_state.guild_id) + + assert isinstance(mood, MoodState) + assert -1.0 <= mood.valence <= 1.0 + assert -1.0 <= mood.arousal <= 1.0 + + @pytest.mark.asyncio + async def test_update_mood(self, db_session, sample_bot_state): + """Test updating mood.""" + service = MoodService(db_session) + + new_mood = await service.update_mood( + guild_id=sample_bot_state.guild_id, + sentiment_delta=0.5, + engagement_delta=0.3, + trigger_type="conversation", + trigger_description="Had a nice chat", + ) + + assert new_mood.valence > 0 + + @pytest.mark.asyncio + async def test_increment_stats(self, db_session, sample_bot_state): + """Test incrementing bot stats.""" + service = MoodService(db_session) + initial_messages = sample_bot_state.total_messages_sent + + await service.increment_stats( + guild_id=sample_bot_state.guild_id, + messages_sent=5, + facts_learned=2, + ) + + assert sample_bot_state.total_messages_sent == initial_messages + 5 + + def test_classify_mood_excited(self): + """Test mood classification for excited.""" + service = MoodService(None) + label = service._classify_mood(0.5, 0.5) + assert label == MoodLabel.EXCITED + + def test_classify_mood_happy(self): + """Test mood classification for happy.""" + service = MoodService(None) + label = service._classify_mood(0.5, 0.0) + assert label == MoodLabel.HAPPY + + def test_classify_mood_bored(self): + """Test mood classification for bored.""" + service = MoodService(None) + label = service._classify_mood(-0.5, 0.0) + assert label == MoodLabel.BORED + + def test_classify_mood_annoyed(self): + """Test mood classification for annoyed.""" + service = MoodService(None) + label = service._classify_mood(-0.5, 0.5) + assert label == MoodLabel.ANNOYED + + def test_get_mood_prompt_modifier(self): + """Test getting mood prompt modifier.""" + service = MoodService(None) + mood = MoodState(valence=0.8, arousal=0.8, label=MoodLabel.EXCITED, intensity=0.8) + + modifier = service.get_mood_prompt_modifier(mood) + + assert "enthusiastic" in modifier.lower() or "excited" in modifier.lower() + + def test_get_mood_prompt_modifier_low_intensity(self): + """Test mood modifier with low intensity.""" + service = MoodService(None) + mood = MoodState(valence=0.1, arousal=0.1, label=MoodLabel.NEUTRAL, intensity=0.1) + + modifier = service.get_mood_prompt_modifier(mood) + + assert modifier == "" + + +class TestRelationshipService: + """Tests for RelationshipService.""" + + @pytest.mark.asyncio + async def test_get_or_create_relationship(self, db_session, sample_user): + """Test getting or creating a relationship.""" + service = RelationshipService(db_session) + + rel = await service.get_or_create_relationship(sample_user, guild_id=111222333) + + assert rel.id is not None + assert rel.user_id == sample_user.id + + @pytest.mark.asyncio + async def test_record_interaction(self, db_session, sample_user): + """Test recording an interaction.""" + service = RelationshipService(db_session) + + level = await service.record_interaction( + user=sample_user, + guild_id=111222333, + sentiment=0.8, + message_length=100, + conversation_turns=3, + ) + + assert isinstance(level, RelationshipLevel) + + @pytest.mark.asyncio + async def test_record_positive_interaction(self, db_session, sample_user): + """Test that positive interactions are tracked.""" + service = RelationshipService(db_session) + + await service.record_interaction( + user=sample_user, + guild_id=111222333, + sentiment=0.5, + message_length=100, + ) + + rel = await service.get_or_create_relationship(sample_user, guild_id=111222333) + assert rel.positive_interactions >= 1 + + def test_get_level_stranger(self): + """Test level classification for stranger.""" + service = RelationshipService(None) + assert service.get_level(10) == RelationshipLevel.STRANGER + + def test_get_level_acquaintance(self): + """Test level classification for acquaintance.""" + service = RelationshipService(None) + assert service.get_level(30) == RelationshipLevel.ACQUAINTANCE + + def test_get_level_friend(self): + """Test level classification for friend.""" + service = RelationshipService(None) + assert service.get_level(50) == RelationshipLevel.FRIEND + + def test_get_level_good_friend(self): + """Test level classification for good friend.""" + service = RelationshipService(None) + assert service.get_level(70) == RelationshipLevel.GOOD_FRIEND + + def test_get_level_close_friend(self): + """Test level classification for close friend.""" + service = RelationshipService(None) + assert service.get_level(90) == RelationshipLevel.CLOSE_FRIEND + + def test_get_level_display_name(self): + """Test getting display name for level.""" + service = RelationshipService(None) + assert service.get_level_display_name(RelationshipLevel.FRIEND) == "Friend" + + @pytest.mark.asyncio + async def test_get_relationship_info(self, db_session, sample_user_relationship, sample_user): + """Test getting relationship info.""" + service = RelationshipService(db_session) + + info = await service.get_relationship_info(sample_user, guild_id=111222333) + + assert "level" in info + assert "score" in info + assert "total_interactions" in info + + +class TestOpinionService: + """Tests for OpinionService.""" + + @pytest.mark.asyncio + async def test_get_or_create_opinion(self, db_session): + """Test getting or creating an opinion.""" + service = OpinionService(db_session) + + opinion = await service.get_or_create_opinion("programming") + + assert opinion.id is not None + assert opinion.topic == "programming" + assert opinion.sentiment == 0.0 + + @pytest.mark.asyncio + async def test_record_topic_discussion(self, db_session): + """Test recording a topic discussion.""" + service = OpinionService(db_session) + + opinion = await service.record_topic_discussion( + topic="gaming", + guild_id=None, + sentiment=0.8, + engagement_level=0.9, + ) + + assert opinion.discussion_count == 1 + assert opinion.sentiment > 0 + + @pytest.mark.asyncio + async def test_get_top_interests(self, db_session): + """Test getting top interests.""" + service = OpinionService(db_session) + + # Create some opinions with discussions + for topic in ["programming", "gaming", "music"]: + for _ in range(5): + await service.record_topic_discussion( + topic=topic, + guild_id=None, + sentiment=0.8, + engagement_level=0.9, + ) + + await db_session.commit() + + interests = await service.get_top_interests(limit=3) + + assert len(interests) <= 3 + + def test_extract_topics_gaming(self): + """Test extracting gaming topic.""" + topics = extract_topics_from_message("I love playing video games!") + assert "gaming" in topics + + def test_extract_topics_programming(self): + """Test extracting programming topic.""" + topics = extract_topics_from_message("I'm learning Python programming") + assert "programming" in topics + + def test_extract_topics_multiple(self): + """Test extracting multiple topics.""" + topics = extract_topics_from_message("I code while listening to music") + assert "programming" in topics + assert "music" in topics + + def test_extract_topics_none(self): + """Test extracting no topics.""" + topics = extract_topics_from_message("Hello, how are you?") + assert len(topics) == 0 + + +class TestPersistentConversationManager: + """Tests for PersistentConversationManager.""" + + @pytest.mark.asyncio + async def test_get_or_create_conversation_new(self, db_session, sample_user): + """Test creating a new conversation.""" + manager = PersistentConversationManager(db_session) + + conv = await manager.get_or_create_conversation( + user=sample_user, + channel_id=123456, + ) + + assert conv.id is not None + assert conv.user_id == sample_user.id + + @pytest.mark.asyncio + async def test_get_or_create_conversation_existing( + self, db_session, sample_user, sample_conversation + ): + """Test getting an existing conversation.""" + manager = PersistentConversationManager(db_session) + + conv = await manager.get_or_create_conversation( + user=sample_user, + channel_id=sample_conversation.channel_id, + ) + + assert conv.id == sample_conversation.id + + @pytest.mark.asyncio + async def test_add_message(self, db_session, sample_user, sample_conversation): + """Test adding a message.""" + manager = PersistentConversationManager(db_session) + + msg = await manager.add_message( + conversation=sample_conversation, + user=sample_user, + role="user", + content="Hello!", + ) + + assert msg.id is not None + assert msg.content == "Hello!" + + @pytest.mark.asyncio + async def test_add_exchange(self, db_session, sample_user, sample_conversation): + """Test adding a user/assistant exchange.""" + manager = PersistentConversationManager(db_session) + + user_msg, assistant_msg = await manager.add_exchange( + conversation=sample_conversation, + user=sample_user, + user_message="Hello!", + assistant_message="Hi there!", + ) + + assert user_msg.role == "user" + assert assistant_msg.role == "assistant" + + @pytest.mark.asyncio + async def test_get_history(self, db_session, sample_user, sample_conversation): + """Test getting conversation history.""" + manager = PersistentConversationManager(db_session) + + await manager.add_exchange( + conversation=sample_conversation, + user=sample_user, + user_message="Hello!", + assistant_message="Hi there!", + ) + + history = await manager.get_history(sample_conversation) + + assert len(history) == 2 + + @pytest.mark.asyncio + async def test_clear_conversation(self, db_session, sample_conversation): + """Test clearing a conversation.""" + manager = PersistentConversationManager(db_session) + + await manager.clear_conversation(sample_conversation) + + assert sample_conversation.is_active is False + + +class TestSelfAwarenessService: + """Tests for SelfAwarenessService.""" + + @pytest.mark.asyncio + async def test_get_bot_stats(self, db_session, sample_bot_state): + """Test getting bot stats.""" + service = SelfAwarenessService(db_session) + + stats = await service.get_bot_stats(guild_id=sample_bot_state.guild_id) + + assert "age_days" in stats + assert "total_messages_sent" in stats + assert "age_readable" in stats + + @pytest.mark.asyncio + async def test_get_history_with_user(self, db_session, sample_user, sample_user_relationship): + """Test getting history with a user.""" + service = SelfAwarenessService(db_session) + + history = await service.get_history_with_user(sample_user, guild_id=111222333) + + assert "days_known" in history + assert "total_interactions" in history + + @pytest.mark.asyncio + async def test_reflect_on_self(self, db_session, sample_bot_state): + """Test self reflection.""" + service = SelfAwarenessService(db_session) + + reflection = await service.reflect_on_self(guild_id=sample_bot_state.guild_id) + + assert isinstance(reflection, str) + + +class TestFactExtractionService: + """Tests for FactExtractionService.""" + + def test_is_extractable_short_message(self): + """Test that short messages are not extractable.""" + service = FactExtractionService(None) + assert service._is_extractable("hi") is False + + def test_is_extractable_greeting(self): + """Test that greetings are not extractable.""" + service = FactExtractionService(None) + assert service._is_extractable("hello") is False + + def test_is_extractable_command(self): + """Test that commands are not extractable.""" + service = FactExtractionService(None) + assert service._is_extractable("!help me with something") is False + + def test_is_extractable_valid(self): + """Test that valid messages are extractable.""" + service = FactExtractionService(None) + assert ( + service._is_extractable("I really enjoy programming in Python and building bots") + is True + ) + + def test_is_duplicate_exact_match(self): + """Test duplicate detection with exact match.""" + service = FactExtractionService(None) + existing = {"likes programming", "enjoys gaming"} + assert service._is_duplicate("likes programming", existing) is True + + def test_is_duplicate_no_match(self): + """Test duplicate detection with no match.""" + service = FactExtractionService(None) + existing = {"likes programming", "enjoys gaming"} + assert service._is_duplicate("works at a tech company", existing) is False + + def test_validate_fact_valid(self): + """Test fact validation with valid fact.""" + service = FactExtractionService(None) + fact = { + "type": "hobby", + "content": "likes programming", + "confidence": 0.9, + } + assert service._validate_fact(fact) is True + + def test_validate_fact_missing_type(self): + """Test fact validation with missing type.""" + service = FactExtractionService(None) + fact = {"content": "likes programming"} + assert service._validate_fact(fact) is False + + def test_validate_fact_invalid_type(self): + """Test fact validation with invalid type.""" + service = FactExtractionService(None) + fact = {"type": "invalid_type", "content": "test"} + assert service._validate_fact(fact) is False + + def test_validate_fact_empty_content(self): + """Test fact validation with empty content.""" + service = FactExtractionService(None) + fact = {"type": "hobby", "content": ""} + assert service._validate_fact(fact) is False + + +class TestAIService: + """Tests for AIService.""" + + def test_get_system_prompt_default(self, mock_settings): + """Test getting default system prompt.""" + with patch("daemon_boyfriend.services.ai_service.settings", mock_settings): + with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"): + service = AIService(mock_settings) + service._provider = MagicMock() + + prompt = service.get_system_prompt() + + assert "TestBot" in prompt + assert "helpful and friendly" in prompt + + def test_get_system_prompt_custom(self, mock_settings): + """Test getting custom system prompt.""" + mock_settings.system_prompt = "Custom prompt" + with patch("daemon_boyfriend.services.ai_service.settings", mock_settings): + with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"): + service = AIService(mock_settings) + service._provider = MagicMock() + + prompt = service.get_system_prompt() + + assert prompt == "Custom prompt" + + def test_provider_name(self, mock_settings): + """Test getting provider name.""" + with patch("daemon_boyfriend.services.ai_service.settings", mock_settings): + with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"): + service = AIService(mock_settings) + mock_provider = MagicMock() + mock_provider.provider_name = "openai" + service._provider = mock_provider + + assert service.provider_name == "openai" + + def test_model_property(self, mock_settings): + """Test getting model name.""" + with patch("daemon_boyfriend.services.ai_service.settings", mock_settings): + with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"): + service = AIService(mock_settings) + service._provider = MagicMock() + + assert service.model == "gpt-4o-mini"