Files
openrabbit/tests/test_llm_client.py
latte e8d28225e0
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
just why not
2026-01-07 21:19:46 +01:00

457 lines
14 KiB
Python

"""Test Suite for LLM Client
Tests for LLM client functionality including provider support,
tool calling, and JSON parsing.
"""
import json
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
from unittest.mock import MagicMock, Mock, patch
import pytest
class TestLLMClientCreation:
"""Test LLM client initialization."""
def test_create_openai_client(self):
"""Test creating OpenAI client."""
from clients.llm_client import LLMClient
client = LLMClient(
provider="openai",
config={"model": "gpt-4o-mini", "api_key": "test-key"},
)
assert client.provider_name == "openai"
def test_create_openrouter_client(self):
"""Test creating OpenRouter client."""
from clients.llm_client import LLMClient
client = LLMClient(
provider="openrouter",
config={"model": "anthropic/claude-3.5-sonnet", "api_key": "test-key"},
)
assert client.provider_name == "openrouter"
def test_create_ollama_client(self):
"""Test creating Ollama client."""
from clients.llm_client import LLMClient
client = LLMClient(
provider="ollama",
config={"model": "codellama:13b", "host": "http://localhost:11434"},
)
assert client.provider_name == "ollama"
def test_invalid_provider_raises_error(self):
"""Test that invalid provider raises ValueError."""
from clients.llm_client import LLMClient
with pytest.raises(ValueError, match="Unknown provider"):
LLMClient(provider="invalid_provider")
def test_from_config_openai(self):
"""Test creating client from config dict."""
from clients.llm_client import LLMClient
config = {
"provider": "openai",
"model": {"openai": "gpt-4o-mini"},
"temperature": 0,
"max_tokens": 4096,
}
client = LLMClient.from_config(config)
assert client.provider_name == "openai"
class TestLLMResponse:
"""Test LLM response dataclass."""
def test_response_creation(self):
"""Test creating LLMResponse."""
from clients.llm_client import LLMResponse
response = LLMResponse(
content="Test response",
model="gpt-4o-mini",
provider="openai",
tokens_used=100,
finish_reason="stop",
)
assert response.content == "Test response"
assert response.model == "gpt-4o-mini"
assert response.provider == "openai"
assert response.tokens_used == 100
assert response.finish_reason == "stop"
assert response.tool_calls is None
def test_response_with_tool_calls(self):
"""Test LLMResponse with tool calls."""
from clients.llm_client import LLMResponse, ToolCall
tool_calls = [ToolCall(id="call_1", name="search", arguments={"query": "test"})]
response = LLMResponse(
content="",
model="gpt-4o-mini",
provider="openai",
tool_calls=tool_calls,
)
assert response.tool_calls is not None
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "search"
class TestToolCall:
"""Test ToolCall dataclass."""
def test_tool_call_creation(self):
"""Test creating ToolCall."""
from clients.llm_client import ToolCall
tool_call = ToolCall(
id="call_123",
name="search_codebase",
arguments={"query": "authentication", "file_pattern": "*.py"},
)
assert tool_call.id == "call_123"
assert tool_call.name == "search_codebase"
assert tool_call.arguments["query"] == "authentication"
assert tool_call.arguments["file_pattern"] == "*.py"
class TestJSONParsing:
"""Test JSON extraction and parsing."""
def test_parse_direct_json(self):
"""Test parsing direct JSON response."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = '{"key": "value", "number": 42}'
result = client._extract_json(content)
assert result["key"] == "value"
assert result["number"] == 42
def test_parse_json_in_code_block(self):
"""Test parsing JSON in markdown code block."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """Here is the analysis:
```json
{
"type": "bug",
"priority": "high"
}
```
That's my analysis."""
result = client._extract_json(content)
assert result["type"] == "bug"
assert result["priority"] == "high"
def test_parse_json_in_plain_code_block(self):
"""Test parsing JSON in plain code block (no json specifier)."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """Analysis:
```
{"status": "success", "count": 5}
```
"""
result = client._extract_json(content)
assert result["status"] == "success"
assert result["count"] == 5
def test_parse_json_with_preamble(self):
"""Test parsing JSON with text before it."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """Based on my analysis, here is the result:
{"findings": ["issue1", "issue2"], "severity": "medium"}
"""
result = client._extract_json(content)
assert result["findings"] == ["issue1", "issue2"]
assert result["severity"] == "medium"
def test_parse_json_with_postamble(self):
"""Test parsing JSON with text after it."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """{"result": true}
Let me know if you need more details."""
result = client._extract_json(content)
assert result["result"] is True
def test_parse_nested_json(self):
"""Test parsing nested JSON objects."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = """{
"outer": {
"inner": {
"value": "deep"
}
},
"array": [1, 2, 3]
}"""
result = client._extract_json(content)
assert result["outer"]["inner"]["value"] == "deep"
assert result["array"] == [1, 2, 3]
def test_parse_invalid_json_raises_error(self):
"""Test that invalid JSON raises ValueError."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = "This is not JSON at all"
with pytest.raises(ValueError, match="Failed to parse JSON"):
client._extract_json(content)
def test_parse_truncated_json_raises_error(self):
"""Test that truncated JSON raises ValueError."""
from clients.llm_client import LLMClient
client = LLMClient.__new__(LLMClient)
content = '{"key": "value", "incomplete'
with pytest.raises(ValueError):
client._extract_json(content)
class TestOpenAIProvider:
"""Test OpenAI provider."""
def test_provider_creation(self):
"""Test OpenAI provider initialization."""
from clients.llm_client import OpenAIProvider
provider = OpenAIProvider(
api_key="test-key",
model="gpt-4o-mini",
temperature=0.5,
max_tokens=2048,
)
assert provider.model == "gpt-4o-mini"
assert provider.temperature == 0.5
assert provider.max_tokens == 2048
def test_provider_requires_api_key(self):
"""Test that calling without API key raises error."""
from clients.llm_client import OpenAIProvider
provider = OpenAIProvider(api_key="")
with pytest.raises(ValueError, match="API key is required"):
provider.call("test prompt")
@patch("clients.llm_client.requests.post")
def test_provider_call_success(self, mock_post):
"""Test successful API call."""
from clients.llm_client import OpenAIProvider
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{
"message": {"content": "Test response"},
"finish_reason": "stop",
}
],
"model": "gpt-4o-mini",
"usage": {"total_tokens": 50},
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OpenAIProvider(api_key="test-key")
response = provider.call("Hello")
assert response.content == "Test response"
assert response.provider == "openai"
assert response.tokens_used == 50
class TestOpenRouterProvider:
"""Test OpenRouter provider."""
def test_provider_creation(self):
"""Test OpenRouter provider initialization."""
from clients.llm_client import OpenRouterProvider
provider = OpenRouterProvider(
api_key="test-key",
model="anthropic/claude-3.5-sonnet",
)
assert provider.model == "anthropic/claude-3.5-sonnet"
@patch("clients.llm_client.requests.post")
def test_provider_call_success(self, mock_post):
"""Test successful API call."""
from clients.llm_client import OpenRouterProvider
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{
"message": {"content": "Claude response"},
"finish_reason": "stop",
}
],
"model": "anthropic/claude-3.5-sonnet",
"usage": {"total_tokens": 75},
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OpenRouterProvider(api_key="test-key")
response = provider.call("Hello")
assert response.content == "Claude response"
assert response.provider == "openrouter"
class TestOllamaProvider:
"""Test Ollama provider."""
def test_provider_creation(self):
"""Test Ollama provider initialization."""
from clients.llm_client import OllamaProvider
provider = OllamaProvider(
host="http://localhost:11434",
model="codellama:13b",
)
assert provider.model == "codellama:13b"
assert provider.host == "http://localhost:11434"
@patch("clients.llm_client.requests.post")
def test_provider_call_success(self, mock_post):
"""Test successful API call."""
from clients.llm_client import OllamaProvider
mock_response = Mock()
mock_response.json.return_value = {
"response": "Ollama response",
"model": "codellama:13b",
"done": True,
"eval_count": 30,
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OllamaProvider()
response = provider.call("Hello")
assert response.content == "Ollama response"
assert response.provider == "ollama"
class TestToolCalling:
"""Test tool/function calling support."""
@patch("clients.llm_client.requests.post")
def test_openai_tool_calling(self, mock_post):
"""Test OpenAI tool calling."""
from clients.llm_client import OpenAIProvider
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{
"message": {
"content": None,
"tool_calls": [
{
"id": "call_abc123",
"function": {
"name": "search_codebase",
"arguments": '{"query": "auth"}',
},
}
],
},
"finish_reason": "tool_calls",
}
],
"model": "gpt-4o-mini",
"usage": {"total_tokens": 100},
}
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
provider = OpenAIProvider(api_key="test-key")
tools = [
{
"type": "function",
"function": {
"name": "search_codebase",
"description": "Search the codebase",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
},
},
}
]
response = provider.call_with_tools(
messages=[{"role": "user", "content": "Search for auth"}],
tools=tools,
)
assert response.tool_calls is not None
assert len(response.tool_calls) == 1
assert response.tool_calls[0].name == "search_codebase"
assert response.tool_calls[0].arguments["query"] == "auth"
def test_ollama_tool_calling_not_supported(self):
"""Test that Ollama raises NotImplementedError for tool calling."""
from clients.llm_client import OllamaProvider
provider = OllamaProvider()
with pytest.raises(NotImplementedError):
provider.call_with_tools(
messages=[{"role": "user", "content": "test"}],
tools=[],
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])