165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
"""Tests for the LLM cache module."""
|
|
|
|
import pytest
|
|
|
|
from arbiter.llm.cache import LLMCache, compute_policy_hash
|
|
from arbiter.llm.client import LLMResponse
|
|
|
|
|
|
class TestComputePolicyHash:
|
|
def test_compute_policy_hash_deterministic(self) -> None:
|
|
policy = {"agents": {"security": {"enabled": True}}}
|
|
hash1 = compute_policy_hash(policy)
|
|
hash2 = compute_policy_hash(policy)
|
|
assert hash1 == hash2
|
|
|
|
def test_policy_hash_varies(self) -> None:
|
|
policy1 = {"agents": {"security": {"enabled": True}}}
|
|
policy2 = {"agents": {"security": {"enabled": False}}}
|
|
assert compute_policy_hash(policy1) != compute_policy_hash(policy2)
|
|
|
|
def test_compute_policy_hash_format(self) -> None:
|
|
policy = {"test": "data"}
|
|
hash_value = compute_policy_hash(policy)
|
|
assert len(hash_value) == 16
|
|
assert all(c in "0123456789abcdef" for c in hash_value)
|
|
|
|
|
|
class MockRedisForCache:
|
|
"""Mock Redis client for cache testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self._data: dict[str, str] = {}
|
|
|
|
async def get(self, key: str) -> str | None:
|
|
return self._data.get(key)
|
|
|
|
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
|
self._data[key] = value
|
|
return True
|
|
|
|
async def delete(self, key: str) -> int:
|
|
if key in self._data:
|
|
del self._data[key]
|
|
return 1
|
|
return 0
|
|
|
|
def scan_iter(self, match: str | None = None): # noqa: ARG002
|
|
async def _gen():
|
|
for key in list(self._data.keys()):
|
|
yield key
|
|
|
|
return _gen()
|
|
|
|
|
|
class TestLLMCache:
|
|
@pytest.fixture
|
|
def cache(self) -> LLMCache:
|
|
mock_redis = MockRedisForCache()
|
|
return LLMCache(mock_redis) # type: ignore[arg-type]
|
|
|
|
def test_compute_key(self, cache: LLMCache) -> None:
|
|
key = cache._compute_key("diff content", "security", "v1.0", "policy123")
|
|
assert key.startswith("arbiter:llm:cache:")
|
|
assert len(key) > 20 # prefix + hash
|
|
|
|
def test_compute_key_deterministic(self, cache: LLMCache) -> None:
|
|
key1 = cache._compute_key("diff", "security", "v1.0")
|
|
key2 = cache._compute_key("diff", "security", "v1.0")
|
|
assert key1 == key2
|
|
|
|
def test_compute_key_unique(self, cache: LLMCache) -> None:
|
|
key1 = cache._compute_key("diff1", "security", "v1.0")
|
|
key2 = cache._compute_key("diff2", "security", "v1.0")
|
|
key3 = cache._compute_key("diff1", "style", "v1.0")
|
|
key4 = cache._compute_key("diff1", "security", "v2.0")
|
|
|
|
assert len({key1, key2, key3, key4}) == 4
|
|
|
|
def test_serialize_deserialize_response(self, cache: LLMCache) -> None:
|
|
response = LLMResponse(
|
|
content="test content",
|
|
model="gpt-4o",
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
cost_usd=0.01,
|
|
)
|
|
|
|
serialized = cache._serialize_response(response)
|
|
deserialized = cache._deserialize_response(serialized)
|
|
|
|
assert deserialized.content == response.content
|
|
assert deserialized.model == response.model
|
|
assert deserialized.tokens_in == response.tokens_in
|
|
assert deserialized.tokens_out == response.tokens_out
|
|
assert deserialized.cost_usd == response.cost_usd
|
|
|
|
async def test_cache_get_miss(self, cache: LLMCache) -> None:
|
|
result = await cache.get("diff", "security", "v1.0")
|
|
assert result is None
|
|
assert cache._misses == 1
|
|
assert cache._hits == 0
|
|
|
|
async def test_cache_set_and_get(self, cache: LLMCache) -> None:
|
|
response = LLMResponse(
|
|
content="cached content",
|
|
model="gpt-4o",
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
cost_usd=0.01,
|
|
)
|
|
|
|
await cache.set("diff", "security", "v1.0", response)
|
|
result = await cache.get("diff", "security", "v1.0")
|
|
|
|
assert result is not None
|
|
assert result.content == "cached content"
|
|
assert cache._hits == 1
|
|
|
|
async def test_cache_invalidate(self, cache: LLMCache) -> None:
|
|
response = LLMResponse(
|
|
content="test",
|
|
model="gpt-4o",
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
cost_usd=0.01,
|
|
)
|
|
|
|
await cache.set("diff", "security", "v1.0", response)
|
|
deleted = await cache.invalidate("diff", "security", "v1.0")
|
|
assert deleted is True
|
|
|
|
result = await cache.get("diff", "security", "v1.0")
|
|
assert result is None
|
|
|
|
async def test_cache_invalidate_nonexistent(self, cache: LLMCache) -> None:
|
|
deleted = await cache.invalidate("nonexistent", "security", "v1.0")
|
|
assert deleted is False
|
|
|
|
def test_get_stats(self, cache: LLMCache) -> None:
|
|
stats = cache.get_stats()
|
|
assert stats["hits"] == 0
|
|
assert stats["misses"] == 0
|
|
assert stats["total"] == 0
|
|
assert stats["hit_rate"] == 0.0
|
|
|
|
async def test_get_stats_after_operations(self, cache: LLMCache) -> None:
|
|
await cache.get("key1", "agent", "v1") # miss
|
|
await cache.get("key2", "agent", "v1") # miss
|
|
|
|
response = LLMResponse(
|
|
content="test",
|
|
model="gpt-4o",
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
cost_usd=0.01,
|
|
)
|
|
await cache.set("key1", "agent", "v1", response)
|
|
await cache.get("key1", "agent", "v1") # hit
|
|
|
|
stats = cache.get_stats()
|
|
assert stats["hits"] == 1
|
|
assert stats["misses"] == 2
|
|
assert stats["total"] == 3
|
|
assert stats["hit_rate"] == pytest.approx(1 / 3)
|