"""Tests for the LLM cache module.""" import pytest from arbiter.llm.cache import LLMCache, compute_policy_hash from arbiter.llm.client import LLMResponse class TestComputePolicyHash: def test_compute_policy_hash_deterministic(self) -> None: policy = {"agents": {"security": {"enabled": True}}} hash1 = compute_policy_hash(policy) hash2 = compute_policy_hash(policy) assert hash1 == hash2 def test_policy_hash_varies(self) -> None: policy1 = {"agents": {"security": {"enabled": True}}} policy2 = {"agents": {"security": {"enabled": False}}} assert compute_policy_hash(policy1) != compute_policy_hash(policy2) def test_compute_policy_hash_format(self) -> None: policy = {"test": "data"} hash_value = compute_policy_hash(policy) assert len(hash_value) == 16 assert all(c in "0123456789abcdef" for c in hash_value) class MockRedisForCache: """Mock Redis client for cache testing.""" def __init__(self) -> None: self._data: dict[str, str] = {} async def get(self, key: str) -> str | None: return self._data.get(key) async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002 self._data[key] = value return True async def delete(self, key: str) -> int: if key in self._data: del self._data[key] return 1 return 0 def scan_iter(self, match: str | None = None): # noqa: ARG002 async def _gen(): for key in list(self._data.keys()): yield key return _gen() class TestLLMCache: @pytest.fixture def cache(self) -> LLMCache: mock_redis = MockRedisForCache() return LLMCache(mock_redis) # type: ignore[arg-type] def test_compute_key(self, cache: LLMCache) -> None: key = cache._compute_key("diff content", "security", "v1.0", "policy123") assert key.startswith("arbiter:llm:cache:") assert len(key) > 20 # prefix + hash def test_compute_key_deterministic(self, cache: LLMCache) -> None: key1 = cache._compute_key("diff", "security", "v1.0") key2 = cache._compute_key("diff", "security", "v1.0") assert key1 == key2 def test_compute_key_unique(self, cache: LLMCache) -> None: key1 = cache._compute_key("diff1", "security", "v1.0") key2 = cache._compute_key("diff2", "security", "v1.0") key3 = cache._compute_key("diff1", "style", "v1.0") key4 = cache._compute_key("diff1", "security", "v2.0") assert len({key1, key2, key3, key4}) == 4 def test_serialize_deserialize_response(self, cache: LLMCache) -> None: response = LLMResponse( content="test content", model="gpt-4o", tokens_in=100, tokens_out=50, cost_usd=0.01, ) serialized = cache._serialize_response(response) deserialized = cache._deserialize_response(serialized) assert deserialized.content == response.content assert deserialized.model == response.model assert deserialized.tokens_in == response.tokens_in assert deserialized.tokens_out == response.tokens_out assert deserialized.cost_usd == response.cost_usd async def test_cache_get_miss(self, cache: LLMCache) -> None: result = await cache.get("diff", "security", "v1.0") assert result is None assert cache._misses == 1 assert cache._hits == 0 async def test_cache_set_and_get(self, cache: LLMCache) -> None: response = LLMResponse( content="cached content", model="gpt-4o", tokens_in=100, tokens_out=50, cost_usd=0.01, ) await cache.set("diff", "security", "v1.0", response) result = await cache.get("diff", "security", "v1.0") assert result is not None assert result.content == "cached content" assert cache._hits == 1 async def test_cache_invalidate(self, cache: LLMCache) -> None: response = LLMResponse( content="test", model="gpt-4o", tokens_in=100, tokens_out=50, cost_usd=0.01, ) await cache.set("diff", "security", "v1.0", response) deleted = await cache.invalidate("diff", "security", "v1.0") assert deleted is True result = await cache.get("diff", "security", "v1.0") assert result is None async def test_cache_invalidate_nonexistent(self, cache: LLMCache) -> None: deleted = await cache.invalidate("nonexistent", "security", "v1.0") assert deleted is False def test_get_stats(self, cache: LLMCache) -> None: stats = cache.get_stats() assert stats["hits"] == 0 assert stats["misses"] == 0 assert stats["total"] == 0 assert stats["hit_rate"] == 0.0 async def test_get_stats_after_operations(self, cache: LLMCache) -> None: await cache.get("key1", "agent", "v1") # miss await cache.get("key2", "agent", "v1") # miss response = LLMResponse( content="test", model="gpt-4o", tokens_in=100, tokens_out=50, cost_usd=0.01, ) await cache.set("key1", "agent", "v1", response) await cache.get("key1", "agent", "v1") # hit stats = cache.get_stats() assert stats["hits"] == 1 assert stats["misses"] == 2 assert stats["total"] == 3 assert stats["hit_rate"] == pytest.approx(1 / 3)