tests for api, worker, cache
This commit is contained in:
164
tests/test_cache.py
Normal file
164
tests/test_cache.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for the LLM cache module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.llm.cache import LLMCache, compute_policy_hash
|
||||
from arbiter.llm.client import LLMResponse
|
||||
|
||||
|
||||
class TestComputePolicyHash:
|
||||
def test_compute_policy_hash_deterministic(self) -> None:
|
||||
policy = {"agents": {"security": {"enabled": True}}}
|
||||
hash1 = compute_policy_hash(policy)
|
||||
hash2 = compute_policy_hash(policy)
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_policy_hash_varies(self) -> None:
|
||||
policy1 = {"agents": {"security": {"enabled": True}}}
|
||||
policy2 = {"agents": {"security": {"enabled": False}}}
|
||||
assert compute_policy_hash(policy1) != compute_policy_hash(policy2)
|
||||
|
||||
def test_compute_policy_hash_format(self) -> None:
|
||||
policy = {"test": "data"}
|
||||
hash_value = compute_policy_hash(policy)
|
||||
assert len(hash_value) == 16
|
||||
assert all(c in "0123456789abcdef" for c in hash_value)
|
||||
|
||||
|
||||
class MockRedisForCache:
|
||||
"""Mock Redis client for cache testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, str] = {}
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._data.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||
self._data[key] = value
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def scan_iter(self, match: str | None = None): # noqa: ARG002
|
||||
async def _gen():
|
||||
for key in list(self._data.keys()):
|
||||
yield key
|
||||
|
||||
return _gen()
|
||||
|
||||
|
||||
class TestLLMCache:
|
||||
@pytest.fixture
|
||||
def cache(self) -> LLMCache:
|
||||
mock_redis = MockRedisForCache()
|
||||
return LLMCache(mock_redis) # type: ignore[arg-type]
|
||||
|
||||
def test_compute_key(self, cache: LLMCache) -> None:
|
||||
key = cache._compute_key("diff content", "security", "v1.0", "policy123")
|
||||
assert key.startswith("arbiter:llm:cache:")
|
||||
assert len(key) > 20 # prefix + hash
|
||||
|
||||
def test_compute_key_deterministic(self, cache: LLMCache) -> None:
|
||||
key1 = cache._compute_key("diff", "security", "v1.0")
|
||||
key2 = cache._compute_key("diff", "security", "v1.0")
|
||||
assert key1 == key2
|
||||
|
||||
def test_compute_key_unique(self, cache: LLMCache) -> None:
|
||||
key1 = cache._compute_key("diff1", "security", "v1.0")
|
||||
key2 = cache._compute_key("diff2", "security", "v1.0")
|
||||
key3 = cache._compute_key("diff1", "style", "v1.0")
|
||||
key4 = cache._compute_key("diff1", "security", "v2.0")
|
||||
|
||||
assert len({key1, key2, key3, key4}) == 4
|
||||
|
||||
def test_serialize_deserialize_response(self, cache: LLMCache) -> None:
|
||||
response = LLMResponse(
|
||||
content="test content",
|
||||
model="gpt-4o",
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
|
||||
serialized = cache._serialize_response(response)
|
||||
deserialized = cache._deserialize_response(serialized)
|
||||
|
||||
assert deserialized.content == response.content
|
||||
assert deserialized.model == response.model
|
||||
assert deserialized.tokens_in == response.tokens_in
|
||||
assert deserialized.tokens_out == response.tokens_out
|
||||
assert deserialized.cost_usd == response.cost_usd
|
||||
|
||||
async def test_cache_get_miss(self, cache: LLMCache) -> None:
|
||||
result = await cache.get("diff", "security", "v1.0")
|
||||
assert result is None
|
||||
assert cache._misses == 1
|
||||
assert cache._hits == 0
|
||||
|
||||
async def test_cache_set_and_get(self, cache: LLMCache) -> None:
|
||||
response = LLMResponse(
|
||||
content="cached content",
|
||||
model="gpt-4o",
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
|
||||
await cache.set("diff", "security", "v1.0", response)
|
||||
result = await cache.get("diff", "security", "v1.0")
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "cached content"
|
||||
assert cache._hits == 1
|
||||
|
||||
async def test_cache_invalidate(self, cache: LLMCache) -> None:
|
||||
response = LLMResponse(
|
||||
content="test",
|
||||
model="gpt-4o",
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
|
||||
await cache.set("diff", "security", "v1.0", response)
|
||||
deleted = await cache.invalidate("diff", "security", "v1.0")
|
||||
assert deleted is True
|
||||
|
||||
result = await cache.get("diff", "security", "v1.0")
|
||||
assert result is None
|
||||
|
||||
async def test_cache_invalidate_nonexistent(self, cache: LLMCache) -> None:
|
||||
deleted = await cache.invalidate("nonexistent", "security", "v1.0")
|
||||
assert deleted is False
|
||||
|
||||
def test_get_stats(self, cache: LLMCache) -> None:
|
||||
stats = cache.get_stats()
|
||||
assert stats["hits"] == 0
|
||||
assert stats["misses"] == 0
|
||||
assert stats["total"] == 0
|
||||
assert stats["hit_rate"] == 0.0
|
||||
|
||||
async def test_get_stats_after_operations(self, cache: LLMCache) -> None:
|
||||
await cache.get("key1", "agent", "v1") # miss
|
||||
await cache.get("key2", "agent", "v1") # miss
|
||||
|
||||
response = LLMResponse(
|
||||
content="test",
|
||||
model="gpt-4o",
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
await cache.set("key1", "agent", "v1", response)
|
||||
await cache.get("key1", "agent", "v1") # hit
|
||||
|
||||
stats = cache.get_stats()
|
||||
assert stats["hits"] == 1
|
||||
assert stats["misses"] == 2
|
||||
assert stats["total"] == 3
|
||||
assert stats["hit_rate"] == pytest.approx(1 / 3)
|
||||
Reference in New Issue
Block a user