add llm response cache (redis)

This commit is contained in:
2025-03-21 18:17:41 +00:00
parent ea2b70f5a3
commit 74432c9f80
4 changed files with 369 additions and 0 deletions

View File

@@ -1,12 +1,15 @@
"""LLM client and prompt management."""
from arbiter.llm.cache import LLMCache, compute_policy_hash
from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
__all__ = [
"LLMCache",
"LLMClient",
"LLMResponse",
"LiteLLMClient",
"PromptRegistry",
"PromptTemplate",
"compute_policy_hash",
]

210
src/arbiter/llm/cache.py Normal file
View File

@@ -0,0 +1,210 @@
"""Redis-backed cache for LLM responses."""
import hashlib
import json
import logging
from typing import Any
from redis.asyncio import Redis
from arbiter.config import get_settings
from arbiter.llm.client import LLMResponse
logger = logging.getLogger(__name__)
class LLMCache:
"""Redis-backed cache for LLM responses.
Cache key format: arbiter:llm:cache:{hash}
where hash = sha256(diff + agent + prompt_version + policy_hash)
"""
PREFIX = "arbiter:llm:cache"
def __init__(self, redis: Redis) -> None:
self.redis = redis
self.settings = get_settings()
self._hits = 0
self._misses = 0
def _compute_key(
self,
diff: str,
agent: str,
prompt_version: str,
policy_hash: str | None = None,
) -> str:
"""Compute cache key from inputs.
Args:
diff: The diff content being reviewed.
agent: Agent name.
prompt_version: Version of the prompt template.
policy_hash: Optional hash of policy configuration.
Returns:
Cache key string.
"""
components = [
diff,
agent,
prompt_version,
policy_hash or "default",
]
content = "|".join(components)
hash_value = hashlib.sha256(content.encode()).hexdigest()
return f"{self.PREFIX}:{hash_value}"
def _serialize_response(self, response: LLMResponse) -> str:
"""Serialize LLMResponse to JSON string."""
return json.dumps(
{
"content": response.content,
"model": response.model,
"tokens_in": response.tokens_in,
"tokens_out": response.tokens_out,
"cost_usd": response.cost_usd,
}
)
def _deserialize_response(self, data: str) -> LLMResponse:
parsed = json.loads(data)
return LLMResponse(
content=parsed["content"],
model=parsed["model"],
tokens_in=parsed["tokens_in"],
tokens_out=parsed["tokens_out"],
cost_usd=parsed["cost_usd"],
)
async def get(
self,
diff: str,
agent: str,
prompt_version: str,
policy_hash: str | None = None,
) -> LLMResponse | None:
"""Get cached LLM response if available.
Args:
diff: The diff content.
agent: Agent name.
prompt_version: Prompt version.
policy_hash: Optional policy hash.
Returns:
Cached LLMResponse or None if not found.
"""
key = self._compute_key(diff, agent, prompt_version, policy_hash)
try:
data = await self.redis.get(key)
if data:
self._hits += 1
logger.debug("Cache hit for %s", key[:50])
return self._deserialize_response(data)
self._misses += 1
return None
except Exception as e:
logger.warning("Cache get error: %s", e)
self._misses += 1
return None
async def set(
self,
diff: str,
agent: str,
prompt_version: str,
response: LLMResponse,
policy_hash: str | None = None,
) -> None:
"""Cache an LLM response.
Args:
diff: The diff content.
agent: Agent name.
prompt_version: Prompt version.
response: LLM response to cache.
policy_hash: Optional policy hash.
"""
key = self._compute_key(diff, agent, prompt_version, policy_hash)
ttl_seconds = self.settings.cache_ttl_hours * 3600
try:
serialized = self._serialize_response(response)
await self.redis.set(key, serialized, ex=ttl_seconds)
logger.debug("Cached response for %s (TTL: %ds)", key[:50], ttl_seconds)
except Exception as e:
logger.warning("Cache set error: %s", e)
async def invalidate(
self,
diff: str,
agent: str,
prompt_version: str,
policy_hash: str | None = None,
) -> bool:
"""Invalidate a cached response.
Args:
diff: The diff content.
agent: Agent name.
prompt_version: Prompt version.
policy_hash: Optional policy hash.
Returns:
True if a key was deleted.
"""
key = self._compute_key(diff, agent, prompt_version, policy_hash)
try:
deleted: int = await self.redis.delete(key)
return deleted > 0
except Exception as e:
logger.warning("Cache invalidate error: %s", e)
return False
async def clear_agent(self, _agent: str) -> int:
"""Clear all cached responses for an agent.
Note: This uses SCAN which may be slow on large datasets.
Args:
agent: Agent name to clear cache for.
Returns:
Number of keys deleted.
"""
pattern = f"{self.PREFIX}:*"
deleted = 0
try:
async for key in self.redis.scan_iter(match=pattern):
deleted += await self.redis.delete(key)
return deleted
except Exception as e:
logger.warning("Cache clear error: %s", e)
return 0
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dict with hits, misses, and hit rate.
"""
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0.0
return {
"hits": self._hits,
"misses": self._misses,
"total": total,
"hit_rate": hit_rate,
}
def compute_policy_hash(policy_dict: dict[str, Any]) -> str:
# Sort keys for consistent hashing
content = json.dumps(policy_dict, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:16]

View File

@@ -1,5 +1,6 @@
"""Arbiter data models."""
from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost
from arbiter.models.enums import AgentName, Severity, Verdict
from arbiter.models.finding import Finding
from arbiter.models.policy import AgentConfig, Policy
@@ -7,9 +8,12 @@ from arbiter.models.review import ReviewResult
__all__ = [
"AgentConfig",
"AgentCost",
"AgentName",
"CostEstimate",
"Finding",
"Policy",
"ReviewCost",
"ReviewResult",
"Severity",
"Verdict",

152
src/arbiter/models/cost.py Normal file
View File

@@ -0,0 +1,152 @@
"""Cost tracking models for Arbiter."""
from pydantic import BaseModel, Field
from arbiter.models.enums import AgentName
class AgentCost(BaseModel):
"""Cost breakdown for a single agent."""
agent: AgentName = Field(description="Agent name")
tokens_in: int = Field(ge=0, default=0, description="Input tokens used")
tokens_out: int = Field(ge=0, default=0, description="Output tokens used")
total_tokens: int = Field(ge=0, default=0, description="Total tokens used")
cost_usd: float = Field(ge=0.0, default=0.0, description="Estimated cost in USD")
class ReviewCost(BaseModel):
"""Complete cost tracking for a review."""
# Per-agent costs
agent_costs: list[AgentCost] = Field(
default_factory=list, description="Cost breakdown by agent"
)
# Deliberation costs (synthesis)
deliberation_tokens_in: int = Field(ge=0, default=0, description="Deliberation input tokens")
deliberation_tokens_out: int = Field(ge=0, default=0, description="Deliberation output tokens")
deliberation_cost_usd: float = Field(ge=0.0, default=0.0, description="Deliberation cost")
# Totals
total_tokens_in: int = Field(ge=0, default=0, description="Total input tokens")
total_tokens_out: int = Field(ge=0, default=0, description="Total output tokens")
total_tokens: int = Field(ge=0, default=0, description="Total tokens")
total_cost_usd: float = Field(ge=0.0, default=0.0, description="Total cost in USD")
# Cache stats
cache_hits: int = Field(ge=0, default=0, description="Number of cache hits")
cache_misses: int = Field(ge=0, default=0, description="Number of cache misses")
def add_agent_cost(
self,
agent: AgentName,
tokens_in: int,
tokens_out: int,
cost_usd: float,
) -> None:
"""Add cost for an agent.
Args:
agent: Agent name.
tokens_in: Input tokens used.
tokens_out: Output tokens used.
cost_usd: Cost in USD.
"""
self.agent_costs.append(
AgentCost(
agent=agent,
tokens_in=tokens_in,
tokens_out=tokens_out,
total_tokens=tokens_in + tokens_out,
cost_usd=cost_usd,
)
)
self._update_totals()
def add_deliberation_cost(
self,
tokens_in: int,
tokens_out: int,
cost_usd: float,
) -> None:
self.deliberation_tokens_in += tokens_in
self.deliberation_tokens_out += tokens_out
self.deliberation_cost_usd += cost_usd
self._update_totals()
def _update_totals(self) -> None:
"""Recalculate totals from components."""
agent_tokens_in = sum(c.tokens_in for c in self.agent_costs)
agent_tokens_out = sum(c.tokens_out for c in self.agent_costs)
agent_cost = sum(c.cost_usd for c in self.agent_costs)
self.total_tokens_in = agent_tokens_in + self.deliberation_tokens_in
self.total_tokens_out = agent_tokens_out + self.deliberation_tokens_out
self.total_tokens = self.total_tokens_in + self.total_tokens_out
self.total_cost_usd = agent_cost + self.deliberation_cost_usd
def to_agent_dict(self) -> dict[str, int]:
return {c.agent.value: c.total_tokens for c in self.agent_costs}
def to_cost_dict(self) -> dict[str, float]:
return {c.agent.value: c.cost_usd for c in self.agent_costs}
def is_within_budget(self, max_tokens: int, max_cost_usd: float) -> bool:
return self.total_tokens <= max_tokens and self.total_cost_usd <= max_cost_usd
class CostEstimate(BaseModel):
"""Pre-review cost estimate."""
estimated_tokens: int = Field(ge=0, description="Estimated tokens needed")
estimated_cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
agents_enabled: list[AgentName] = Field(description="Agents that will run")
model: str = Field(description="Model to be used")
within_budget: bool = Field(description="Whether estimate is within budget")
@classmethod
def estimate(
cls,
diff_size: int,
agents: list[AgentName],
model: str,
max_tokens: int = 50000,
max_cost_usd: float = 0.50,
) -> "CostEstimate":
"""Estimate cost for a review.
This is a rough estimate based on diff size and model pricing.
Args:
diff_size: Size of diff in characters.
agents: Agents that will run.
model: Model to be used.
max_tokens: Maximum allowed tokens.
max_cost_usd: Maximum allowed cost.
Returns:
Cost estimate.
"""
# Rough token estimate: ~4 chars per token for input
# Each agent typically uses 3x input for output
tokens_per_agent = (diff_size // 4) * 4 # input + 3x output
# Deliberation uses ~20% of agent tokens
deliberation_tokens = tokens_per_agent // 5
total_tokens = (tokens_per_agent * len(agents)) + deliberation_tokens
# Rough cost estimate based on model
# GPT-4o: $5/1M input, $15/1M output (~$10/1M average)
# GPT-4o-mini: $0.15/1M input, $0.60/1M output (~$0.40/1M average)
cost_per_million = 10.0 if "gpt-4o" in model and "mini" not in model else 0.4
estimated_cost = (total_tokens / 1_000_000) * cost_per_million
return cls(
estimated_tokens=total_tokens,
estimated_cost_usd=round(estimated_cost, 4),
agents_enabled=agents,
model=model,
within_budget=total_tokens <= max_tokens and estimated_cost <= max_cost_usd,
)