Files
loyal_companion/tests/test_providers.py
2026-01-12 20:41:04 +01:00

291 lines
11 KiB
Python

"""Tests for AI provider implementations."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from daemon_boyfriend.services.providers.base import (
AIProvider,
AIResponse,
Message,
ImageAttachment,
)
from daemon_boyfriend.services.providers.openai import OpenAIProvider
from daemon_boyfriend.services.providers.anthropic import AnthropicProvider
from daemon_boyfriend.services.providers.gemini import GeminiProvider
from daemon_boyfriend.services.providers.openrouter import OpenRouterProvider
class TestMessage:
"""Tests for the Message dataclass."""
def test_message_creation(self):
"""Test creating a basic message."""
msg = Message(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
assert msg.images == []
def test_message_with_images(self):
"""Test creating a message with images."""
images = [ImageAttachment(url="https://example.com/image.png")]
msg = Message(role="user", content="Look at this", images=images)
assert len(msg.images) == 1
assert msg.images[0].url == "https://example.com/image.png"
class TestImageAttachment:
"""Tests for the ImageAttachment dataclass."""
def test_default_media_type(self):
"""Test default media type."""
img = ImageAttachment(url="https://example.com/image.png")
assert img.media_type == "image/png"
def test_custom_media_type(self):
"""Test custom media type."""
img = ImageAttachment(url="https://example.com/image.jpg", media_type="image/jpeg")
assert img.media_type == "image/jpeg"
class TestAIResponse:
"""Tests for the AIResponse dataclass."""
def test_response_creation(self):
"""Test creating an AI response."""
response = AIResponse(
content="Hello!",
model="test-model",
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
assert response.content == "Hello!"
assert response.model == "test-model"
assert response.usage["total_tokens"] == 15
class TestOpenAIProvider:
"""Tests for the OpenAI provider."""
@pytest.fixture
def provider(self, mock_openai_client):
"""Create an OpenAI provider with mocked client."""
with patch("daemon_boyfriend.services.providers.openai.AsyncOpenAI") as mock_class:
mock_class.return_value = mock_openai_client
provider = OpenAIProvider(api_key="test_key", model="gpt-4o-mini")
provider.client = mock_openai_client
return provider
def test_provider_name(self, provider):
"""Test provider name."""
assert provider.provider_name == "openai"
def test_model_setting(self, provider):
"""Test model is set correctly."""
assert provider.model == "gpt-4o-mini"
@pytest.mark.asyncio
async def test_generate_simple_message(self, provider, mock_openai_client):
"""Test generating a response with a simple message."""
messages = [Message(role="user", content="Hello")]
response = await provider.generate(messages)
assert response.content == "Test OpenAI response"
assert response.model == "gpt-4o-mini"
mock_openai_client.chat.completions.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_with_system_prompt(self, provider, mock_openai_client):
"""Test generating a response with a system prompt."""
messages = [Message(role="user", content="Hello")]
await provider.generate(messages, system_prompt="You are a helpful assistant.")
call_args = mock_openai_client.chat.completions.create.call_args
api_messages = call_args.kwargs["messages"]
assert api_messages[0]["role"] == "system"
assert api_messages[0]["content"] == "You are a helpful assistant."
@pytest.mark.asyncio
async def test_generate_with_images(self, provider, mock_openai_client):
"""Test generating a response with images."""
images = [ImageAttachment(url="https://example.com/image.png")]
messages = [Message(role="user", content="What's in this image?", images=images)]
await provider.generate(messages)
call_args = mock_openai_client.chat.completions.create.call_args
api_messages = call_args.kwargs["messages"]
content = api_messages[0]["content"]
assert isinstance(content, list)
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
def test_build_message_content_no_images(self, provider):
"""Test building message content without images."""
msg = Message(role="user", content="Hello")
content = provider._build_message_content(msg)
assert content == "Hello"
def test_build_message_content_with_images(self, provider):
"""Test building message content with images."""
images = [ImageAttachment(url="https://example.com/image.png")]
msg = Message(role="user", content="Hello", images=images)
content = provider._build_message_content(msg)
assert isinstance(content, list)
assert len(content) == 2
class TestAnthropicProvider:
"""Tests for the Anthropic provider."""
@pytest.fixture
def provider(self, mock_anthropic_client):
"""Create an Anthropic provider with mocked client."""
with patch(
"daemon_boyfriend.services.providers.anthropic.anthropic.AsyncAnthropic"
) as mock_class:
"""Tests for the Anthropic provider."""
@pytest.fixture
def provider(self, mock_anthropic_client):
"""Create an Anthropic provider with mocked client."""
with patch("daemon_boyfriend.services.providers.anthropic.anthropic.AsyncAnthropic") as mock_class:
mock_class.return_value = mock_anthropic_client
provider = AnthropicProvider(api_key="test_key", model="claude-sonnet-4-20250514")
provider.client = mock_anthropic_client
return provider
def test_provider_name(self, provider):
"""Test provider name."""
assert provider.provider_name == "anthropic"
def test_model_setting(self, provider):
"""Test model is set correctly."""
assert provider.model == "claude-sonnet-4-20250514"
@pytest.mark.asyncio
async def test_generate_simple_message(self, provider, mock_anthropic_client):
"""Test generating a response with a simple message."""
messages = [Message(role="user", content="Hello")]
response = await provider.generate(messages)
assert response.content == "Test Anthropic response"
mock_anthropic_client.messages.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_with_system_prompt(self, provider, mock_anthropic_client):
"""Test generating a response with a system prompt."""
messages = [Message(role="user", content="Hello")]
await provider.generate(messages, system_prompt="You are a helpful assistant.")
call_args = mock_anthropic_client.messages.create.call_args
assert call_args.kwargs["system"] == "You are a helpful assistant."
def test_build_message_content_no_images(self, provider):
"""Test building message content without images."""
msg = Message(role="user", content="Hello")
content = provider._build_message_content(msg)
assert content == "Hello"
def test_build_message_content_with_images(self, provider):
"""Test building message content with images."""
images = [ImageAttachment(url="https://example.com/image.png")]
msg = Message(role="user", content="Hello", images=images)
content = provider._build_message_content(msg)
assert isinstance(content, list)
assert content[0]["type"] == "image"
assert content[1]["type"] == "text"
class TestGeminiProvider:
"""Tests for the Gemini provider."""
@pytest.fixture
def provider(self, mock_gemini_client):
"""Create a Gemini provider with mocked client."""
with patch("daemon_boyfriend.services.providers.gemini.genai.Client") as mock_class:
mock_class.return_value = mock_gemini_client
provider = GeminiProvider(api_key="test_key", model="gemini-2.0-flash")
provider.client = mock_gemini_client
return provider
def test_provider_name(self, provider):
"""Test provider name."""
assert provider.provider_name == "gemini"
def test_model_setting(self, provider):
"""Test model is set correctly."""
assert provider.model == "gemini-2.0-flash"
@pytest.mark.asyncio
async def test_generate_simple_message(self, provider, mock_gemini_client):
"""Test generating a response with a simple message."""
messages = [Message(role="user", content="Hello")]
response = await provider.generate(messages)
assert response.content == "Test Gemini response"
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_role_mapping(self, provider, mock_gemini_client):
"""Test that 'assistant' role is mapped to 'model'."""
messages = [
Message(role="user", content="Hello"),
Message(role="assistant", content="Hi there!"),
Message(role="user", content="How are you?"),
]
await provider.generate(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
contents = call_args.kwargs["contents"]
assert contents[0].role == "user"
assert contents[1].role == "model"
assert contents[2].role == "user"
class TestOpenRouterProvider:
"""Tests for the OpenRouter provider."""
@pytest.fixture
def provider(self, mock_openai_client):
"""Create an OpenRouter provider with mocked client."""
with patch("daemon_boyfriend.services.providers.openrouter.AsyncOpenAI") as mock_class:
mock_class.return_value = mock_openai_client
provider = OpenRouterProvider(api_key="test_key", model="openai/gpt-4o")
provider.client = mock_openai_client
return provider
def test_provider_name(self, provider):
"""Test provider name."""
assert provider.provider_name == "openrouter"
def test_model_setting(self, provider):
"""Test model is set correctly."""
assert provider.model == "openai/gpt-4o"
@pytest.mark.asyncio
async def test_generate_simple_message(self, provider, mock_openai_client):
"""Test generating a response with a simple message."""
messages = [Message(role="user", content="Hello")]
response = await provider.generate(messages)
assert response.content == "Test OpenAI response"
mock_openai_client.chat.completions.create.assert_called_once()
@pytest.mark.asyncio
async def test_extra_headers(self, provider, mock_openai_client):
"""Test that OpenRouter-specific headers are included."""
messages = [Message(role="user", content="Hello")]
await provider.generate(messages)
call_args = mock_openai_client.chat.completions.create.call_args
extra_headers = call_args.kwargs.get("extra_headers", {})
assert "HTTP-Referer" in extra_headers
assert "X-Title" in extra_headers