314 lines
11 KiB
Python
314 lines
11 KiB
Python
"""Tests for LLM client and prompts."""
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from arbiter.llm.client import LiteLLMClient, LLMResponse
|
|
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
|
from tests.conftest import MockLLMClient
|
|
|
|
|
|
class TestLiteLLMClient:
|
|
def test_init_default_values(self) -> None:
|
|
client = LiteLLMClient()
|
|
assert client.timeout == 60
|
|
assert client.max_retries == 3
|
|
|
|
def test_init_custom_values(self) -> None:
|
|
client = LiteLLMClient(timeout=120, max_retries=5)
|
|
assert client.timeout == 120
|
|
assert client.max_retries == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_returns_response(self) -> None:
|
|
client = LiteLLMClient()
|
|
|
|
# Mock the litellm.acompletion function
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Test response"
|
|
mock_response.model = "gpt-4o"
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.prompt_tokens = 10
|
|
mock_response.usage.completion_tokens = 5
|
|
|
|
with (
|
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
|
):
|
|
mock_acomp.return_value = mock_response
|
|
mock_cost.return_value = 0.001
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
response = await client.complete(messages, "gpt-4o")
|
|
|
|
assert response.content == "Test response"
|
|
assert response.model == "gpt-4o"
|
|
assert response.tokens_in == 10
|
|
assert response.tokens_out == 5
|
|
assert response.cost_usd == 0.001
|
|
|
|
mock_acomp.assert_called_once_with(
|
|
model="gpt-4o",
|
|
messages=messages,
|
|
timeout=60,
|
|
num_retries=3,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_handles_empty_content(self) -> None:
|
|
client = LiteLLMClient()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = None # None content
|
|
mock_response.model = "gpt-4o"
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.prompt_tokens = 5
|
|
mock_response.usage.completion_tokens = 0
|
|
|
|
with (
|
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
|
):
|
|
mock_acomp.return_value = mock_response
|
|
mock_cost.return_value = 0.0
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
response = await client.complete(messages, "gpt-4o")
|
|
|
|
assert response.content == "" # Should be empty string, not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_handles_missing_usage(self) -> None:
|
|
client = LiteLLMClient()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Response"
|
|
mock_response.model = "gpt-4o"
|
|
mock_response.usage = None # No usage data
|
|
|
|
with (
|
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
|
):
|
|
mock_acomp.return_value = mock_response
|
|
mock_cost.return_value = 0.0
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
response = await client.complete(messages, "gpt-4o")
|
|
|
|
assert response.tokens_in == 0
|
|
assert response.tokens_out == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_uses_fallback_model(self) -> None:
|
|
client = LiteLLMClient()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [MagicMock()]
|
|
mock_response.choices[0].message.content = "Response"
|
|
mock_response.model = None # No model in response
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.prompt_tokens = 5
|
|
mock_response.usage.completion_tokens = 3
|
|
|
|
with (
|
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
|
):
|
|
mock_acomp.return_value = mock_response
|
|
mock_cost.return_value = 0.0
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
response = await client.complete(messages, "claude-3-opus")
|
|
|
|
# Should use the passed model as fallback
|
|
assert response.model == "claude-3-opus"
|
|
|
|
|
|
class TestLLMResponse:
|
|
def test_response_creation(self) -> None:
|
|
response = LLMResponse(
|
|
content="Hello, world!",
|
|
model="gpt-4o",
|
|
tokens_in=10,
|
|
tokens_out=5,
|
|
cost_usd=0.001,
|
|
)
|
|
assert response.content == "Hello, world!"
|
|
assert response.model == "gpt-4o"
|
|
assert response.tokens_in == 10
|
|
assert response.tokens_out == 5
|
|
assert response.cost_usd == 0.001
|
|
|
|
|
|
class TestMockLLMClient:
|
|
@pytest.mark.asyncio
|
|
async def test_records_calls(self) -> None:
|
|
client = MockLLMClient()
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
await client.complete(messages, "gpt-4o")
|
|
|
|
assert len(client.calls) == 1
|
|
assert client.calls[0]["messages"] == messages
|
|
assert client.calls[0]["model"] == "gpt-4o"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_canned_responses(self) -> None:
|
|
client = MockLLMClient(responses=["First", "Second"])
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
response1 = await client.complete(messages, "gpt-4o")
|
|
response2 = await client.complete(messages, "gpt-4o")
|
|
response3 = await client.complete(messages, "gpt-4o")
|
|
|
|
assert response1.content == "First"
|
|
assert response2.content == "Second"
|
|
assert response3.content == "" # Exhausted
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reset(self) -> None:
|
|
client = MockLLMClient(responses=["Hello"])
|
|
messages = [{"role": "user", "content": "Hi"}]
|
|
|
|
await client.complete(messages, "gpt-4o")
|
|
assert len(client.calls) == 1
|
|
|
|
client.reset()
|
|
assert len(client.calls) == 0
|
|
|
|
response = await client.complete(messages, "gpt-4o")
|
|
assert response.content == "Hello" # Responses reset too
|
|
|
|
|
|
class TestPromptTemplate:
|
|
def test_template_creation(self) -> None:
|
|
template = PromptTemplate(
|
|
name="security",
|
|
version="1.0",
|
|
content="Review: {{diff}}",
|
|
)
|
|
assert template.name == "security"
|
|
assert template.version == "1.0"
|
|
assert template.full_name == "security-v1.0"
|
|
|
|
def test_render_substitution(self) -> None:
|
|
template = PromptTemplate(
|
|
name="test",
|
|
version="1.0",
|
|
content="File: {{file}}\nDiff: {{diff}}",
|
|
)
|
|
result = template.render(file="test.py", diff="+ added line")
|
|
assert result == "File: test.py\nDiff: + added line"
|
|
|
|
def test_render_missing_variable(self) -> None:
|
|
template = PromptTemplate(
|
|
name="test",
|
|
version="1.0",
|
|
content="Value: {{value}}",
|
|
)
|
|
result = template.render()
|
|
assert result == "Value: {{value}}"
|
|
|
|
def test_render_multiple_occurrences(self) -> None:
|
|
template = PromptTemplate(
|
|
name="test",
|
|
version="1.0",
|
|
content="{{name}} and {{name}} again",
|
|
)
|
|
result = template.render(name="test")
|
|
assert result == "test and test again"
|
|
|
|
|
|
class TestPromptRegistry:
|
|
def test_get_template(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
template = registry.get("security", "1.0")
|
|
|
|
assert template.name == "security"
|
|
assert template.version == "1.0"
|
|
assert "{{diff}}" in template.content
|
|
|
|
def test_get_template_cached(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
(templates_dir / "security-v1.0.md").write_text("Content")
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
template1 = registry.get("security", "1.0")
|
|
template2 = registry.get("security", "1.0")
|
|
|
|
assert template1 is template2
|
|
|
|
def test_get_template_not_found(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
registry.get("missing", "1.0")
|
|
|
|
def test_list_templates(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
(templates_dir / "security-v1.0.md").write_text("Content")
|
|
(templates_dir / "style-v2.0.md").write_text("Content")
|
|
(templates_dir / "readme.md").write_text("Not a template")
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
templates = registry.list_templates()
|
|
|
|
assert len(templates) == 2
|
|
assert ("security", "1.0") in templates
|
|
assert ("style", "2.0") in templates
|
|
|
|
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
templates = registry.list_templates()
|
|
|
|
assert templates == []
|
|
|
|
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "missing"
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
templates = registry.list_templates()
|
|
|
|
assert templates == []
|
|
|
|
def test_clear_cache(self, tmp_path: Path) -> None:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
(templates_dir / "test-v1.0.md").write_text("Original")
|
|
|
|
registry = PromptRegistry(templates_dir)
|
|
template1 = registry.get("test", "1.0")
|
|
assert template1.content == "Original"
|
|
|
|
# Modify file
|
|
(templates_dir / "test-v1.0.md").write_text("Modified")
|
|
|
|
# Still cached
|
|
template2 = registry.get("test", "1.0")
|
|
assert template2.content == "Original"
|
|
|
|
# Clear cache
|
|
registry.clear_cache()
|
|
|
|
# Now reads new content
|
|
template3 = registry.get("test", "1.0")
|
|
assert template3.content == "Modified"
|