"""Tests for LLM client and prompts.""" from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from arbiter.llm.client import LiteLLMClient, LLMResponse from arbiter.llm.prompts import PromptRegistry, PromptTemplate from tests.conftest import MockLLMClient class TestLiteLLMClient: def test_init_default_values(self) -> None: client = LiteLLMClient() assert client.timeout == 60 assert client.max_retries == 3 def test_init_custom_values(self) -> None: client = LiteLLMClient(timeout=120, max_retries=5) assert client.timeout == 120 assert client.max_retries == 5 @pytest.mark.asyncio async def test_complete_returns_response(self) -> None: client = LiteLLMClient() # Mock the litellm.acompletion function mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "Test response" mock_response.model = "gpt-4o" mock_response.usage = MagicMock() mock_response.usage.prompt_tokens = 10 mock_response.usage.completion_tokens = 5 with ( patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, ): mock_acomp.return_value = mock_response mock_cost.return_value = 0.001 messages = [{"role": "user", "content": "Hello"}] response = await client.complete(messages, "gpt-4o") assert response.content == "Test response" assert response.model == "gpt-4o" assert response.tokens_in == 10 assert response.tokens_out == 5 assert response.cost_usd == 0.001 mock_acomp.assert_called_once_with( model="gpt-4o", messages=messages, timeout=60, num_retries=3, ) @pytest.mark.asyncio async def test_complete_handles_empty_content(self) -> None: client = LiteLLMClient() mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = None # None content mock_response.model = "gpt-4o" mock_response.usage = MagicMock() mock_response.usage.prompt_tokens = 5 mock_response.usage.completion_tokens = 0 with ( patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, ): mock_acomp.return_value = mock_response mock_cost.return_value = 0.0 messages = [{"role": "user", "content": "Hello"}] response = await client.complete(messages, "gpt-4o") assert response.content == "" # Should be empty string, not None @pytest.mark.asyncio async def test_complete_handles_missing_usage(self) -> None: client = LiteLLMClient() mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "Response" mock_response.model = "gpt-4o" mock_response.usage = None # No usage data with ( patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, ): mock_acomp.return_value = mock_response mock_cost.return_value = 0.0 messages = [{"role": "user", "content": "Hello"}] response = await client.complete(messages, "gpt-4o") assert response.tokens_in == 0 assert response.tokens_out == 0 @pytest.mark.asyncio async def test_complete_uses_fallback_model(self) -> None: client = LiteLLMClient() mock_response = MagicMock() mock_response.choices = [MagicMock()] mock_response.choices[0].message.content = "Response" mock_response.model = None # No model in response mock_response.usage = MagicMock() mock_response.usage.prompt_tokens = 5 mock_response.usage.completion_tokens = 3 with ( patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, ): mock_acomp.return_value = mock_response mock_cost.return_value = 0.0 messages = [{"role": "user", "content": "Hello"}] response = await client.complete(messages, "claude-3-opus") # Should use the passed model as fallback assert response.model == "claude-3-opus" class TestLLMResponse: def test_response_creation(self) -> None: response = LLMResponse( content="Hello, world!", model="gpt-4o", tokens_in=10, tokens_out=5, cost_usd=0.001, ) assert response.content == "Hello, world!" assert response.model == "gpt-4o" assert response.tokens_in == 10 assert response.tokens_out == 5 assert response.cost_usd == 0.001 class TestMockLLMClient: @pytest.mark.asyncio async def test_records_calls(self) -> None: client = MockLLMClient() messages = [{"role": "user", "content": "Hello"}] await client.complete(messages, "gpt-4o") assert len(client.calls) == 1 assert client.calls[0]["messages"] == messages assert client.calls[0]["model"] == "gpt-4o" @pytest.mark.asyncio async def test_returns_canned_responses(self) -> None: client = MockLLMClient(responses=["First", "Second"]) messages = [{"role": "user", "content": "Hello"}] response1 = await client.complete(messages, "gpt-4o") response2 = await client.complete(messages, "gpt-4o") response3 = await client.complete(messages, "gpt-4o") assert response1.content == "First" assert response2.content == "Second" assert response3.content == "" # Exhausted @pytest.mark.asyncio async def test_reset(self) -> None: client = MockLLMClient(responses=["Hello"]) messages = [{"role": "user", "content": "Hi"}] await client.complete(messages, "gpt-4o") assert len(client.calls) == 1 client.reset() assert len(client.calls) == 0 response = await client.complete(messages, "gpt-4o") assert response.content == "Hello" # Responses reset too class TestPromptTemplate: def test_template_creation(self) -> None: template = PromptTemplate( name="security", version="1.0", content="Review: {{diff}}", ) assert template.name == "security" assert template.version == "1.0" assert template.full_name == "security-v1.0" def test_render_substitution(self) -> None: template = PromptTemplate( name="test", version="1.0", content="File: {{file}}\nDiff: {{diff}}", ) result = template.render(file="test.py", diff="+ added line") assert result == "File: test.py\nDiff: + added line" def test_render_missing_variable(self) -> None: template = PromptTemplate( name="test", version="1.0", content="Value: {{value}}", ) result = template.render() assert result == "Value: {{value}}" def test_render_multiple_occurrences(self) -> None: template = PromptTemplate( name="test", version="1.0", content="{{name}} and {{name}} again", ) result = template.render(name="test") assert result == "test and test again" class TestPromptRegistry: def test_get_template(self, tmp_path: Path) -> None: templates_dir = tmp_path / "templates" templates_dir.mkdir() (templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}") registry = PromptRegistry(templates_dir) template = registry.get("security", "1.0") assert template.name == "security" assert template.version == "1.0" assert "{{diff}}" in template.content def test_get_template_cached(self, tmp_path: Path) -> None: templates_dir = tmp_path / "templates" templates_dir.mkdir() (templates_dir / "security-v1.0.md").write_text("Content") registry = PromptRegistry(templates_dir) template1 = registry.get("security", "1.0") template2 = registry.get("security", "1.0") assert template1 is template2 def test_get_template_not_found(self, tmp_path: Path) -> None: templates_dir = tmp_path / "templates" templates_dir.mkdir() registry = PromptRegistry(templates_dir) with pytest.raises(FileNotFoundError): registry.get("missing", "1.0") def test_list_templates(self, tmp_path: Path) -> None: templates_dir = tmp_path / "templates" templates_dir.mkdir() (templates_dir / "security-v1.0.md").write_text("Content") (templates_dir / "style-v2.0.md").write_text("Content") (templates_dir / "readme.md").write_text("Not a template") registry = PromptRegistry(templates_dir) templates = registry.list_templates() assert len(templates) == 2 assert ("security", "1.0") in templates assert ("style", "2.0") in templates def test_list_templates_empty_dir(self, tmp_path: Path) -> None: templates_dir = tmp_path / "templates" templates_dir.mkdir() registry = PromptRegistry(templates_dir) templates = registry.list_templates() assert templates == [] def test_list_templates_missing_dir(self, tmp_path: Path) -> None: templates_dir = tmp_path / "missing" registry = PromptRegistry(templates_dir) templates = registry.list_templates() assert templates == [] def test_clear_cache(self, tmp_path: Path) -> None: templates_dir = tmp_path / "templates" templates_dir.mkdir() (templates_dir / "test-v1.0.md").write_text("Original") registry = PromptRegistry(templates_dir) template1 = registry.get("test", "1.0") assert template1.content == "Original" # Modify file (templates_dir / "test-v1.0.md").write_text("Modified") # Still cached template2 = registry.get("test", "1.0") assert template2.content == "Original" # Clear cache registry.clear_cache() # Now reads new content template3 = registry.get("test", "1.0") assert template3.content == "Modified"