987 lines
35 KiB
Python
987 lines
35 KiB
Python
"""Tests for service layer."""
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from loyal_companion.models import (
|
|
BotOpinion,
|
|
BotState,
|
|
Conversation,
|
|
Message,
|
|
User,
|
|
UserAttachmentProfile,
|
|
UserFact,
|
|
UserRelationship,
|
|
)
|
|
from loyal_companion.services.ai_service import AIService
|
|
from loyal_companion.services.attachment_service import (
|
|
AttachmentContext,
|
|
AttachmentService,
|
|
AttachmentState,
|
|
AttachmentStyle,
|
|
)
|
|
from loyal_companion.services.fact_extraction_service import FactExtractionService
|
|
from loyal_companion.services.mood_service import MoodLabel, MoodService, MoodState
|
|
from loyal_companion.services.opinion_service import OpinionService, extract_topics_from_message
|
|
from loyal_companion.services.persistent_conversation import PersistentConversationManager
|
|
from loyal_companion.services.relationship_service import RelationshipLevel, RelationshipService
|
|
from loyal_companion.services.self_awareness_service import SelfAwarenessService
|
|
from loyal_companion.services.user_service import UserService
|
|
|
|
|
|
class TestUserService:
|
|
"""Tests for UserService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_user_new(self, db_session):
|
|
"""Test creating a new user."""
|
|
service = UserService(db_session)
|
|
|
|
user = await service.get_or_create_user(
|
|
discord_id=123456789,
|
|
username="testuser",
|
|
display_name="Test User",
|
|
)
|
|
|
|
assert user.id is not None
|
|
assert user.discord_id == 123456789
|
|
assert user.discord_username == "testuser"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_user_existing(self, db_session, sample_user):
|
|
"""Test getting an existing user."""
|
|
service = UserService(db_session)
|
|
|
|
user = await service.get_or_create_user(
|
|
discord_id=sample_user.discord_id,
|
|
username="newname",
|
|
display_name="New Display",
|
|
)
|
|
|
|
assert user.id == sample_user.id
|
|
assert user.discord_username == "newname"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_custom_name(self, db_session, sample_user):
|
|
"""Test setting a custom name."""
|
|
service = UserService(db_session)
|
|
|
|
user = await service.set_custom_name(sample_user.discord_id, "CustomName")
|
|
|
|
assert user.custom_name == "CustomName"
|
|
assert user.display_name == "CustomName"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_custom_name(self, db_session, sample_user):
|
|
"""Test clearing a custom name."""
|
|
service = UserService(db_session)
|
|
sample_user.custom_name = "OldName"
|
|
|
|
user = await service.set_custom_name(sample_user.discord_id, None)
|
|
|
|
assert user.custom_name is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_fact(self, db_session, sample_user):
|
|
"""Test adding a fact about a user."""
|
|
service = UserService(db_session)
|
|
|
|
fact = await service.add_fact(
|
|
user=sample_user,
|
|
fact_type="hobby",
|
|
fact_content="likes programming",
|
|
)
|
|
|
|
assert fact.id is not None
|
|
assert fact.user_id == sample_user.id
|
|
assert fact.fact_type == "hobby"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_facts(self, db_session, sample_user_with_facts):
|
|
"""Test getting user facts."""
|
|
service = UserService(db_session)
|
|
|
|
facts = await service.get_user_facts(sample_user_with_facts)
|
|
|
|
assert len(facts) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_facts_by_type(self, db_session, sample_user_with_facts):
|
|
"""Test getting user facts by type."""
|
|
service = UserService(db_session)
|
|
|
|
facts = await service.get_user_facts(sample_user_with_facts, fact_type="hobby")
|
|
|
|
assert len(facts) == 1
|
|
assert facts[0].fact_type == "hobby"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_user_facts(self, db_session, sample_user_with_facts):
|
|
"""Test deleting user facts."""
|
|
service = UserService(db_session)
|
|
|
|
count = await service.delete_user_facts(sample_user_with_facts)
|
|
|
|
assert count == 2
|
|
facts = await service.get_user_facts(sample_user_with_facts, active_only=True)
|
|
assert len(facts) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_context(self, db_session, sample_user_with_facts):
|
|
"""Test getting user context string."""
|
|
service = UserService(db_session)
|
|
|
|
context = await service.get_user_context(sample_user_with_facts)
|
|
|
|
assert "Test User" in context
|
|
assert "likes programming" in context
|
|
|
|
|
|
class TestMoodService:
|
|
"""Tests for MoodService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_bot_state(self, db_session):
|
|
"""Test getting or creating bot state."""
|
|
service = MoodService(db_session)
|
|
|
|
state = await service.get_or_create_bot_state(guild_id=111222333)
|
|
|
|
assert state.id is not None
|
|
assert state.guild_id == 111222333
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_current_mood(self, db_session, sample_bot_state):
|
|
"""Test getting current mood."""
|
|
service = MoodService(db_session)
|
|
|
|
mood = await service.get_current_mood(guild_id=sample_bot_state.guild_id)
|
|
|
|
assert isinstance(mood, MoodState)
|
|
assert -1.0 <= mood.valence <= 1.0
|
|
assert -1.0 <= mood.arousal <= 1.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_mood(self, db_session, sample_bot_state):
|
|
"""Test updating mood."""
|
|
service = MoodService(db_session)
|
|
|
|
new_mood = await service.update_mood(
|
|
guild_id=sample_bot_state.guild_id,
|
|
sentiment_delta=0.5,
|
|
engagement_delta=0.3,
|
|
trigger_type="conversation",
|
|
trigger_description="Had a nice chat",
|
|
)
|
|
|
|
assert new_mood.valence > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_increment_stats(self, db_session, sample_bot_state):
|
|
"""Test incrementing bot stats."""
|
|
service = MoodService(db_session)
|
|
initial_messages = sample_bot_state.total_messages_sent
|
|
|
|
await service.increment_stats(
|
|
guild_id=sample_bot_state.guild_id,
|
|
messages_sent=5,
|
|
facts_learned=2,
|
|
)
|
|
|
|
assert sample_bot_state.total_messages_sent == initial_messages + 5
|
|
|
|
def test_classify_mood_excited(self):
|
|
"""Test mood classification for excited."""
|
|
service = MoodService(None)
|
|
label = service._classify_mood(0.5, 0.5)
|
|
assert label == MoodLabel.EXCITED
|
|
|
|
def test_classify_mood_happy(self):
|
|
"""Test mood classification for happy."""
|
|
service = MoodService(None)
|
|
label = service._classify_mood(0.5, 0.0)
|
|
assert label == MoodLabel.HAPPY
|
|
|
|
def test_classify_mood_bored(self):
|
|
"""Test mood classification for bored."""
|
|
service = MoodService(None)
|
|
label = service._classify_mood(-0.5, 0.0)
|
|
assert label == MoodLabel.BORED
|
|
|
|
def test_classify_mood_annoyed(self):
|
|
"""Test mood classification for annoyed."""
|
|
service = MoodService(None)
|
|
label = service._classify_mood(-0.5, 0.5)
|
|
assert label == MoodLabel.ANNOYED
|
|
|
|
def test_get_mood_prompt_modifier(self):
|
|
"""Test getting mood prompt modifier."""
|
|
service = MoodService(None)
|
|
mood = MoodState(valence=0.8, arousal=0.8, label=MoodLabel.EXCITED, intensity=0.8)
|
|
|
|
modifier = service.get_mood_prompt_modifier(mood)
|
|
|
|
assert "enthusiastic" in modifier.lower() or "excited" in modifier.lower()
|
|
|
|
def test_get_mood_prompt_modifier_low_intensity(self):
|
|
"""Test mood modifier with low intensity."""
|
|
service = MoodService(None)
|
|
mood = MoodState(valence=0.1, arousal=0.1, label=MoodLabel.NEUTRAL, intensity=0.1)
|
|
|
|
modifier = service.get_mood_prompt_modifier(mood)
|
|
|
|
assert modifier == ""
|
|
|
|
|
|
class TestRelationshipService:
|
|
"""Tests for RelationshipService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_relationship(self, db_session, sample_user):
|
|
"""Test getting or creating a relationship."""
|
|
service = RelationshipService(db_session)
|
|
|
|
rel = await service.get_or_create_relationship(sample_user, guild_id=111222333)
|
|
|
|
assert rel.id is not None
|
|
assert rel.user_id == sample_user.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_interaction(self, db_session, sample_user):
|
|
"""Test recording an interaction."""
|
|
service = RelationshipService(db_session)
|
|
|
|
level = await service.record_interaction(
|
|
user=sample_user,
|
|
guild_id=111222333,
|
|
sentiment=0.8,
|
|
message_length=100,
|
|
conversation_turns=3,
|
|
)
|
|
|
|
assert isinstance(level, RelationshipLevel)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_positive_interaction(self, db_session, sample_user):
|
|
"""Test that positive interactions are tracked."""
|
|
service = RelationshipService(db_session)
|
|
|
|
await service.record_interaction(
|
|
user=sample_user,
|
|
guild_id=111222333,
|
|
sentiment=0.5,
|
|
message_length=100,
|
|
)
|
|
|
|
rel = await service.get_or_create_relationship(sample_user, guild_id=111222333)
|
|
assert rel.positive_interactions >= 1
|
|
|
|
def test_get_level_stranger(self):
|
|
"""Test level classification for stranger."""
|
|
service = RelationshipService(None)
|
|
assert service.get_level(10) == RelationshipLevel.STRANGER
|
|
|
|
def test_get_level_acquaintance(self):
|
|
"""Test level classification for acquaintance."""
|
|
service = RelationshipService(None)
|
|
assert service.get_level(30) == RelationshipLevel.ACQUAINTANCE
|
|
|
|
def test_get_level_friend(self):
|
|
"""Test level classification for friend."""
|
|
service = RelationshipService(None)
|
|
assert service.get_level(50) == RelationshipLevel.FRIEND
|
|
|
|
def test_get_level_good_friend(self):
|
|
"""Test level classification for good friend."""
|
|
service = RelationshipService(None)
|
|
assert service.get_level(70) == RelationshipLevel.GOOD_FRIEND
|
|
|
|
def test_get_level_close_friend(self):
|
|
"""Test level classification for close friend."""
|
|
service = RelationshipService(None)
|
|
assert service.get_level(90) == RelationshipLevel.CLOSE_FRIEND
|
|
|
|
def test_get_level_display_name(self):
|
|
"""Test getting display name for level."""
|
|
service = RelationshipService(None)
|
|
assert service.get_level_display_name(RelationshipLevel.FRIEND) == "Friend"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_relationship_info(self, db_session, sample_user_relationship, sample_user):
|
|
"""Test getting relationship info."""
|
|
service = RelationshipService(db_session)
|
|
|
|
info = await service.get_relationship_info(sample_user, guild_id=111222333)
|
|
|
|
assert "level" in info
|
|
assert "score" in info
|
|
assert "total_interactions" in info
|
|
|
|
|
|
class TestOpinionService:
|
|
"""Tests for OpinionService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_opinion(self, db_session):
|
|
"""Test getting or creating an opinion."""
|
|
service = OpinionService(db_session)
|
|
|
|
opinion = await service.get_or_create_opinion("programming")
|
|
|
|
assert opinion.id is not None
|
|
assert opinion.topic == "programming"
|
|
assert opinion.sentiment == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_topic_discussion(self, db_session):
|
|
"""Test recording a topic discussion."""
|
|
service = OpinionService(db_session)
|
|
|
|
opinion = await service.record_topic_discussion(
|
|
topic="gaming",
|
|
guild_id=None,
|
|
sentiment=0.8,
|
|
engagement_level=0.9,
|
|
)
|
|
|
|
assert opinion.discussion_count == 1
|
|
assert opinion.sentiment > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_top_interests(self, db_session):
|
|
"""Test getting top interests."""
|
|
service = OpinionService(db_session)
|
|
|
|
# Create some opinions with discussions
|
|
for topic in ["programming", "gaming", "music"]:
|
|
for _ in range(5):
|
|
await service.record_topic_discussion(
|
|
topic=topic,
|
|
guild_id=None,
|
|
sentiment=0.8,
|
|
engagement_level=0.9,
|
|
)
|
|
|
|
await db_session.commit()
|
|
|
|
interests = await service.get_top_interests(limit=3)
|
|
|
|
assert len(interests) <= 3
|
|
|
|
def test_extract_topics_gaming(self):
|
|
"""Test extracting gaming topic."""
|
|
topics = extract_topics_from_message("I love playing video games!")
|
|
assert "gaming" in topics
|
|
|
|
def test_extract_topics_programming(self):
|
|
"""Test extracting programming topic."""
|
|
topics = extract_topics_from_message("I'm learning Python programming")
|
|
assert "programming" in topics
|
|
|
|
def test_extract_topics_multiple(self):
|
|
"""Test extracting multiple topics."""
|
|
topics = extract_topics_from_message("I code while listening to music")
|
|
assert "programming" in topics
|
|
assert "music" in topics
|
|
|
|
def test_extract_topics_none(self):
|
|
"""Test extracting no topics."""
|
|
topics = extract_topics_from_message("Hello, how are you?")
|
|
assert len(topics) == 0
|
|
|
|
|
|
class TestPersistentConversationManager:
|
|
"""Tests for PersistentConversationManager."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_conversation_new(self, db_session, sample_user):
|
|
"""Test creating a new conversation."""
|
|
manager = PersistentConversationManager(db_session)
|
|
|
|
conv = await manager.get_or_create_conversation(
|
|
user=sample_user,
|
|
channel_id=123456,
|
|
)
|
|
|
|
assert conv.id is not None
|
|
assert conv.user_id == sample_user.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_conversation_existing(
|
|
self, db_session, sample_user, sample_conversation
|
|
):
|
|
"""Test getting an existing conversation."""
|
|
manager = PersistentConversationManager(db_session)
|
|
|
|
conv = await manager.get_or_create_conversation(
|
|
user=sample_user,
|
|
channel_id=sample_conversation.channel_id,
|
|
)
|
|
|
|
assert conv.id == sample_conversation.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_message(self, db_session, sample_user, sample_conversation):
|
|
"""Test adding a message."""
|
|
manager = PersistentConversationManager(db_session)
|
|
|
|
msg = await manager.add_message(
|
|
conversation=sample_conversation,
|
|
user=sample_user,
|
|
role="user",
|
|
content="Hello!",
|
|
)
|
|
|
|
assert msg.id is not None
|
|
assert msg.content == "Hello!"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_exchange(self, db_session, sample_user, sample_conversation):
|
|
"""Test adding a user/assistant exchange."""
|
|
manager = PersistentConversationManager(db_session)
|
|
|
|
user_msg, assistant_msg = await manager.add_exchange(
|
|
conversation=sample_conversation,
|
|
user=sample_user,
|
|
user_message="Hello!",
|
|
assistant_message="Hi there!",
|
|
)
|
|
|
|
assert user_msg.role == "user"
|
|
assert assistant_msg.role == "assistant"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_history(self, db_session, sample_user, sample_conversation):
|
|
"""Test getting conversation history."""
|
|
manager = PersistentConversationManager(db_session)
|
|
|
|
await manager.add_exchange(
|
|
conversation=sample_conversation,
|
|
user=sample_user,
|
|
user_message="Hello!",
|
|
assistant_message="Hi there!",
|
|
)
|
|
|
|
history = await manager.get_history(sample_conversation)
|
|
|
|
assert len(history) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_conversation(self, db_session, sample_conversation):
|
|
"""Test clearing a conversation."""
|
|
manager = PersistentConversationManager(db_session)
|
|
|
|
await manager.clear_conversation(sample_conversation)
|
|
|
|
assert sample_conversation.is_active is False
|
|
|
|
|
|
class TestSelfAwarenessService:
|
|
"""Tests for SelfAwarenessService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_bot_stats(self, db_session, sample_bot_state):
|
|
"""Test getting bot stats."""
|
|
service = SelfAwarenessService(db_session)
|
|
|
|
stats = await service.get_bot_stats(guild_id=sample_bot_state.guild_id)
|
|
|
|
assert "age_days" in stats
|
|
assert "total_messages_sent" in stats
|
|
assert "age_readable" in stats
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_history_with_user(self, db_session, sample_user, sample_user_relationship):
|
|
"""Test getting history with a user."""
|
|
service = SelfAwarenessService(db_session)
|
|
|
|
history = await service.get_history_with_user(sample_user, guild_id=111222333)
|
|
|
|
assert "days_known" in history
|
|
assert "total_interactions" in history
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reflect_on_self(self, db_session, sample_bot_state):
|
|
"""Test self reflection."""
|
|
service = SelfAwarenessService(db_session)
|
|
|
|
reflection = await service.reflect_on_self(guild_id=sample_bot_state.guild_id)
|
|
|
|
assert isinstance(reflection, str)
|
|
|
|
|
|
class TestFactExtractionService:
|
|
"""Tests for FactExtractionService."""
|
|
|
|
def test_is_extractable_short_message(self):
|
|
"""Test that short messages are not extractable."""
|
|
service = FactExtractionService(None)
|
|
assert service._is_extractable("hi") is False
|
|
|
|
def test_is_extractable_greeting(self):
|
|
"""Test that greetings are not extractable."""
|
|
service = FactExtractionService(None)
|
|
assert service._is_extractable("hello") is False
|
|
|
|
def test_is_extractable_command(self):
|
|
"""Test that commands are not extractable."""
|
|
service = FactExtractionService(None)
|
|
assert service._is_extractable("!help me with something") is False
|
|
|
|
def test_is_extractable_valid(self):
|
|
"""Test that valid messages are extractable."""
|
|
service = FactExtractionService(None)
|
|
assert (
|
|
service._is_extractable("I really enjoy programming in Python and building bots")
|
|
is True
|
|
)
|
|
|
|
def test_is_duplicate_exact_match(self):
|
|
"""Test duplicate detection with exact match."""
|
|
service = FactExtractionService(None)
|
|
existing = {"likes programming", "enjoys gaming"}
|
|
assert service._is_duplicate("likes programming", existing) is True
|
|
|
|
def test_is_duplicate_no_match(self):
|
|
"""Test duplicate detection with no match."""
|
|
service = FactExtractionService(None)
|
|
existing = {"likes programming", "enjoys gaming"}
|
|
assert service._is_duplicate("works at a tech company", existing) is False
|
|
|
|
def test_validate_fact_valid(self):
|
|
"""Test fact validation with valid fact."""
|
|
service = FactExtractionService(None)
|
|
fact = {
|
|
"type": "hobby",
|
|
"content": "likes programming",
|
|
"confidence": 0.9,
|
|
}
|
|
assert service._validate_fact(fact) is True
|
|
|
|
def test_validate_fact_missing_type(self):
|
|
"""Test fact validation with missing type."""
|
|
service = FactExtractionService(None)
|
|
fact = {"content": "likes programming"}
|
|
assert service._validate_fact(fact) is False
|
|
|
|
def test_validate_fact_invalid_type(self):
|
|
"""Test fact validation with invalid type."""
|
|
service = FactExtractionService(None)
|
|
fact = {"type": "invalid_type", "content": "test"}
|
|
assert service._validate_fact(fact) is False
|
|
|
|
def test_validate_fact_empty_content(self):
|
|
"""Test fact validation with empty content."""
|
|
service = FactExtractionService(None)
|
|
fact = {"type": "hobby", "content": ""}
|
|
assert service._validate_fact(fact) is False
|
|
|
|
|
|
class TestAIService:
|
|
"""Tests for AIService."""
|
|
|
|
def test_get_system_prompt_default(self, mock_settings):
|
|
"""Test getting default system prompt."""
|
|
with patch("loyal_companion.services.ai_service.settings", mock_settings):
|
|
with patch("loyal_companion.services.ai_service.AIService._init_provider"):
|
|
service = AIService(mock_settings)
|
|
service._provider = MagicMock()
|
|
|
|
prompt = service.get_system_prompt()
|
|
|
|
assert "TestBot" in prompt
|
|
assert "helpful and friendly" in prompt
|
|
|
|
def test_get_system_prompt_custom(self, mock_settings):
|
|
"""Test getting custom system prompt."""
|
|
mock_settings.system_prompt = "Custom prompt"
|
|
with patch("loyal_companion.services.ai_service.settings", mock_settings):
|
|
with patch("loyal_companion.services.ai_service.AIService._init_provider"):
|
|
service = AIService(mock_settings)
|
|
service._provider = MagicMock()
|
|
|
|
prompt = service.get_system_prompt()
|
|
|
|
assert prompt == "Custom prompt"
|
|
|
|
def test_provider_name(self, mock_settings):
|
|
"""Test getting provider name."""
|
|
with patch("loyal_companion.services.ai_service.settings", mock_settings):
|
|
with patch("loyal_companion.services.ai_service.AIService._init_provider"):
|
|
service = AIService(mock_settings)
|
|
mock_provider = MagicMock()
|
|
mock_provider.provider_name = "openai"
|
|
service._provider = mock_provider
|
|
|
|
assert service.provider_name == "openai"
|
|
|
|
def test_model_property(self, mock_settings):
|
|
"""Test getting model name."""
|
|
with patch("loyal_companion.services.ai_service.settings", mock_settings):
|
|
with patch("loyal_companion.services.ai_service.AIService._init_provider"):
|
|
service = AIService(mock_settings)
|
|
service._provider = MagicMock()
|
|
|
|
assert service.model == "gpt-4o-mini"
|
|
|
|
|
|
class TestAttachmentService:
|
|
"""Tests for AttachmentService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_profile_new(self, db_session, sample_user):
|
|
"""Test creating a new attachment profile."""
|
|
service = AttachmentService(db_session)
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
|
|
assert profile.id is not None
|
|
assert profile.user_id == sample_user.id
|
|
assert profile.primary_style == "unknown"
|
|
assert profile.current_state == "regulated"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_profile_existing(self, db_session, sample_user):
|
|
"""Test getting an existing attachment profile."""
|
|
service = AttachmentService(db_session)
|
|
|
|
# Create first
|
|
profile1 = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
await db_session.commit()
|
|
|
|
# Get again
|
|
profile2 = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
|
|
assert profile1.id == profile2.id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analyze_message_no_indicators(self, db_session, sample_user):
|
|
"""Test analyzing a message with no attachment indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="Hello, how are you today?",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
assert context.current_state == AttachmentState.REGULATED
|
|
assert len(context.recent_indicators) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analyze_message_anxious_indicators(self, db_session, sample_user):
|
|
"""Test analyzing a message with anxious attachment indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="Are you still there? Do you still like me? Did I do something wrong?",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
assert context.current_state == AttachmentState.ACTIVATED
|
|
assert len(context.recent_indicators) > 0
|
|
|
|
# Check profile was updated
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
assert profile.anxious_indicators > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analyze_message_avoidant_indicators(self, db_session, sample_user):
|
|
"""Test analyzing a message with avoidant attachment indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="It's fine, whatever. I don't need anyone. I'm better alone.",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
assert context.current_state == AttachmentState.ACTIVATED
|
|
assert len(context.recent_indicators) > 0
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
assert profile.avoidant_indicators > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analyze_message_disorganized_indicators(self, db_session, sample_user):
|
|
"""Test analyzing a message with disorganized attachment indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="I don't know what I want. I'm so confused and torn.",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
# Should detect disorganized patterns
|
|
assert len(context.recent_indicators) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analyze_message_mixed_state(self, db_session, sample_user):
|
|
"""Test that mixed indicators result in mixed state."""
|
|
service = AttachmentService(db_session)
|
|
|
|
# Message with both anxious and avoidant indicators
|
|
context = await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="Are you still there? Actually, it's fine, I don't care anyway.",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
assert context.current_state == AttachmentState.MIXED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_analyze_message_secure_indicators(self, db_session, sample_user):
|
|
"""Test analyzing a message with secure attachment indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="I'm feeling sad today and I need to talk about it. Thank you for listening.",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
assert profile.secure_indicators > 0
|
|
|
|
def test_find_indicators_anxious(self, db_session):
|
|
"""Test finding anxious indicators in text."""
|
|
service = AttachmentService(db_session)
|
|
|
|
matches = service._find_indicators(
|
|
"do you still like me?",
|
|
service.ANXIOUS_INDICATORS,
|
|
)
|
|
|
|
assert len(matches) > 0
|
|
|
|
def test_find_indicators_none(self, db_session):
|
|
"""Test finding no indicators in neutral text."""
|
|
service = AttachmentService(db_session)
|
|
|
|
matches = service._find_indicators(
|
|
"the weather is nice today",
|
|
service.ANXIOUS_INDICATORS,
|
|
)
|
|
|
|
assert len(matches) == 0
|
|
|
|
def test_determine_state_regulated(self, db_session):
|
|
"""Test state determination with no indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
state, intensity = service._determine_state([], [], [])
|
|
|
|
assert state == AttachmentState.REGULATED
|
|
assert intensity == 0.0
|
|
|
|
def test_determine_state_activated(self, db_session):
|
|
"""Test state determination with single style indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
state, intensity = service._determine_state(["pattern1", "pattern2"], [], [])
|
|
|
|
assert state == AttachmentState.ACTIVATED
|
|
assert intensity > 0
|
|
|
|
def test_determine_state_mixed(self, db_session):
|
|
"""Test state determination with mixed indicators."""
|
|
service = AttachmentService(db_session)
|
|
|
|
state, intensity = service._determine_state(["anxious1"], ["avoidant1"], [])
|
|
|
|
assert state == AttachmentState.MIXED
|
|
|
|
def test_get_attachment_prompt_modifier_regulated(self, db_session):
|
|
"""Test prompt modifier for regulated state."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = AttachmentContext(
|
|
primary_style=AttachmentStyle.UNKNOWN,
|
|
style_confidence=0.0,
|
|
current_state=AttachmentState.REGULATED,
|
|
state_intensity=0.0,
|
|
recent_indicators=[],
|
|
effective_responses=[],
|
|
)
|
|
|
|
modifier = service.get_attachment_prompt_modifier(context, "friend")
|
|
|
|
assert modifier == ""
|
|
|
|
def test_get_attachment_prompt_modifier_anxious_activated(self, db_session):
|
|
"""Test prompt modifier for anxious activated state."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = AttachmentContext(
|
|
primary_style=AttachmentStyle.ANXIOUS,
|
|
style_confidence=0.7,
|
|
current_state=AttachmentState.ACTIVATED,
|
|
state_intensity=0.6,
|
|
recent_indicators=["pattern1"],
|
|
effective_responses=[],
|
|
)
|
|
|
|
modifier = service.get_attachment_prompt_modifier(context, "friend")
|
|
|
|
assert "reassurance" in modifier.lower()
|
|
assert "present" in modifier.lower()
|
|
|
|
def test_get_attachment_prompt_modifier_avoidant_activated(self, db_session):
|
|
"""Test prompt modifier for avoidant activated state."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = AttachmentContext(
|
|
primary_style=AttachmentStyle.AVOIDANT,
|
|
style_confidence=0.7,
|
|
current_state=AttachmentState.ACTIVATED,
|
|
state_intensity=0.6,
|
|
recent_indicators=["pattern1"],
|
|
effective_responses=[],
|
|
)
|
|
|
|
modifier = service.get_attachment_prompt_modifier(context, "friend")
|
|
|
|
assert "space" in modifier.lower()
|
|
assert "push" in modifier.lower()
|
|
|
|
def test_get_attachment_prompt_modifier_disorganized_activated(self, db_session):
|
|
"""Test prompt modifier for disorganized activated state."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = AttachmentContext(
|
|
primary_style=AttachmentStyle.DISORGANIZED,
|
|
style_confidence=0.7,
|
|
current_state=AttachmentState.ACTIVATED,
|
|
state_intensity=0.6,
|
|
recent_indicators=["pattern1"],
|
|
effective_responses=[],
|
|
)
|
|
|
|
modifier = service.get_attachment_prompt_modifier(context, "friend")
|
|
|
|
assert "steady" in modifier.lower()
|
|
assert "predictable" in modifier.lower()
|
|
|
|
def test_get_attachment_prompt_modifier_close_friend_reflection(self, db_session):
|
|
"""Test prompt modifier includes reflection at close friend level."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = AttachmentContext(
|
|
primary_style=AttachmentStyle.ANXIOUS,
|
|
style_confidence=0.7,
|
|
current_state=AttachmentState.ACTIVATED,
|
|
state_intensity=0.6,
|
|
recent_indicators=["pattern1"],
|
|
effective_responses=[],
|
|
)
|
|
|
|
modifier = service.get_attachment_prompt_modifier(context, "close_friend")
|
|
|
|
assert "pattern" in modifier.lower()
|
|
|
|
def test_get_attachment_prompt_modifier_with_effective_responses(self, db_session):
|
|
"""Test prompt modifier includes effective responses."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = AttachmentContext(
|
|
primary_style=AttachmentStyle.ANXIOUS,
|
|
style_confidence=0.7,
|
|
current_state=AttachmentState.ACTIVATED,
|
|
state_intensity=0.6,
|
|
recent_indicators=["pattern1"],
|
|
effective_responses=["reassurance", "validation"],
|
|
)
|
|
|
|
modifier = service.get_attachment_prompt_modifier(context, "friend")
|
|
|
|
assert "helped" in modifier.lower()
|
|
assert "reassurance" in modifier.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_response_effectiveness_helpful(self, db_session, sample_user):
|
|
"""Test recording a helpful response."""
|
|
service = AttachmentService(db_session)
|
|
|
|
await service.record_response_effectiveness(
|
|
user=sample_user,
|
|
guild_id=111222333,
|
|
response_style="reassurance",
|
|
was_helpful=True,
|
|
)
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
assert "reassurance" in (profile.effective_responses or [])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_response_effectiveness_unhelpful(self, db_session, sample_user):
|
|
"""Test recording an unhelpful response."""
|
|
service = AttachmentService(db_session)
|
|
|
|
await service.record_response_effectiveness(
|
|
user=sample_user,
|
|
guild_id=111222333,
|
|
response_style="advice",
|
|
was_helpful=False,
|
|
)
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
assert "advice" in (profile.ineffective_responses or [])
|
|
|
|
def test_default_context(self, db_session):
|
|
"""Test default context when tracking is disabled."""
|
|
service = AttachmentService(db_session)
|
|
|
|
context = service._default_context()
|
|
|
|
assert context.primary_style == AttachmentStyle.UNKNOWN
|
|
assert context.current_state == AttachmentState.REGULATED
|
|
assert context.style_confidence == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_primary_style_determination(self, db_session, sample_user):
|
|
"""Test that primary style is determined after enough samples."""
|
|
service = AttachmentService(db_session)
|
|
|
|
# Send multiple messages with anxious indicators
|
|
anxious_messages = [
|
|
"Are you still there?",
|
|
"Do you still like me?",
|
|
"Did I do something wrong?",
|
|
"Please don't leave me",
|
|
"Are you mad at me?",
|
|
"I'm scared you'll abandon me",
|
|
]
|
|
|
|
for msg in anxious_messages:
|
|
await service.analyze_message(
|
|
user=sample_user,
|
|
message_content=msg,
|
|
guild_id=111222333,
|
|
)
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
|
|
# After enough samples, primary style should be determined
|
|
assert profile.anxious_indicators >= 5
|
|
assert profile.style_confidence > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_activation_tracking(self, db_session, sample_user):
|
|
"""Test that activations are tracked."""
|
|
service = AttachmentService(db_session)
|
|
|
|
await service.analyze_message(
|
|
user=sample_user,
|
|
message_content="Are you still there? Do you still like me?",
|
|
guild_id=111222333,
|
|
)
|
|
|
|
profile = await service.get_or_create_profile(sample_user, guild_id=111222333)
|
|
|
|
assert profile.activation_count >= 1
|
|
assert profile.last_activation_at is not None
|