just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
296
tests/test_dispatcher.py
Normal file
296
tests/test_dispatcher.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Test Suite for Dispatcher
|
||||
|
||||
Tests for event routing and agent execution.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDispatcherCreation:
|
||||
"""Test dispatcher initialization."""
|
||||
|
||||
def test_create_dispatcher(self):
|
||||
"""Test creating dispatcher."""
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
assert dispatcher is not None
|
||||
assert dispatcher.agents == []
|
||||
|
||||
def test_create_dispatcher_with_config(self):
|
||||
"""Test creating dispatcher with config."""
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
config = {"dispatcher": {"max_workers": 4}}
|
||||
dispatcher = Dispatcher(config=config)
|
||||
assert dispatcher.config == config
|
||||
|
||||
|
||||
class TestAgentRegistration:
|
||||
"""Test agent registration."""
|
||||
|
||||
def test_register_agent(self):
|
||||
"""Test registering an agent."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "test"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="done")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
dispatcher.register_agent(agent)
|
||||
|
||||
assert len(dispatcher.agents) == 1
|
||||
assert dispatcher.agents[0] == agent
|
||||
|
||||
def test_register_multiple_agents(self):
|
||||
"""Test registering multiple agents."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent1(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "type1"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent1")
|
||||
|
||||
class MockAgent2(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "type2"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent2")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
MockAgent1(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
dispatcher.register_agent(
|
||||
MockAgent2(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
assert len(dispatcher.agents) == 2
|
||||
|
||||
|
||||
class TestEventRouting:
|
||||
"""Test event routing to agents."""
|
||||
|
||||
def test_route_to_matching_agent(self):
|
||||
"""Test that events are routed to matching agents."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="handled")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
dispatcher.register_agent(agent)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={"action": "opened"},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
assert len(result.agents_run) == 1
|
||||
assert result.results[0].success is True
|
||||
|
||||
def test_no_matching_agent(self):
|
||||
"""Test dispatch when no agent matches."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="handled")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
agent = MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
dispatcher.register_agent(agent)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="pull_request", # Different event type
|
||||
event_data={"action": "opened"},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
assert len(result.agents_run) == 0
|
||||
|
||||
def test_multiple_matching_agents(self):
|
||||
"""Test dispatch when multiple agents match."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class MockAgent1(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent1")
|
||||
|
||||
class MockAgent2(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return event_type == "issues"
|
||||
|
||||
def execute(self, context):
|
||||
return AgentResult(success=True, message="agent2")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
MockAgent1(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
dispatcher.register_agent(
|
||||
MockAgent2(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={"action": "opened"},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
assert len(result.agents_run) == 2
|
||||
|
||||
|
||||
class TestDispatchResult:
|
||||
"""Test dispatch result structure."""
|
||||
|
||||
def test_result_structure(self):
|
||||
"""Test DispatchResult has correct structure."""
|
||||
from dispatcher import DispatchResult
|
||||
|
||||
result = DispatchResult(
|
||||
agents_run=["Agent1", "Agent2"],
|
||||
results=[],
|
||||
errors=[],
|
||||
)
|
||||
|
||||
assert result.agents_run == ["Agent1", "Agent2"]
|
||||
assert result.results == []
|
||||
assert result.errors == []
|
||||
|
||||
def test_result_with_errors(self):
|
||||
"""Test DispatchResult with errors."""
|
||||
from dispatcher import DispatchResult
|
||||
|
||||
result = DispatchResult(
|
||||
agents_run=["Agent1"],
|
||||
results=[],
|
||||
errors=["Error 1", "Error 2"],
|
||||
)
|
||||
|
||||
assert len(result.errors) == 2
|
||||
|
||||
|
||||
class TestAgentExecution:
|
||||
"""Test agent execution through dispatcher."""
|
||||
|
||||
def test_agent_receives_context(self):
|
||||
"""Test that agents receive proper context."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
received_context = None
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return True
|
||||
|
||||
def execute(self, context):
|
||||
nonlocal received_context
|
||||
received_context = context
|
||||
return AgentResult(success=True, message="done")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
MockAgent(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={"action": "opened", "issue": {"number": 123}},
|
||||
owner="testowner",
|
||||
repo="testrepo",
|
||||
)
|
||||
|
||||
assert received_context is not None
|
||||
assert received_context.owner == "testowner"
|
||||
assert received_context.repo == "testrepo"
|
||||
assert received_context.event_type == "issues"
|
||||
assert received_context.event_data["action"] == "opened"
|
||||
|
||||
def test_agent_failure_captured(self):
|
||||
"""Test that agent failures are captured in results."""
|
||||
from agents.base_agent import AgentContext, AgentResult, BaseAgent
|
||||
from dispatcher import Dispatcher
|
||||
|
||||
class FailingAgent(BaseAgent):
|
||||
def can_handle(self, event_type, event_data):
|
||||
return True
|
||||
|
||||
def execute(self, context):
|
||||
raise Exception("Test error")
|
||||
|
||||
dispatcher = Dispatcher()
|
||||
dispatcher.register_agent(
|
||||
FailingAgent(config={}, gitea_client=None, llm_client=None)
|
||||
)
|
||||
|
||||
result = dispatcher.dispatch(
|
||||
event_type="issues",
|
||||
event_data={},
|
||||
owner="test",
|
||||
repo="repo",
|
||||
)
|
||||
|
||||
# Agent should still be in agents_run
|
||||
assert len(result.agents_run) == 1
|
||||
# Result should indicate failure
|
||||
assert result.results[0].success is False
|
||||
|
||||
|
||||
class TestGetDispatcher:
|
||||
"""Test get_dispatcher factory function."""
|
||||
|
||||
def test_get_dispatcher_returns_singleton(self):
|
||||
"""Test that get_dispatcher returns configured dispatcher."""
|
||||
from dispatcher import get_dispatcher
|
||||
|
||||
dispatcher = get_dispatcher()
|
||||
assert dispatcher is not None
|
||||
|
||||
def test_get_dispatcher_with_config(self):
|
||||
"""Test get_dispatcher with custom config."""
|
||||
from dispatcher import get_dispatcher
|
||||
|
||||
config = {"test": "value"}
|
||||
dispatcher = get_dispatcher(config=config)
|
||||
assert dispatcher.config.get("test") == "value"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
456
tests/test_llm_client.py
Normal file
456
tests/test_llm_client.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Test Suite for LLM Client
|
||||
|
||||
Tests for LLM client functionality including provider support,
|
||||
tool calling, and JSON parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "tools", "ai-review"))
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLLMClientCreation:
|
||||
"""Test LLM client initialization."""
|
||||
|
||||
def test_create_openai_client(self):
|
||||
"""Test creating OpenAI client."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient(
|
||||
provider="openai",
|
||||
config={"model": "gpt-4o-mini", "api_key": "test-key"},
|
||||
)
|
||||
assert client.provider_name == "openai"
|
||||
|
||||
def test_create_openrouter_client(self):
|
||||
"""Test creating OpenRouter client."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient(
|
||||
provider="openrouter",
|
||||
config={"model": "anthropic/claude-3.5-sonnet", "api_key": "test-key"},
|
||||
)
|
||||
assert client.provider_name == "openrouter"
|
||||
|
||||
def test_create_ollama_client(self):
|
||||
"""Test creating Ollama client."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient(
|
||||
provider="ollama",
|
||||
config={"model": "codellama:13b", "host": "http://localhost:11434"},
|
||||
)
|
||||
assert client.provider_name == "ollama"
|
||||
|
||||
def test_invalid_provider_raises_error(self):
|
||||
"""Test that invalid provider raises ValueError."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown provider"):
|
||||
LLMClient(provider="invalid_provider")
|
||||
|
||||
def test_from_config_openai(self):
|
||||
"""Test creating client from config dict."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"model": {"openai": "gpt-4o-mini"},
|
||||
"temperature": 0,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
client = LLMClient.from_config(config)
|
||||
assert client.provider_name == "openai"
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Test LLM response dataclass."""
|
||||
|
||||
def test_response_creation(self):
|
||||
"""Test creating LLMResponse."""
|
||||
from clients.llm_client import LLMResponse
|
||||
|
||||
response = LLMResponse(
|
||||
content="Test response",
|
||||
model="gpt-4o-mini",
|
||||
provider="openai",
|
||||
tokens_used=100,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.model == "gpt-4o-mini"
|
||||
assert response.provider == "openai"
|
||||
assert response.tokens_used == 100
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.tool_calls is None
|
||||
|
||||
def test_response_with_tool_calls(self):
|
||||
"""Test LLMResponse with tool calls."""
|
||||
from clients.llm_client import LLMResponse, ToolCall
|
||||
|
||||
tool_calls = [ToolCall(id="call_1", name="search", arguments={"query": "test"})]
|
||||
|
||||
response = LLMResponse(
|
||||
content="",
|
||||
model="gpt-4o-mini",
|
||||
provider="openai",
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
assert response.tool_calls is not None
|
||||
assert len(response.tool_calls) == 1
|
||||
assert response.tool_calls[0].name == "search"
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
"""Test ToolCall dataclass."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test creating ToolCall."""
|
||||
from clients.llm_client import ToolCall
|
||||
|
||||
tool_call = ToolCall(
|
||||
id="call_123",
|
||||
name="search_codebase",
|
||||
arguments={"query": "authentication", "file_pattern": "*.py"},
|
||||
)
|
||||
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.name == "search_codebase"
|
||||
assert tool_call.arguments["query"] == "authentication"
|
||||
assert tool_call.arguments["file_pattern"] == "*.py"
|
||||
|
||||
|
||||
class TestJSONParsing:
|
||||
"""Test JSON extraction and parsing."""
|
||||
|
||||
def test_parse_direct_json(self):
|
||||
"""Test parsing direct JSON response."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = '{"key": "value", "number": 42}'
|
||||
result = client._extract_json(content)
|
||||
|
||||
assert result["key"] == "value"
|
||||
assert result["number"] == 42
|
||||
|
||||
def test_parse_json_in_code_block(self):
|
||||
"""Test parsing JSON in markdown code block."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """Here is the analysis:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "bug",
|
||||
"priority": "high"
|
||||
}
|
||||
```
|
||||
|
||||
That's my analysis."""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["type"] == "bug"
|
||||
assert result["priority"] == "high"
|
||||
|
||||
def test_parse_json_in_plain_code_block(self):
|
||||
"""Test parsing JSON in plain code block (no json specifier)."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """Analysis:
|
||||
|
||||
```
|
||||
{"status": "success", "count": 5}
|
||||
```
|
||||
"""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["status"] == "success"
|
||||
assert result["count"] == 5
|
||||
|
||||
def test_parse_json_with_preamble(self):
|
||||
"""Test parsing JSON with text before it."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """Based on my analysis, here is the result:
|
||||
{"findings": ["issue1", "issue2"], "severity": "medium"}
|
||||
"""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["findings"] == ["issue1", "issue2"]
|
||||
assert result["severity"] == "medium"
|
||||
|
||||
def test_parse_json_with_postamble(self):
|
||||
"""Test parsing JSON with text after it."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """{"result": true}
|
||||
|
||||
Let me know if you need more details."""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["result"] is True
|
||||
|
||||
def test_parse_nested_json(self):
|
||||
"""Test parsing nested JSON objects."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = """{
|
||||
"outer": {
|
||||
"inner": {
|
||||
"value": "deep"
|
||||
}
|
||||
},
|
||||
"array": [1, 2, 3]
|
||||
}"""
|
||||
|
||||
result = client._extract_json(content)
|
||||
assert result["outer"]["inner"]["value"] == "deep"
|
||||
assert result["array"] == [1, 2, 3]
|
||||
|
||||
def test_parse_invalid_json_raises_error(self):
|
||||
"""Test that invalid JSON raises ValueError."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = "This is not JSON at all"
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to parse JSON"):
|
||||
client._extract_json(content)
|
||||
|
||||
def test_parse_truncated_json_raises_error(self):
|
||||
"""Test that truncated JSON raises ValueError."""
|
||||
from clients.llm_client import LLMClient
|
||||
|
||||
client = LLMClient.__new__(LLMClient)
|
||||
|
||||
content = '{"key": "value", "incomplete'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client._extract_json(content)
|
||||
|
||||
|
||||
class TestOpenAIProvider:
|
||||
"""Test OpenAI provider."""
|
||||
|
||||
def test_provider_creation(self):
|
||||
"""Test OpenAI provider initialization."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.5,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
assert provider.model == "gpt-4o-mini"
|
||||
assert provider.temperature == 0.5
|
||||
assert provider.max_tokens == 2048
|
||||
|
||||
def test_provider_requires_api_key(self):
|
||||
"""Test that calling without API key raises error."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
provider = OpenAIProvider(api_key="")
|
||||
|
||||
with pytest.raises(ValueError, match="API key is required"):
|
||||
provider.call("test prompt")
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_provider_call_success(self, mock_post):
|
||||
"""Test successful API call."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": "Test response"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {"total_tokens": 50},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OpenAIProvider(api_key="test-key")
|
||||
response = provider.call("Hello")
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.provider == "openai"
|
||||
assert response.tokens_used == 50
|
||||
|
||||
|
||||
class TestOpenRouterProvider:
|
||||
"""Test OpenRouter provider."""
|
||||
|
||||
def test_provider_creation(self):
|
||||
"""Test OpenRouter provider initialization."""
|
||||
from clients.llm_client import OpenRouterProvider
|
||||
|
||||
provider = OpenRouterProvider(
|
||||
api_key="test-key",
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
)
|
||||
|
||||
assert provider.model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_provider_call_success(self, mock_post):
|
||||
"""Test successful API call."""
|
||||
from clients.llm_client import OpenRouterProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {"content": "Claude response"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"model": "anthropic/claude-3.5-sonnet",
|
||||
"usage": {"total_tokens": 75},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OpenRouterProvider(api_key="test-key")
|
||||
response = provider.call("Hello")
|
||||
|
||||
assert response.content == "Claude response"
|
||||
assert response.provider == "openrouter"
|
||||
|
||||
|
||||
class TestOllamaProvider:
|
||||
"""Test Ollama provider."""
|
||||
|
||||
def test_provider_creation(self):
|
||||
"""Test Ollama provider initialization."""
|
||||
from clients.llm_client import OllamaProvider
|
||||
|
||||
provider = OllamaProvider(
|
||||
host="http://localhost:11434",
|
||||
model="codellama:13b",
|
||||
)
|
||||
|
||||
assert provider.model == "codellama:13b"
|
||||
assert provider.host == "http://localhost:11434"
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_provider_call_success(self, mock_post):
|
||||
"""Test successful API call."""
|
||||
from clients.llm_client import OllamaProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"response": "Ollama response",
|
||||
"model": "codellama:13b",
|
||||
"done": True,
|
||||
"eval_count": 30,
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OllamaProvider()
|
||||
response = provider.call("Hello")
|
||||
|
||||
assert response.content == "Ollama response"
|
||||
assert response.provider == "ollama"
|
||||
|
||||
|
||||
class TestToolCalling:
|
||||
"""Test tool/function calling support."""
|
||||
|
||||
@patch("clients.llm_client.requests.post")
|
||||
def test_openai_tool_calling(self, mock_post):
|
||||
"""Test OpenAI tool calling."""
|
||||
from clients.llm_client import OpenAIProvider
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"function": {
|
||||
"name": "search_codebase",
|
||||
"arguments": '{"query": "auth"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o-mini",
|
||||
"usage": {"total_tokens": 100},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
provider = OpenAIProvider(api_key="test-key")
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_codebase",
|
||||
"description": "Search the codebase",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
response = provider.call_with_tools(
|
||||
messages=[{"role": "user", "content": "Search for auth"}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
assert response.tool_calls is not None
|
||||
assert len(response.tool_calls) == 1
|
||||
assert response.tool_calls[0].name == "search_codebase"
|
||||
assert response.tool_calls[0].arguments["query"] == "auth"
|
||||
|
||||
def test_ollama_tool_calling_not_supported(self):
|
||||
"""Test that Ollama raises NotImplementedError for tool calling."""
|
||||
from clients.llm_client import OllamaProvider
|
||||
|
||||
provider = OllamaProvider()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
provider.call_with_tools(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user