feat(agents): implement agent framework and CLI

This commit is contained in:
2025-03-08 15:52:29 +00:00
parent 72268ff440
commit f22ca1d5bd
30 changed files with 3466 additions and 0 deletions

313
tests/test_llm.py Normal file
View File

@@ -0,0 +1,313 @@
"""Tests for LLM client and prompts."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from arbiter.llm.client import LiteLLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
from tests.conftest import MockLLMClient
class TestLiteLLMClient:
def test_init_default_values(self) -> None:
client = LiteLLMClient()
assert client.timeout == 60
assert client.max_retries == 3
def test_init_custom_values(self) -> None:
client = LiteLLMClient(timeout=120, max_retries=5)
assert client.timeout == 120
assert client.max_retries == 5
@pytest.mark.asyncio
async def test_complete_returns_response(self) -> None:
client = LiteLLMClient()
# Mock the litellm.acompletion function
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.model = "gpt-4o"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.001
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.content == "Test response"
assert response.model == "gpt-4o"
assert response.tokens_in == 10
assert response.tokens_out == 5
assert response.cost_usd == 0.001
mock_acomp.assert_called_once_with(
model="gpt-4o",
messages=messages,
timeout=60,
num_retries=3,
)
@pytest.mark.asyncio
async def test_complete_handles_empty_content(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = None # None content
mock_response.model = "gpt-4o"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.usage.completion_tokens = 0
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.content == "" # Should be empty string, not None
@pytest.mark.asyncio
async def test_complete_handles_missing_usage(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.model = "gpt-4o"
mock_response.usage = None # No usage data
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.tokens_in == 0
assert response.tokens_out == 0
@pytest.mark.asyncio
async def test_complete_uses_fallback_model(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.model = None # No model in response
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.usage.completion_tokens = 3
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "claude-3-opus")
# Should use the passed model as fallback
assert response.model == "claude-3-opus"
class TestLLMResponse:
def test_response_creation(self) -> None:
response = LLMResponse(
content="Hello, world!",
model="gpt-4o",
tokens_in=10,
tokens_out=5,
cost_usd=0.001,
)
assert response.content == "Hello, world!"
assert response.model == "gpt-4o"
assert response.tokens_in == 10
assert response.tokens_out == 5
assert response.cost_usd == 0.001
class TestMockLLMClient:
@pytest.mark.asyncio
async def test_records_calls(self) -> None:
client = MockLLMClient()
messages = [{"role": "user", "content": "Hello"}]
await client.complete(messages, "gpt-4o")
assert len(client.calls) == 1
assert client.calls[0]["messages"] == messages
assert client.calls[0]["model"] == "gpt-4o"
@pytest.mark.asyncio
async def test_returns_canned_responses(self) -> None:
client = MockLLMClient(responses=["First", "Second"])
messages = [{"role": "user", "content": "Hello"}]
response1 = await client.complete(messages, "gpt-4o")
response2 = await client.complete(messages, "gpt-4o")
response3 = await client.complete(messages, "gpt-4o")
assert response1.content == "First"
assert response2.content == "Second"
assert response3.content == "" # Exhausted
@pytest.mark.asyncio
async def test_reset(self) -> None:
client = MockLLMClient(responses=["Hello"])
messages = [{"role": "user", "content": "Hi"}]
await client.complete(messages, "gpt-4o")
assert len(client.calls) == 1
client.reset()
assert len(client.calls) == 0
response = await client.complete(messages, "gpt-4o")
assert response.content == "Hello" # Responses reset too
class TestPromptTemplate:
def test_template_creation(self) -> None:
template = PromptTemplate(
name="security",
version="1.0",
content="Review: {{diff}}",
)
assert template.name == "security"
assert template.version == "1.0"
assert template.full_name == "security-v1.0"
def test_render_substitution(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="File: {{file}}\nDiff: {{diff}}",
)
result = template.render(file="test.py", diff="+ added line")
assert result == "File: test.py\nDiff: + added line"
def test_render_missing_variable(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="Value: {{value}}",
)
result = template.render()
assert result == "Value: {{value}}"
def test_render_multiple_occurrences(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="{{name}} and {{name}} again",
)
result = template.render(name="test")
assert result == "test and test again"
class TestPromptRegistry:
def test_get_template(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
registry = PromptRegistry(templates_dir)
template = registry.get("security", "1.0")
assert template.name == "security"
assert template.version == "1.0"
assert "{{diff}}" in template.content
def test_get_template_cached(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Content")
registry = PromptRegistry(templates_dir)
template1 = registry.get("security", "1.0")
template2 = registry.get("security", "1.0")
assert template1 is template2
def test_get_template_not_found(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
registry = PromptRegistry(templates_dir)
with pytest.raises(FileNotFoundError):
registry.get("missing", "1.0")
def test_list_templates(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Content")
(templates_dir / "style-v2.0.md").write_text("Content")
(templates_dir / "readme.md").write_text("Not a template")
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert len(templates) == 2
assert ("security", "1.0") in templates
assert ("style", "2.0") in templates
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert templates == []
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "missing"
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert templates == []
def test_clear_cache(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "test-v1.0.md").write_text("Original")
registry = PromptRegistry(templates_dir)
template1 = registry.get("test", "1.0")
assert template1.content == "Original"
# Modify file
(templates_dir / "test-v1.0.md").write_text("Modified")
# Still cached
template2 = registry.get("test", "1.0")
assert template2.content == "Original"
# Clear cache
registry.clear_cache()
# Now reads new content
template3 = registry.get("test", "1.0")
assert template3.content == "Modified"