feat(agents): implement agent framework and CLI
This commit is contained in:
313
tests/test_llm.py
Normal file
313
tests/test_llm.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Tests for LLM client and prompts."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.llm.client import LiteLLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||
from tests.conftest import MockLLMClient
|
||||
|
||||
|
||||
class TestLiteLLMClient:
|
||||
def test_init_default_values(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
assert client.timeout == 60
|
||||
assert client.max_retries == 3
|
||||
|
||||
def test_init_custom_values(self) -> None:
|
||||
client = LiteLLMClient(timeout=120, max_retries=5)
|
||||
assert client.timeout == 120
|
||||
assert client.max_retries == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_returns_response(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
# Mock the litellm.acompletion function
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.001
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.tokens_in == 10
|
||||
assert response.tokens_out == 5
|
||||
assert response.cost_usd == 0.001
|
||||
|
||||
mock_acomp.assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
timeout=60,
|
||||
num_retries=3,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_handles_empty_content(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = None # None content
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 5
|
||||
mock_response.usage.completion_tokens = 0
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.content == "" # Should be empty string, not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_handles_missing_usage(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = None # No usage data
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.tokens_in == 0
|
||||
assert response.tokens_out == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_uses_fallback_model(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.model = None # No model in response
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 5
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "claude-3-opus")
|
||||
|
||||
# Should use the passed model as fallback
|
||||
assert response.model == "claude-3-opus"
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
def test_response_creation(self) -> None:
|
||||
response = LLMResponse(
|
||||
content="Hello, world!",
|
||||
model="gpt-4o",
|
||||
tokens_in=10,
|
||||
tokens_out=5,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
assert response.content == "Hello, world!"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.tokens_in == 10
|
||||
assert response.tokens_out == 5
|
||||
assert response.cost_usd == 0.001
|
||||
|
||||
|
||||
class TestMockLLMClient:
|
||||
@pytest.mark.asyncio
|
||||
async def test_records_calls(self) -> None:
|
||||
client = MockLLMClient()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert len(client.calls) == 1
|
||||
assert client.calls[0]["messages"] == messages
|
||||
assert client.calls[0]["model"] == "gpt-4o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_canned_responses(self) -> None:
|
||||
client = MockLLMClient(responses=["First", "Second"])
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
response1 = await client.complete(messages, "gpt-4o")
|
||||
response2 = await client.complete(messages, "gpt-4o")
|
||||
response3 = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response1.content == "First"
|
||||
assert response2.content == "Second"
|
||||
assert response3.content == "" # Exhausted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset(self) -> None:
|
||||
client = MockLLMClient(responses=["Hello"])
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
|
||||
await client.complete(messages, "gpt-4o")
|
||||
assert len(client.calls) == 1
|
||||
|
||||
client.reset()
|
||||
assert len(client.calls) == 0
|
||||
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
assert response.content == "Hello" # Responses reset too
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
def test_template_creation(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="security",
|
||||
version="1.0",
|
||||
content="Review: {{diff}}",
|
||||
)
|
||||
assert template.name == "security"
|
||||
assert template.version == "1.0"
|
||||
assert template.full_name == "security-v1.0"
|
||||
|
||||
def test_render_substitution(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="File: {{file}}\nDiff: {{diff}}",
|
||||
)
|
||||
result = template.render(file="test.py", diff="+ added line")
|
||||
assert result == "File: test.py\nDiff: + added line"
|
||||
|
||||
def test_render_missing_variable(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="Value: {{value}}",
|
||||
)
|
||||
result = template.render()
|
||||
assert result == "Value: {{value}}"
|
||||
|
||||
def test_render_multiple_occurrences(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="{{name}} and {{name}} again",
|
||||
)
|
||||
result = template.render(name="test")
|
||||
assert result == "test and test again"
|
||||
|
||||
|
||||
class TestPromptRegistry:
|
||||
def test_get_template(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template = registry.get("security", "1.0")
|
||||
|
||||
assert template.name == "security"
|
||||
assert template.version == "1.0"
|
||||
assert "{{diff}}" in template.content
|
||||
|
||||
def test_get_template_cached(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template1 = registry.get("security", "1.0")
|
||||
template2 = registry.get("security", "1.0")
|
||||
|
||||
assert template1 is template2
|
||||
|
||||
def test_get_template_not_found(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
registry.get("missing", "1.0")
|
||||
|
||||
def test_list_templates(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||
(templates_dir / "style-v2.0.md").write_text("Content")
|
||||
(templates_dir / "readme.md").write_text("Not a template")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert len(templates) == 2
|
||||
assert ("security", "1.0") in templates
|
||||
assert ("style", "2.0") in templates
|
||||
|
||||
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert templates == []
|
||||
|
||||
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "missing"
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert templates == []
|
||||
|
||||
def test_clear_cache(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "test-v1.0.md").write_text("Original")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template1 = registry.get("test", "1.0")
|
||||
assert template1.content == "Original"
|
||||
|
||||
# Modify file
|
||||
(templates_dir / "test-v1.0.md").write_text("Modified")
|
||||
|
||||
# Still cached
|
||||
template2 = registry.get("test", "1.0")
|
||||
assert template2.content == "Original"
|
||||
|
||||
# Clear cache
|
||||
registry.clear_cache()
|
||||
|
||||
# Now reads new content
|
||||
template3 = registry.get("test", "1.0")
|
||||
assert template3.content == "Modified"
|
||||
Reference in New Issue
Block a user