quick adding (not working)
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test suite for Daemon Boyfriend Discord bot."""
|
||||
309
tests/conftest.py
Normal file
309
tests/conftest.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Pytest configuration and fixtures for the test suite."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, Generator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from daemon_boyfriend.config import Settings
|
||||
from daemon_boyfriend.models.base import Base
|
||||
|
||||
# --- Event Loop Fixture ---
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""Create an event loop for the test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
# --- Database Fixtures ---
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a database session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# --- Mock Settings Fixture ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> Settings:
|
||||
"""Create mock settings for testing."""
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"DISCORD_TOKEN": "test_token",
|
||||
"AI_PROVIDER": "openai",
|
||||
"AI_MODEL": "gpt-4o-mini",
|
||||
"OPENAI_API_KEY": "test_openai_key",
|
||||
"ANTHROPIC_API_KEY": "test_anthropic_key",
|
||||
"GEMINI_API_KEY": "test_gemini_key",
|
||||
"OPENROUTER_API_KEY": "test_openrouter_key",
|
||||
"BOT_NAME": "TestBot",
|
||||
"BOT_PERSONALITY": "helpful and friendly",
|
||||
"DATABASE_URL": "",
|
||||
"LIVING_AI_ENABLED": "true",
|
||||
"MOOD_ENABLED": "true",
|
||||
"RELATIONSHIP_ENABLED": "true",
|
||||
},
|
||||
):
|
||||
return Settings()
|
||||
|
||||
|
||||
# --- Mock Discord Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_user() -> MagicMock:
|
||||
"""Create a mock Discord user."""
|
||||
user = MagicMock()
|
||||
user.id = 123456789
|
||||
user.name = "TestUser"
|
||||
user.display_name = "Test User"
|
||||
user.mention = "<@123456789>"
|
||||
user.bot = False
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_message(mock_discord_user) -> MagicMock:
|
||||
"""Create a mock Discord message."""
|
||||
message = MagicMock()
|
||||
message.author = mock_discord_user
|
||||
message.content = "Hello, bot!"
|
||||
message.channel = MagicMock()
|
||||
message.channel.id = 987654321
|
||||
message.channel.send = AsyncMock()
|
||||
message.channel.typing = MagicMock(return_value=AsyncMock())
|
||||
message.guild = MagicMock()
|
||||
message.guild.id = 111222333
|
||||
message.guild.name = "Test Guild"
|
||||
message.id = 555666777
|
||||
message.mentions = []
|
||||
return message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_bot() -> MagicMock:
|
||||
"""Create a mock Discord bot."""
|
||||
bot = MagicMock()
|
||||
bot.user = MagicMock()
|
||||
bot.user.id = 999888777
|
||||
bot.user.name = "TestBot"
|
||||
bot.user.mentioned_in = MagicMock(return_value=True)
|
||||
return bot
|
||||
|
||||
|
||||
# --- Mock AI Provider Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_response() -> MagicMock:
|
||||
"""Create a mock AI response."""
|
||||
from daemon_boyfriend.services.providers.base import AIResponse
|
||||
|
||||
return AIResponse(
|
||||
content="This is a test response from the AI.",
|
||||
model="test-model",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client() -> MagicMock:
|
||||
"""Create a mock OpenAI client."""
|
||||
client = MagicMock()
|
||||
|
||||
response = MagicMock()
|
||||
response.choices = [MagicMock()]
|
||||
response.choices[0].message.content = "Test OpenAI response"
|
||||
response.model = "gpt-4o-mini"
|
||||
response.usage = MagicMock()
|
||||
response.usage.prompt_tokens = 10
|
||||
response.usage.completion_tokens = 20
|
||||
response.usage.total_tokens = 30
|
||||
|
||||
client.chat.completions.create = AsyncMock(return_value=response)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anthropic_client() -> MagicMock:
|
||||
"""Create a mock Anthropic client."""
|
||||
client = MagicMock()
|
||||
|
||||
response = MagicMock()
|
||||
response.content = [MagicMock()]
|
||||
response.content[0].type = "text"
|
||||
response.content[0].text = "Test Anthropic response"
|
||||
response.model = "claude-sonnet-4-20250514"
|
||||
response.usage = MagicMock()
|
||||
response.usage.input_tokens = 10
|
||||
response.usage.output_tokens = 20
|
||||
|
||||
client.messages.create = AsyncMock(return_value=response)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_client() -> MagicMock:
|
||||
"""Create a mock Gemini client."""
|
||||
client = MagicMock()
|
||||
|
||||
response = MagicMock()
|
||||
response.text = "Test Gemini response"
|
||||
response.usage_metadata = MagicMock()
|
||||
response.usage_metadata.prompt_token_count = 10
|
||||
response.usage_metadata.candidates_token_count = 20
|
||||
response.usage_metadata.total_token_count = 30
|
||||
|
||||
client.aio.models.generate_content = AsyncMock(return_value=response)
|
||||
return client
|
||||
|
||||
|
||||
# --- Model Fixtures ---
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_user(db_session: AsyncSession):
|
||||
"""Create a sample user in the database."""
|
||||
from daemon_boyfriend.models import User
|
||||
|
||||
user = User(
|
||||
discord_id=123456789,
|
||||
discord_username="testuser",
|
||||
discord_display_name="Test User",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_user_with_facts(db_session: AsyncSession, sample_user):
|
||||
"""Create a sample user with facts."""
|
||||
from daemon_boyfriend.models import UserFact
|
||||
|
||||
facts = [
|
||||
UserFact(
|
||||
user_id=sample_user.id,
|
||||
fact_type="hobby",
|
||||
fact_content="likes programming",
|
||||
confidence=1.0,
|
||||
source="explicit",
|
||||
),
|
||||
UserFact(
|
||||
user_id=sample_user.id,
|
||||
fact_type="preference",
|
||||
fact_content="prefers dark mode",
|
||||
confidence=0.8,
|
||||
source="conversation",
|
||||
),
|
||||
]
|
||||
|
||||
for fact in facts:
|
||||
db_session.add(fact)
|
||||
|
||||
await db_session.commit()
|
||||
return sample_user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_conversation(db_session: AsyncSession, sample_user):
|
||||
"""Create a sample conversation."""
|
||||
from daemon_boyfriend.models import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
user_id=sample_user.id,
|
||||
guild_id=111222333,
|
||||
channel_id=987654321,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_bot_state(db_session: AsyncSession):
|
||||
"""Create a sample bot state."""
|
||||
from daemon_boyfriend.models import BotState
|
||||
|
||||
bot_state = BotState(
|
||||
guild_id=111222333,
|
||||
mood_valence=0.5,
|
||||
mood_arousal=0.3,
|
||||
)
|
||||
db_session.add(bot_state)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(bot_state)
|
||||
return bot_state
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_user_relationship(db_session: AsyncSession, sample_user):
|
||||
"""Create a sample user relationship."""
|
||||
from daemon_boyfriend.models import UserRelationship
|
||||
|
||||
relationship = UserRelationship(
|
||||
user_id=sample_user.id,
|
||||
guild_id=111222333,
|
||||
relationship_score=50.0,
|
||||
total_interactions=10,
|
||||
positive_interactions=8,
|
||||
negative_interactions=1,
|
||||
)
|
||||
db_session.add(relationship)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(relationship)
|
||||
return relationship
|
||||
|
||||
|
||||
# --- Utility Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def utc_now() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(timezone.utc)
|
||||
487
tests/test_models.py
Normal file
487
tests/test_models.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Tests for database models."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from daemon_boyfriend.models import (
|
||||
BotOpinion,
|
||||
BotState,
|
||||
Conversation,
|
||||
FactAssociation,
|
||||
Guild,
|
||||
GuildMember,
|
||||
Message,
|
||||
MoodHistory,
|
||||
ScheduledEvent,
|
||||
User,
|
||||
UserCommunicationStyle,
|
||||
UserFact,
|
||||
UserPreference,
|
||||
UserRelationship,
|
||||
)
|
||||
from daemon_boyfriend.models.base import utc_now
|
||||
|
||||
|
||||
class TestUtcNow:
|
||||
"""Tests for the utc_now helper function."""
|
||||
|
||||
def test_returns_timezone_aware(self):
|
||||
"""Test that utc_now returns timezone-aware datetime."""
|
||||
now = utc_now()
|
||||
assert now.tzinfo is not None
|
||||
assert now.tzinfo == timezone.utc
|
||||
|
||||
def test_returns_current_time(self):
|
||||
"""Test that utc_now returns approximately current time."""
|
||||
before = datetime.now(timezone.utc)
|
||||
now = utc_now()
|
||||
after = datetime.now(timezone.utc)
|
||||
assert before <= now <= after
|
||||
|
||||
|
||||
class TestUserModel:
|
||||
"""Tests for the User model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(self, db_session):
|
||||
"""Test creating a user."""
|
||||
user = User(
|
||||
discord_id=123456789,
|
||||
discord_username="testuser",
|
||||
discord_display_name="Test User",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
assert user.id is not None
|
||||
assert user.discord_id == 123456789
|
||||
assert user.is_active is True
|
||||
assert user.created_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_display_name_property(self, db_session):
|
||||
"""Test the display_name property."""
|
||||
user = User(
|
||||
discord_id=123456789,
|
||||
discord_username="testuser",
|
||||
discord_display_name="Test User",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
# Uses discord_display_name when no custom_name
|
||||
assert user.display_name == "Test User"
|
||||
|
||||
# Uses custom_name when set
|
||||
user.custom_name = "Custom Name"
|
||||
assert user.display_name == "Custom Name"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_display_name_fallback(self, db_session):
|
||||
"""Test display_name falls back to username."""
|
||||
user = User(
|
||||
discord_id=123456789,
|
||||
discord_username="testuser",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
assert user.display_name == "testuser"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_timestamps(self, db_session):
|
||||
"""Test user timestamp fields."""
|
||||
user = User(
|
||||
discord_id=123456789,
|
||||
discord_username="testuser",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
|
||||
assert user.first_seen_at is not None
|
||||
assert user.last_seen_at is not None
|
||||
assert user.created_at is not None
|
||||
assert user.updated_at is not None
|
||||
|
||||
|
||||
class TestUserFactModel:
|
||||
"""Tests for the UserFact model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fact(self, db_session, sample_user):
|
||||
"""Test creating a user fact."""
|
||||
fact = UserFact(
|
||||
user_id=sample_user.id,
|
||||
fact_type="hobby",
|
||||
fact_content="likes gaming",
|
||||
confidence=0.9,
|
||||
source="conversation",
|
||||
)
|
||||
db_session.add(fact)
|
||||
await db_session.commit()
|
||||
|
||||
assert fact.id is not None
|
||||
assert fact.is_active is True
|
||||
assert fact.learned_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fact_default_values(self, db_session, sample_user):
|
||||
"""Test fact default values."""
|
||||
fact = UserFact(
|
||||
user_id=sample_user.id,
|
||||
fact_type="general",
|
||||
fact_content="test fact",
|
||||
)
|
||||
db_session.add(fact)
|
||||
await db_session.commit()
|
||||
|
||||
assert fact.confidence == 1.0
|
||||
assert fact.source == "conversation"
|
||||
assert fact.is_active is True
|
||||
|
||||
|
||||
class TestConversationModel:
|
||||
"""Tests for the Conversation model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_conversation(self, db_session, sample_user):
|
||||
"""Test creating a conversation."""
|
||||
conv = Conversation(
|
||||
user_id=sample_user.id,
|
||||
guild_id=111222333,
|
||||
channel_id=444555666,
|
||||
)
|
||||
db_session.add(conv)
|
||||
await db_session.commit()
|
||||
|
||||
assert conv.id is not None
|
||||
assert conv.message_count == 0
|
||||
assert conv.is_active is True
|
||||
assert conv.started_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_with_messages(self, db_session, sample_user):
|
||||
"""Test conversation with messages."""
|
||||
conv = Conversation(
|
||||
user_id=sample_user.id,
|
||||
channel_id=444555666,
|
||||
)
|
||||
db_session.add(conv)
|
||||
await db_session.commit()
|
||||
|
||||
msg = Message(
|
||||
conversation_id=conv.id,
|
||||
user_id=sample_user.id,
|
||||
role="user",
|
||||
content="Hello!",
|
||||
)
|
||||
db_session.add(msg)
|
||||
await db_session.commit()
|
||||
|
||||
assert msg.id is not None
|
||||
assert msg.role == "user"
|
||||
assert msg.has_images is False
|
||||
|
||||
|
||||
class TestMessageModel:
|
||||
"""Tests for the Message model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_message(self, db_session, sample_conversation, sample_user):
|
||||
"""Test creating a message."""
|
||||
msg = Message(
|
||||
conversation_id=sample_conversation.id,
|
||||
user_id=sample_user.id,
|
||||
role="user",
|
||||
content="Test message",
|
||||
)
|
||||
db_session.add(msg)
|
||||
await db_session.commit()
|
||||
|
||||
assert msg.id is not None
|
||||
assert msg.has_images is False
|
||||
assert msg.image_urls is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_with_images(self, db_session, sample_conversation, sample_user):
|
||||
"""Test message with images."""
|
||||
msg = Message(
|
||||
conversation_id=sample_conversation.id,
|
||||
user_id=sample_user.id,
|
||||
role="user",
|
||||
content="Look at this",
|
||||
has_images=True,
|
||||
image_urls=["https://example.com/image.png"],
|
||||
)
|
||||
db_session.add(msg)
|
||||
await db_session.commit()
|
||||
|
||||
assert msg.has_images is True
|
||||
assert len(msg.image_urls) == 1
|
||||
|
||||
|
||||
class TestGuildModel:
|
||||
"""Tests for the Guild model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_guild(self, db_session):
|
||||
"""Test creating a guild."""
|
||||
guild = Guild(
|
||||
discord_id=111222333,
|
||||
name="Test Guild",
|
||||
)
|
||||
db_session.add(guild)
|
||||
await db_session.commit()
|
||||
|
||||
assert guild.id is not None
|
||||
assert guild.is_active is True
|
||||
assert guild.settings == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_with_settings(self, db_session):
|
||||
"""Test guild with custom settings."""
|
||||
guild = Guild(
|
||||
discord_id=111222333,
|
||||
name="Test Guild",
|
||||
settings={"prefix": "!", "language": "en"},
|
||||
)
|
||||
db_session.add(guild)
|
||||
await db_session.commit()
|
||||
|
||||
assert guild.settings["prefix"] == "!"
|
||||
assert guild.settings["language"] == "en"
|
||||
|
||||
|
||||
class TestGuildMemberModel:
|
||||
"""Tests for the GuildMember model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_guild_member(self, db_session, sample_user):
|
||||
"""Test creating a guild member."""
|
||||
guild = Guild(discord_id=111222333, name="Test Guild")
|
||||
db_session.add(guild)
|
||||
await db_session.commit()
|
||||
|
||||
member = GuildMember(
|
||||
guild_id=guild.id,
|
||||
user_id=sample_user.id,
|
||||
guild_nickname="TestNick",
|
||||
)
|
||||
db_session.add(member)
|
||||
await db_session.commit()
|
||||
|
||||
assert member.id is not None
|
||||
assert member.guild_nickname == "TestNick"
|
||||
|
||||
|
||||
class TestBotStateModel:
|
||||
"""Tests for the BotState model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bot_state(self, db_session):
|
||||
"""Test creating a bot state."""
|
||||
state = BotState(guild_id=111222333)
|
||||
db_session.add(state)
|
||||
await db_session.commit()
|
||||
|
||||
assert state.id is not None
|
||||
assert state.mood_valence == 0.0
|
||||
assert state.mood_arousal == 0.0
|
||||
assert state.total_messages_sent == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bot_state_defaults(self, db_session):
|
||||
"""Test bot state default values."""
|
||||
state = BotState()
|
||||
db_session.add(state)
|
||||
await db_session.commit()
|
||||
|
||||
assert state.guild_id is None
|
||||
assert state.preferences == {}
|
||||
assert state.total_facts_learned == 0
|
||||
assert state.total_users_known == 0
|
||||
|
||||
|
||||
class TestBotOpinionModel:
|
||||
"""Tests for the BotOpinion model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_opinion(self, db_session):
|
||||
"""Test creating a bot opinion."""
|
||||
opinion = BotOpinion(
|
||||
topic="programming",
|
||||
sentiment=0.8,
|
||||
interest_level=0.9,
|
||||
)
|
||||
db_session.add(opinion)
|
||||
await db_session.commit()
|
||||
|
||||
assert opinion.id is not None
|
||||
assert opinion.discussion_count == 0
|
||||
assert opinion.formed_at is not None
|
||||
|
||||
|
||||
class TestUserRelationshipModel:
|
||||
"""Tests for the UserRelationship model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_relationship(self, db_session, sample_user):
|
||||
"""Test creating a user relationship."""
|
||||
rel = UserRelationship(
|
||||
user_id=sample_user.id,
|
||||
guild_id=111222333,
|
||||
)
|
||||
db_session.add(rel)
|
||||
await db_session.commit()
|
||||
|
||||
assert rel.id is not None
|
||||
assert rel.relationship_score == 10.0
|
||||
assert rel.total_interactions == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relationship_defaults(self, db_session, sample_user):
|
||||
"""Test relationship default values."""
|
||||
rel = UserRelationship(user_id=sample_user.id)
|
||||
db_session.add(rel)
|
||||
await db_session.commit()
|
||||
|
||||
assert rel.shared_references == {}
|
||||
assert rel.positive_interactions == 0
|
||||
assert rel.negative_interactions == 0
|
||||
assert rel.avg_message_length == 0.0
|
||||
|
||||
|
||||
class TestUserCommunicationStyleModel:
|
||||
"""Tests for the UserCommunicationStyle model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_style(self, db_session, sample_user):
|
||||
"""Test creating a communication style."""
|
||||
style = UserCommunicationStyle(user_id=sample_user.id)
|
||||
db_session.add(style)
|
||||
await db_session.commit()
|
||||
|
||||
assert style.id is not None
|
||||
assert style.preferred_length == "medium"
|
||||
assert style.preferred_formality == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_style_defaults(self, db_session, sample_user):
|
||||
"""Test communication style defaults."""
|
||||
style = UserCommunicationStyle(user_id=sample_user.id)
|
||||
db_session.add(style)
|
||||
await db_session.commit()
|
||||
|
||||
assert style.emoji_affinity == 0.5
|
||||
assert style.humor_affinity == 0.5
|
||||
assert style.detail_preference == 0.5
|
||||
assert style.samples_collected == 0
|
||||
assert style.confidence == 0.0
|
||||
|
||||
|
||||
class TestScheduledEventModel:
|
||||
"""Tests for the ScheduledEvent model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_event(self, db_session, sample_user):
|
||||
"""Test creating a scheduled event."""
|
||||
trigger_time = datetime.now(timezone.utc) + timedelta(days=1)
|
||||
event = ScheduledEvent(
|
||||
user_id=sample_user.id,
|
||||
event_type="birthday",
|
||||
trigger_at=trigger_time,
|
||||
title="Birthday reminder",
|
||||
)
|
||||
db_session.add(event)
|
||||
await db_session.commit()
|
||||
|
||||
assert event.id is not None
|
||||
assert event.status == "pending"
|
||||
assert event.is_recurring is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recurring_event(self, db_session, sample_user):
|
||||
"""Test creating a recurring event."""
|
||||
trigger_time = datetime.now(timezone.utc) + timedelta(days=1)
|
||||
event = ScheduledEvent(
|
||||
user_id=sample_user.id,
|
||||
event_type="birthday",
|
||||
trigger_at=trigger_time,
|
||||
title="Birthday",
|
||||
is_recurring=True,
|
||||
recurrence_rule="yearly",
|
||||
)
|
||||
db_session.add(event)
|
||||
await db_session.commit()
|
||||
|
||||
assert event.is_recurring is True
|
||||
assert event.recurrence_rule == "yearly"
|
||||
|
||||
|
||||
class TestFactAssociationModel:
|
||||
"""Tests for the FactAssociation model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_association(self, db_session, sample_user):
|
||||
"""Test creating a fact association."""
|
||||
fact1 = UserFact(
|
||||
user_id=sample_user.id,
|
||||
fact_type="hobby",
|
||||
fact_content="likes programming",
|
||||
)
|
||||
fact2 = UserFact(
|
||||
user_id=sample_user.id,
|
||||
fact_type="hobby",
|
||||
fact_content="likes Python",
|
||||
)
|
||||
db_session.add(fact1)
|
||||
db_session.add(fact2)
|
||||
await db_session.commit()
|
||||
|
||||
assoc = FactAssociation(
|
||||
fact_id_1=fact1.id,
|
||||
fact_id_2=fact2.id,
|
||||
association_type="shared_interest",
|
||||
strength=0.8,
|
||||
)
|
||||
db_session.add(assoc)
|
||||
await db_session.commit()
|
||||
|
||||
assert assoc.id is not None
|
||||
assert assoc.discovered_at is not None
|
||||
|
||||
|
||||
class TestMoodHistoryModel:
|
||||
"""Tests for the MoodHistory model."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mood_history(self, db_session, sample_user):
|
||||
"""Test creating a mood history entry."""
|
||||
history = MoodHistory(
|
||||
guild_id=111222333,
|
||||
valence=0.5,
|
||||
arousal=0.3,
|
||||
trigger_type="conversation",
|
||||
trigger_user_id=sample_user.id,
|
||||
trigger_description="Had a nice chat",
|
||||
)
|
||||
db_session.add(history)
|
||||
await db_session.commit()
|
||||
|
||||
assert history.id is not None
|
||||
assert history.recorded_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mood_history_without_user(self, db_session):
|
||||
"""Test mood history without trigger user."""
|
||||
history = MoodHistory(
|
||||
valence=-0.2,
|
||||
arousal=-0.1,
|
||||
trigger_type="time_decay",
|
||||
)
|
||||
db_session.add(history)
|
||||
await db_session.commit()
|
||||
|
||||
assert history.trigger_user_id is None
|
||||
assert history.trigger_description is None
|
||||
290
tests/test_providers.py
Normal file
290
tests/test_providers.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Tests for AI provider implementations."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from daemon_boyfriend.services.providers.base import (
|
||||
AIProvider,
|
||||
AIResponse,
|
||||
Message,
|
||||
ImageAttachment,
|
||||
)
|
||||
from daemon_boyfriend.services.providers.openai import OpenAIProvider
|
||||
from daemon_boyfriend.services.providers.anthropic import AnthropicProvider
|
||||
from daemon_boyfriend.services.providers.gemini import GeminiProvider
|
||||
from daemon_boyfriend.services.providers.openrouter import OpenRouterProvider
|
||||
|
||||
|
||||
class TestMessage:
|
||||
"""Tests for the Message dataclass."""
|
||||
|
||||
def test_message_creation(self):
|
||||
"""Test creating a basic message."""
|
||||
msg = Message(role="user", content="Hello")
|
||||
assert msg.role == "user"
|
||||
assert msg.content == "Hello"
|
||||
assert msg.images == []
|
||||
|
||||
def test_message_with_images(self):
|
||||
"""Test creating a message with images."""
|
||||
images = [ImageAttachment(url="https://example.com/image.png")]
|
||||
msg = Message(role="user", content="Look at this", images=images)
|
||||
assert len(msg.images) == 1
|
||||
assert msg.images[0].url == "https://example.com/image.png"
|
||||
|
||||
|
||||
class TestImageAttachment:
|
||||
"""Tests for the ImageAttachment dataclass."""
|
||||
|
||||
def test_default_media_type(self):
|
||||
"""Test default media type."""
|
||||
img = ImageAttachment(url="https://example.com/image.png")
|
||||
assert img.media_type == "image/png"
|
||||
|
||||
def test_custom_media_type(self):
|
||||
"""Test custom media type."""
|
||||
img = ImageAttachment(url="https://example.com/image.jpg", media_type="image/jpeg")
|
||||
assert img.media_type == "image/jpeg"
|
||||
|
||||
|
||||
class TestAIResponse:
|
||||
"""Tests for the AIResponse dataclass."""
|
||||
|
||||
def test_response_creation(self):
|
||||
"""Test creating an AI response."""
|
||||
response = AIResponse(
|
||||
content="Hello!",
|
||||
model="test-model",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
assert response.content == "Hello!"
|
||||
assert response.model == "test-model"
|
||||
assert response.usage["total_tokens"] == 15
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
"""Tests for the OpenAI provider."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_openai_client):
|
||||
"""Create an OpenAI provider with mocked client."""
|
||||
with patch("daemon_boyfriend.services.providers.openai.AsyncOpenAI") as mock_class:
|
||||
mock_class.return_value = mock_openai_client
|
||||
provider = OpenAIProvider(api_key="test_key", model="gpt-4o-mini")
|
||||
provider.client = mock_openai_client
|
||||
return provider
|
||||
|
||||
def test_provider_name(self, provider):
|
||||
"""Test provider name."""
|
||||
assert provider.provider_name == "openai"
|
||||
|
||||
def test_model_setting(self, provider):
|
||||
"""Test model is set correctly."""
|
||||
assert provider.model == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_simple_message(self, provider, mock_openai_client):
|
||||
"""Test generating a response with a simple message."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
response = await provider.generate(messages)
|
||||
|
||||
assert response.content == "Test OpenAI response"
|
||||
assert response.model == "gpt-4o-mini"
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_system_prompt(self, provider, mock_openai_client):
|
||||
"""Test generating a response with a system prompt."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
await provider.generate(messages, system_prompt="You are a helpful assistant.")
|
||||
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
api_messages = call_args.kwargs["messages"]
|
||||
assert api_messages[0]["role"] == "system"
|
||||
assert api_messages[0]["content"] == "You are a helpful assistant."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_images(self, provider, mock_openai_client):
|
||||
"""Test generating a response with images."""
|
||||
images = [ImageAttachment(url="https://example.com/image.png")]
|
||||
messages = [Message(role="user", content="What's in this image?", images=images)]
|
||||
|
||||
await provider.generate(messages)
|
||||
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
api_messages = call_args.kwargs["messages"]
|
||||
content = api_messages[0]["content"]
|
||||
assert isinstance(content, list)
|
||||
assert content[0]["type"] == "text"
|
||||
assert content[1]["type"] == "image_url"
|
||||
|
||||
def test_build_message_content_no_images(self, provider):
|
||||
"""Test building message content without images."""
|
||||
msg = Message(role="user", content="Hello")
|
||||
content = provider._build_message_content(msg)
|
||||
assert content == "Hello"
|
||||
|
||||
def test_build_message_content_with_images(self, provider):
|
||||
"""Test building message content with images."""
|
||||
images = [ImageAttachment(url="https://example.com/image.png")]
|
||||
msg = Message(role="user", content="Hello", images=images)
|
||||
content = provider._build_message_content(msg)
|
||||
assert isinstance(content, list)
|
||||
assert len(content) == 2
|
||||
|
||||
|
||||
class TestAnthropicProvider:
|
||||
"""Tests for the Anthropic provider."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_anthropic_client):
|
||||
"""Create an Anthropic provider with mocked client."""
|
||||
with patch(
|
||||
"daemon_boyfriend.services.providers.anthropic.anthropic.AsyncAnthropic"
|
||||
) as mock_class:
|
||||
"""Tests for the Anthropic provider."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_anthropic_client):
|
||||
"""Create an Anthropic provider with mocked client."""
|
||||
with patch("daemon_boyfriend.services.providers.anthropic.anthropic.AsyncAnthropic") as mock_class:
|
||||
mock_class.return_value = mock_anthropic_client
|
||||
provider = AnthropicProvider(api_key="test_key", model="claude-sonnet-4-20250514")
|
||||
provider.client = mock_anthropic_client
|
||||
return provider
|
||||
|
||||
def test_provider_name(self, provider):
|
||||
"""Test provider name."""
|
||||
assert provider.provider_name == "anthropic"
|
||||
|
||||
def test_model_setting(self, provider):
|
||||
"""Test model is set correctly."""
|
||||
assert provider.model == "claude-sonnet-4-20250514"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_simple_message(self, provider, mock_anthropic_client):
|
||||
"""Test generating a response with a simple message."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
response = await provider.generate(messages)
|
||||
|
||||
assert response.content == "Test Anthropic response"
|
||||
mock_anthropic_client.messages.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_system_prompt(self, provider, mock_anthropic_client):
|
||||
"""Test generating a response with a system prompt."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
await provider.generate(messages, system_prompt="You are a helpful assistant.")
|
||||
|
||||
call_args = mock_anthropic_client.messages.create.call_args
|
||||
assert call_args.kwargs["system"] == "You are a helpful assistant."
|
||||
|
||||
def test_build_message_content_no_images(self, provider):
|
||||
"""Test building message content without images."""
|
||||
msg = Message(role="user", content="Hello")
|
||||
content = provider._build_message_content(msg)
|
||||
assert content == "Hello"
|
||||
|
||||
def test_build_message_content_with_images(self, provider):
|
||||
"""Test building message content with images."""
|
||||
images = [ImageAttachment(url="https://example.com/image.png")]
|
||||
msg = Message(role="user", content="Hello", images=images)
|
||||
content = provider._build_message_content(msg)
|
||||
assert isinstance(content, list)
|
||||
assert content[0]["type"] == "image"
|
||||
assert content[1]["type"] == "text"
|
||||
|
||||
|
||||
class TestGeminiProvider:
|
||||
"""Tests for the Gemini provider."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_gemini_client):
|
||||
"""Create a Gemini provider with mocked client."""
|
||||
with patch("daemon_boyfriend.services.providers.gemini.genai.Client") as mock_class:
|
||||
mock_class.return_value = mock_gemini_client
|
||||
provider = GeminiProvider(api_key="test_key", model="gemini-2.0-flash")
|
||||
provider.client = mock_gemini_client
|
||||
return provider
|
||||
|
||||
def test_provider_name(self, provider):
|
||||
"""Test provider name."""
|
||||
assert provider.provider_name == "gemini"
|
||||
|
||||
def test_model_setting(self, provider):
|
||||
"""Test model is set correctly."""
|
||||
assert provider.model == "gemini-2.0-flash"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_simple_message(self, provider, mock_gemini_client):
|
||||
"""Test generating a response with a simple message."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
response = await provider.generate(messages)
|
||||
|
||||
assert response.content == "Test Gemini response"
|
||||
mock_gemini_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_mapping(self, provider, mock_gemini_client):
|
||||
"""Test that 'assistant' role is mapped to 'model'."""
|
||||
messages = [
|
||||
Message(role="user", content="Hello"),
|
||||
Message(role="assistant", content="Hi there!"),
|
||||
Message(role="user", content="How are you?"),
|
||||
]
|
||||
|
||||
await provider.generate(messages)
|
||||
|
||||
call_args = mock_gemini_client.aio.models.generate_content.call_args
|
||||
contents = call_args.kwargs["contents"]
|
||||
assert contents[0].role == "user"
|
||||
assert contents[1].role == "model"
|
||||
assert contents[2].role == "user"
|
||||
|
||||
|
||||
class TestOpenRouterProvider:
|
||||
"""Tests for the OpenRouter provider."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self, mock_openai_client):
|
||||
"""Create an OpenRouter provider with mocked client."""
|
||||
with patch("daemon_boyfriend.services.providers.openrouter.AsyncOpenAI") as mock_class:
|
||||
mock_class.return_value = mock_openai_client
|
||||
provider = OpenRouterProvider(api_key="test_key", model="openai/gpt-4o")
|
||||
provider.client = mock_openai_client
|
||||
return provider
|
||||
|
||||
def test_provider_name(self, provider):
|
||||
"""Test provider name."""
|
||||
assert provider.provider_name == "openrouter"
|
||||
|
||||
def test_model_setting(self, provider):
|
||||
"""Test model is set correctly."""
|
||||
assert provider.model == "openai/gpt-4o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_simple_message(self, provider, mock_openai_client):
|
||||
"""Test generating a response with a simple message."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
response = await provider.generate(messages)
|
||||
|
||||
assert response.content == "Test OpenAI response"
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_headers(self, provider, mock_openai_client):
|
||||
"""Test that OpenRouter-specific headers are included."""
|
||||
messages = [Message(role="user", content="Hello")]
|
||||
|
||||
await provider.generate(messages)
|
||||
|
||||
call_args = mock_openai_client.chat.completions.create.call_args
|
||||
extra_headers = call_args.kwargs.get("extra_headers", {})
|
||||
assert "HTTP-Referer" in extra_headers
|
||||
assert "X-Title" in extra_headers
|
||||
620
tests/test_services.py
Normal file
620
tests/test_services.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""Tests for service layer."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from daemon_boyfriend.models import (
|
||||
BotOpinion,
|
||||
BotState,
|
||||
Conversation,
|
||||
Message,
|
||||
User,
|
||||
UserFact,
|
||||
UserRelationship,
|
||||
)
|
||||
from daemon_boyfriend.services.ai_service import AIService
|
||||
from daemon_boyfriend.services.fact_extraction_service import FactExtractionService
|
||||
from daemon_boyfriend.services.mood_service import MoodLabel, MoodService, MoodState
|
||||
from daemon_boyfriend.services.opinion_service import OpinionService, extract_topics_from_message
|
||||
from daemon_boyfriend.services.persistent_conversation import PersistentConversationManager
|
||||
from daemon_boyfriend.services.relationship_service import RelationshipLevel, RelationshipService
|
||||
from daemon_boyfriend.services.self_awareness_service import SelfAwarenessService
|
||||
from daemon_boyfriend.services.user_service import UserService
|
||||
|
||||
|
||||
class TestUserService:
|
||||
"""Tests for UserService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_new(self, db_session):
|
||||
"""Test creating a new user."""
|
||||
service = UserService(db_session)
|
||||
|
||||
user = await service.get_or_create_user(
|
||||
discord_id=123456789,
|
||||
username="testuser",
|
||||
display_name="Test User",
|
||||
)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.discord_id == 123456789
|
||||
assert user.discord_username == "testuser"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_existing(self, db_session, sample_user):
|
||||
"""Test getting an existing user."""
|
||||
service = UserService(db_session)
|
||||
|
||||
user = await service.get_or_create_user(
|
||||
discord_id=sample_user.discord_id,
|
||||
username="newname",
|
||||
display_name="New Display",
|
||||
)
|
||||
|
||||
assert user.id == sample_user.id
|
||||
assert user.discord_username == "newname"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_custom_name(self, db_session, sample_user):
|
||||
"""Test setting a custom name."""
|
||||
service = UserService(db_session)
|
||||
|
||||
user = await service.set_custom_name(sample_user.discord_id, "CustomName")
|
||||
|
||||
assert user.custom_name == "CustomName"
|
||||
assert user.display_name == "CustomName"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_custom_name(self, db_session, sample_user):
|
||||
"""Test clearing a custom name."""
|
||||
service = UserService(db_session)
|
||||
sample_user.custom_name = "OldName"
|
||||
|
||||
user = await service.set_custom_name(sample_user.discord_id, None)
|
||||
|
||||
assert user.custom_name is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_fact(self, db_session, sample_user):
|
||||
"""Test adding a fact about a user."""
|
||||
service = UserService(db_session)
|
||||
|
||||
fact = await service.add_fact(
|
||||
user=sample_user,
|
||||
fact_type="hobby",
|
||||
fact_content="likes programming",
|
||||
)
|
||||
|
||||
assert fact.id is not None
|
||||
assert fact.user_id == sample_user.id
|
||||
assert fact.fact_type == "hobby"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_facts(self, db_session, sample_user_with_facts):
|
||||
"""Test getting user facts."""
|
||||
service = UserService(db_session)
|
||||
|
||||
facts = await service.get_user_facts(sample_user_with_facts)
|
||||
|
||||
assert len(facts) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_facts_by_type(self, db_session, sample_user_with_facts):
|
||||
"""Test getting user facts by type."""
|
||||
service = UserService(db_session)
|
||||
|
||||
facts = await service.get_user_facts(sample_user_with_facts, fact_type="hobby")
|
||||
|
||||
assert len(facts) == 1
|
||||
assert facts[0].fact_type == "hobby"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_facts(self, db_session, sample_user_with_facts):
|
||||
"""Test deleting user facts."""
|
||||
service = UserService(db_session)
|
||||
|
||||
count = await service.delete_user_facts(sample_user_with_facts)
|
||||
|
||||
assert count == 2
|
||||
facts = await service.get_user_facts(sample_user_with_facts, active_only=True)
|
||||
assert len(facts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_context(self, db_session, sample_user_with_facts):
|
||||
"""Test getting user context string."""
|
||||
service = UserService(db_session)
|
||||
|
||||
context = await service.get_user_context(sample_user_with_facts)
|
||||
|
||||
assert "Test User" in context
|
||||
assert "likes programming" in context
|
||||
|
||||
|
||||
class TestMoodService:
|
||||
"""Tests for MoodService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_bot_state(self, db_session):
|
||||
"""Test getting or creating bot state."""
|
||||
service = MoodService(db_session)
|
||||
|
||||
state = await service.get_or_create_bot_state(guild_id=111222333)
|
||||
|
||||
assert state.id is not None
|
||||
assert state.guild_id == 111222333
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_mood(self, db_session, sample_bot_state):
|
||||
"""Test getting current mood."""
|
||||
service = MoodService(db_session)
|
||||
|
||||
mood = await service.get_current_mood(guild_id=sample_bot_state.guild_id)
|
||||
|
||||
assert isinstance(mood, MoodState)
|
||||
assert -1.0 <= mood.valence <= 1.0
|
||||
assert -1.0 <= mood.arousal <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_mood(self, db_session, sample_bot_state):
|
||||
"""Test updating mood."""
|
||||
service = MoodService(db_session)
|
||||
|
||||
new_mood = await service.update_mood(
|
||||
guild_id=sample_bot_state.guild_id,
|
||||
sentiment_delta=0.5,
|
||||
engagement_delta=0.3,
|
||||
trigger_type="conversation",
|
||||
trigger_description="Had a nice chat",
|
||||
)
|
||||
|
||||
assert new_mood.valence > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_stats(self, db_session, sample_bot_state):
|
||||
"""Test incrementing bot stats."""
|
||||
service = MoodService(db_session)
|
||||
initial_messages = sample_bot_state.total_messages_sent
|
||||
|
||||
await service.increment_stats(
|
||||
guild_id=sample_bot_state.guild_id,
|
||||
messages_sent=5,
|
||||
facts_learned=2,
|
||||
)
|
||||
|
||||
assert sample_bot_state.total_messages_sent == initial_messages + 5
|
||||
|
||||
def test_classify_mood_excited(self):
|
||||
"""Test mood classification for excited."""
|
||||
service = MoodService(None)
|
||||
label = service._classify_mood(0.5, 0.5)
|
||||
assert label == MoodLabel.EXCITED
|
||||
|
||||
def test_classify_mood_happy(self):
|
||||
"""Test mood classification for happy."""
|
||||
service = MoodService(None)
|
||||
label = service._classify_mood(0.5, 0.0)
|
||||
assert label == MoodLabel.HAPPY
|
||||
|
||||
def test_classify_mood_bored(self):
|
||||
"""Test mood classification for bored."""
|
||||
service = MoodService(None)
|
||||
label = service._classify_mood(-0.5, 0.0)
|
||||
assert label == MoodLabel.BORED
|
||||
|
||||
def test_classify_mood_annoyed(self):
|
||||
"""Test mood classification for annoyed."""
|
||||
service = MoodService(None)
|
||||
label = service._classify_mood(-0.5, 0.5)
|
||||
assert label == MoodLabel.ANNOYED
|
||||
|
||||
def test_get_mood_prompt_modifier(self):
|
||||
"""Test getting mood prompt modifier."""
|
||||
service = MoodService(None)
|
||||
mood = MoodState(valence=0.8, arousal=0.8, label=MoodLabel.EXCITED, intensity=0.8)
|
||||
|
||||
modifier = service.get_mood_prompt_modifier(mood)
|
||||
|
||||
assert "enthusiastic" in modifier.lower() or "excited" in modifier.lower()
|
||||
|
||||
def test_get_mood_prompt_modifier_low_intensity(self):
|
||||
"""Test mood modifier with low intensity."""
|
||||
service = MoodService(None)
|
||||
mood = MoodState(valence=0.1, arousal=0.1, label=MoodLabel.NEUTRAL, intensity=0.1)
|
||||
|
||||
modifier = service.get_mood_prompt_modifier(mood)
|
||||
|
||||
assert modifier == ""
|
||||
|
||||
|
||||
class TestRelationshipService:
|
||||
"""Tests for RelationshipService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_relationship(self, db_session, sample_user):
|
||||
"""Test getting or creating a relationship."""
|
||||
service = RelationshipService(db_session)
|
||||
|
||||
rel = await service.get_or_create_relationship(sample_user, guild_id=111222333)
|
||||
|
||||
assert rel.id is not None
|
||||
assert rel.user_id == sample_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_interaction(self, db_session, sample_user):
|
||||
"""Test recording an interaction."""
|
||||
service = RelationshipService(db_session)
|
||||
|
||||
level = await service.record_interaction(
|
||||
user=sample_user,
|
||||
guild_id=111222333,
|
||||
sentiment=0.8,
|
||||
message_length=100,
|
||||
conversation_turns=3,
|
||||
)
|
||||
|
||||
assert isinstance(level, RelationshipLevel)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_positive_interaction(self, db_session, sample_user):
|
||||
"""Test that positive interactions are tracked."""
|
||||
service = RelationshipService(db_session)
|
||||
|
||||
await service.record_interaction(
|
||||
user=sample_user,
|
||||
guild_id=111222333,
|
||||
sentiment=0.5,
|
||||
message_length=100,
|
||||
)
|
||||
|
||||
rel = await service.get_or_create_relationship(sample_user, guild_id=111222333)
|
||||
assert rel.positive_interactions >= 1
|
||||
|
||||
def test_get_level_stranger(self):
|
||||
"""Test level classification for stranger."""
|
||||
service = RelationshipService(None)
|
||||
assert service.get_level(10) == RelationshipLevel.STRANGER
|
||||
|
||||
def test_get_level_acquaintance(self):
|
||||
"""Test level classification for acquaintance."""
|
||||
service = RelationshipService(None)
|
||||
assert service.get_level(30) == RelationshipLevel.ACQUAINTANCE
|
||||
|
||||
def test_get_level_friend(self):
|
||||
"""Test level classification for friend."""
|
||||
service = RelationshipService(None)
|
||||
assert service.get_level(50) == RelationshipLevel.FRIEND
|
||||
|
||||
def test_get_level_good_friend(self):
|
||||
"""Test level classification for good friend."""
|
||||
service = RelationshipService(None)
|
||||
assert service.get_level(70) == RelationshipLevel.GOOD_FRIEND
|
||||
|
||||
def test_get_level_close_friend(self):
|
||||
"""Test level classification for close friend."""
|
||||
service = RelationshipService(None)
|
||||
assert service.get_level(90) == RelationshipLevel.CLOSE_FRIEND
|
||||
|
||||
def test_get_level_display_name(self):
|
||||
"""Test getting display name for level."""
|
||||
service = RelationshipService(None)
|
||||
assert service.get_level_display_name(RelationshipLevel.FRIEND) == "Friend"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_relationship_info(self, db_session, sample_user_relationship, sample_user):
|
||||
"""Test getting relationship info."""
|
||||
service = RelationshipService(db_session)
|
||||
|
||||
info = await service.get_relationship_info(sample_user, guild_id=111222333)
|
||||
|
||||
assert "level" in info
|
||||
assert "score" in info
|
||||
assert "total_interactions" in info
|
||||
|
||||
|
||||
class TestOpinionService:
|
||||
"""Tests for OpinionService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_opinion(self, db_session):
|
||||
"""Test getting or creating an opinion."""
|
||||
service = OpinionService(db_session)
|
||||
|
||||
opinion = await service.get_or_create_opinion("programming")
|
||||
|
||||
assert opinion.id is not None
|
||||
assert opinion.topic == "programming"
|
||||
assert opinion.sentiment == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_topic_discussion(self, db_session):
|
||||
"""Test recording a topic discussion."""
|
||||
service = OpinionService(db_session)
|
||||
|
||||
opinion = await service.record_topic_discussion(
|
||||
topic="gaming",
|
||||
guild_id=None,
|
||||
sentiment=0.8,
|
||||
engagement_level=0.9,
|
||||
)
|
||||
|
||||
assert opinion.discussion_count == 1
|
||||
assert opinion.sentiment > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_top_interests(self, db_session):
|
||||
"""Test getting top interests."""
|
||||
service = OpinionService(db_session)
|
||||
|
||||
# Create some opinions with discussions
|
||||
for topic in ["programming", "gaming", "music"]:
|
||||
for _ in range(5):
|
||||
await service.record_topic_discussion(
|
||||
topic=topic,
|
||||
guild_id=None,
|
||||
sentiment=0.8,
|
||||
engagement_level=0.9,
|
||||
)
|
||||
|
||||
await db_session.commit()
|
||||
|
||||
interests = await service.get_top_interests(limit=3)
|
||||
|
||||
assert len(interests) <= 3
|
||||
|
||||
def test_extract_topics_gaming(self):
|
||||
"""Test extracting gaming topic."""
|
||||
topics = extract_topics_from_message("I love playing video games!")
|
||||
assert "gaming" in topics
|
||||
|
||||
def test_extract_topics_programming(self):
|
||||
"""Test extracting programming topic."""
|
||||
topics = extract_topics_from_message("I'm learning Python programming")
|
||||
assert "programming" in topics
|
||||
|
||||
def test_extract_topics_multiple(self):
|
||||
"""Test extracting multiple topics."""
|
||||
topics = extract_topics_from_message("I code while listening to music")
|
||||
assert "programming" in topics
|
||||
assert "music" in topics
|
||||
|
||||
def test_extract_topics_none(self):
|
||||
"""Test extracting no topics."""
|
||||
topics = extract_topics_from_message("Hello, how are you?")
|
||||
assert len(topics) == 0
|
||||
|
||||
|
||||
class TestPersistentConversationManager:
|
||||
"""Tests for PersistentConversationManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_conversation_new(self, db_session, sample_user):
|
||||
"""Test creating a new conversation."""
|
||||
manager = PersistentConversationManager(db_session)
|
||||
|
||||
conv = await manager.get_or_create_conversation(
|
||||
user=sample_user,
|
||||
channel_id=123456,
|
||||
)
|
||||
|
||||
assert conv.id is not None
|
||||
assert conv.user_id == sample_user.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_conversation_existing(
|
||||
self, db_session, sample_user, sample_conversation
|
||||
):
|
||||
"""Test getting an existing conversation."""
|
||||
manager = PersistentConversationManager(db_session)
|
||||
|
||||
conv = await manager.get_or_create_conversation(
|
||||
user=sample_user,
|
||||
channel_id=sample_conversation.channel_id,
|
||||
)
|
||||
|
||||
assert conv.id == sample_conversation.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_message(self, db_session, sample_user, sample_conversation):
|
||||
"""Test adding a message."""
|
||||
manager = PersistentConversationManager(db_session)
|
||||
|
||||
msg = await manager.add_message(
|
||||
conversation=sample_conversation,
|
||||
user=sample_user,
|
||||
role="user",
|
||||
content="Hello!",
|
||||
)
|
||||
|
||||
assert msg.id is not None
|
||||
assert msg.content == "Hello!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_exchange(self, db_session, sample_user, sample_conversation):
|
||||
"""Test adding a user/assistant exchange."""
|
||||
manager = PersistentConversationManager(db_session)
|
||||
|
||||
user_msg, assistant_msg = await manager.add_exchange(
|
||||
conversation=sample_conversation,
|
||||
user=sample_user,
|
||||
user_message="Hello!",
|
||||
assistant_message="Hi there!",
|
||||
)
|
||||
|
||||
assert user_msg.role == "user"
|
||||
assert assistant_msg.role == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history(self, db_session, sample_user, sample_conversation):
|
||||
"""Test getting conversation history."""
|
||||
manager = PersistentConversationManager(db_session)
|
||||
|
||||
await manager.add_exchange(
|
||||
conversation=sample_conversation,
|
||||
user=sample_user,
|
||||
user_message="Hello!",
|
||||
assistant_message="Hi there!",
|
||||
)
|
||||
|
||||
history = await manager.get_history(sample_conversation)
|
||||
|
||||
assert len(history) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_conversation(self, db_session, sample_conversation):
|
||||
"""Test clearing a conversation."""
|
||||
manager = PersistentConversationManager(db_session)
|
||||
|
||||
await manager.clear_conversation(sample_conversation)
|
||||
|
||||
assert sample_conversation.is_active is False
|
||||
|
||||
|
||||
class TestSelfAwarenessService:
|
||||
"""Tests for SelfAwarenessService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_bot_stats(self, db_session, sample_bot_state):
|
||||
"""Test getting bot stats."""
|
||||
service = SelfAwarenessService(db_session)
|
||||
|
||||
stats = await service.get_bot_stats(guild_id=sample_bot_state.guild_id)
|
||||
|
||||
assert "age_days" in stats
|
||||
assert "total_messages_sent" in stats
|
||||
assert "age_readable" in stats
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_with_user(self, db_session, sample_user, sample_user_relationship):
|
||||
"""Test getting history with a user."""
|
||||
service = SelfAwarenessService(db_session)
|
||||
|
||||
history = await service.get_history_with_user(sample_user, guild_id=111222333)
|
||||
|
||||
assert "days_known" in history
|
||||
assert "total_interactions" in history
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflect_on_self(self, db_session, sample_bot_state):
|
||||
"""Test self reflection."""
|
||||
service = SelfAwarenessService(db_session)
|
||||
|
||||
reflection = await service.reflect_on_self(guild_id=sample_bot_state.guild_id)
|
||||
|
||||
assert isinstance(reflection, str)
|
||||
|
||||
|
||||
class TestFactExtractionService:
|
||||
"""Tests for FactExtractionService."""
|
||||
|
||||
def test_is_extractable_short_message(self):
|
||||
"""Test that short messages are not extractable."""
|
||||
service = FactExtractionService(None)
|
||||
assert service._is_extractable("hi") is False
|
||||
|
||||
def test_is_extractable_greeting(self):
|
||||
"""Test that greetings are not extractable."""
|
||||
service = FactExtractionService(None)
|
||||
assert service._is_extractable("hello") is False
|
||||
|
||||
def test_is_extractable_command(self):
|
||||
"""Test that commands are not extractable."""
|
||||
service = FactExtractionService(None)
|
||||
assert service._is_extractable("!help me with something") is False
|
||||
|
||||
def test_is_extractable_valid(self):
|
||||
"""Test that valid messages are extractable."""
|
||||
service = FactExtractionService(None)
|
||||
assert (
|
||||
service._is_extractable("I really enjoy programming in Python and building bots")
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_duplicate_exact_match(self):
|
||||
"""Test duplicate detection with exact match."""
|
||||
service = FactExtractionService(None)
|
||||
existing = {"likes programming", "enjoys gaming"}
|
||||
assert service._is_duplicate("likes programming", existing) is True
|
||||
|
||||
def test_is_duplicate_no_match(self):
|
||||
"""Test duplicate detection with no match."""
|
||||
service = FactExtractionService(None)
|
||||
existing = {"likes programming", "enjoys gaming"}
|
||||
assert service._is_duplicate("works at a tech company", existing) is False
|
||||
|
||||
def test_validate_fact_valid(self):
|
||||
"""Test fact validation with valid fact."""
|
||||
service = FactExtractionService(None)
|
||||
fact = {
|
||||
"type": "hobby",
|
||||
"content": "likes programming",
|
||||
"confidence": 0.9,
|
||||
}
|
||||
assert service._validate_fact(fact) is True
|
||||
|
||||
def test_validate_fact_missing_type(self):
|
||||
"""Test fact validation with missing type."""
|
||||
service = FactExtractionService(None)
|
||||
fact = {"content": "likes programming"}
|
||||
assert service._validate_fact(fact) is False
|
||||
|
||||
def test_validate_fact_invalid_type(self):
|
||||
"""Test fact validation with invalid type."""
|
||||
service = FactExtractionService(None)
|
||||
fact = {"type": "invalid_type", "content": "test"}
|
||||
assert service._validate_fact(fact) is False
|
||||
|
||||
def test_validate_fact_empty_content(self):
|
||||
"""Test fact validation with empty content."""
|
||||
service = FactExtractionService(None)
|
||||
fact = {"type": "hobby", "content": ""}
|
||||
assert service._validate_fact(fact) is False
|
||||
|
||||
|
||||
class TestAIService:
|
||||
"""Tests for AIService."""
|
||||
|
||||
def test_get_system_prompt_default(self, mock_settings):
|
||||
"""Test getting default system prompt."""
|
||||
with patch("daemon_boyfriend.services.ai_service.settings", mock_settings):
|
||||
with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"):
|
||||
service = AIService(mock_settings)
|
||||
service._provider = MagicMock()
|
||||
|
||||
prompt = service.get_system_prompt()
|
||||
|
||||
assert "TestBot" in prompt
|
||||
assert "helpful and friendly" in prompt
|
||||
|
||||
def test_get_system_prompt_custom(self, mock_settings):
|
||||
"""Test getting custom system prompt."""
|
||||
mock_settings.system_prompt = "Custom prompt"
|
||||
with patch("daemon_boyfriend.services.ai_service.settings", mock_settings):
|
||||
with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"):
|
||||
service = AIService(mock_settings)
|
||||
service._provider = MagicMock()
|
||||
|
||||
prompt = service.get_system_prompt()
|
||||
|
||||
assert prompt == "Custom prompt"
|
||||
|
||||
def test_provider_name(self, mock_settings):
|
||||
"""Test getting provider name."""
|
||||
with patch("daemon_boyfriend.services.ai_service.settings", mock_settings):
|
||||
with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"):
|
||||
service = AIService(mock_settings)
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.provider_name = "openai"
|
||||
service._provider = mock_provider
|
||||
|
||||
assert service.provider_name == "openai"
|
||||
|
||||
def test_model_property(self, mock_settings):
|
||||
"""Test getting model name."""
|
||||
with patch("daemon_boyfriend.services.ai_service.settings", mock_settings):
|
||||
with patch("daemon_boyfriend.services.ai_service.AIService._init_provider"):
|
||||
service = AIService(mock_settings)
|
||||
service._provider = MagicMock()
|
||||
|
||||
assert service.model == "gpt-4o-mini"
|
||||
Reference in New Issue
Block a user