add llm response cache (redis)

This commit is contained in:
2025-03-21 18:17:41 +00:00
parent ea2b70f5a3
commit 74432c9f80
4 changed files with 369 additions and 0 deletions

View File

@@ -1,12 +1,15 @@
"""LLM client and prompt management.""" """LLM client and prompt management."""
from arbiter.llm.cache import LLMCache, compute_policy_hash
from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry, PromptTemplate from arbiter.llm.prompts import PromptRegistry, PromptTemplate
__all__ = [ __all__ = [
"LLMCache",
"LLMClient", "LLMClient",
"LLMResponse", "LLMResponse",
"LiteLLMClient", "LiteLLMClient",
"PromptRegistry", "PromptRegistry",
"PromptTemplate", "PromptTemplate",
"compute_policy_hash",
] ]

210
src/arbiter/llm/cache.py Normal file
View File

@@ -0,0 +1,210 @@
"""Redis-backed cache for LLM responses."""
import hashlib
import json
import logging
from typing import Any
from redis.asyncio import Redis
from arbiter.config import get_settings
from arbiter.llm.client import LLMResponse
logger = logging.getLogger(__name__)
class LLMCache:
"""Redis-backed cache for LLM responses.
Cache key format: arbiter:llm:cache:{hash}
where hash = sha256(diff + agent + prompt_version + policy_hash)
"""
PREFIX = "arbiter:llm:cache"
def __init__(self, redis: Redis) -> None:
self.redis = redis
self.settings = get_settings()
self._hits = 0
self._misses = 0
def _compute_key(
self,
diff: str,
agent: str,
prompt_version: str,
policy_hash: str | None = None,
) -> str:
"""Compute cache key from inputs.
Args:
diff: The diff content being reviewed.
agent: Agent name.
prompt_version: Version of the prompt template.
policy_hash: Optional hash of policy configuration.
Returns:
Cache key string.
"""
components = [
diff,
agent,
prompt_version,
policy_hash or "default",
]
content = "|".join(components)
hash_value = hashlib.sha256(content.encode()).hexdigest()
return f"{self.PREFIX}:{hash_value}"
def _serialize_response(self, response: LLMResponse) -> str:
"""Serialize LLMResponse to JSON string."""
return json.dumps(
{
"content": response.content,
"model": response.model,
"tokens_in": response.tokens_in,
"tokens_out": response.tokens_out,
"cost_usd": response.cost_usd,
}
)
def _deserialize_response(self, data: str) -> LLMResponse:
parsed = json.loads(data)
return LLMResponse(
content=parsed["content"],
model=parsed["model"],
tokens_in=parsed["tokens_in"],
tokens_out=parsed["tokens_out"],
cost_usd=parsed["cost_usd"],
)
async def get(
self,
diff: str,
agent: str,
prompt_version: str,
policy_hash: str | None = None,
) -> LLMResponse | None:
"""Get cached LLM response if available.
Args:
diff: The diff content.
agent: Agent name.
prompt_version: Prompt version.
policy_hash: Optional policy hash.
Returns:
Cached LLMResponse or None if not found.
"""
key = self._compute_key(diff, agent, prompt_version, policy_hash)
try:
data = await self.redis.get(key)
if data:
self._hits += 1
logger.debug("Cache hit for %s", key[:50])
return self._deserialize_response(data)
self._misses += 1
return None
except Exception as e:
logger.warning("Cache get error: %s", e)
self._misses += 1
return None
async def set(
self,
diff: str,
agent: str,
prompt_version: str,
response: LLMResponse,
policy_hash: str | None = None,
) -> None:
"""Cache an LLM response.
Args:
diff: The diff content.
agent: Agent name.
prompt_version: Prompt version.
response: LLM response to cache.
policy_hash: Optional policy hash.
"""
key = self._compute_key(diff, agent, prompt_version, policy_hash)
ttl_seconds = self.settings.cache_ttl_hours * 3600
try:
serialized = self._serialize_response(response)
await self.redis.set(key, serialized, ex=ttl_seconds)
logger.debug("Cached response for %s (TTL: %ds)", key[:50], ttl_seconds)
except Exception as e:
logger.warning("Cache set error: %s", e)
async def invalidate(
self,
diff: str,
agent: str,
prompt_version: str,
policy_hash: str | None = None,
) -> bool:
"""Invalidate a cached response.
Args:
diff: The diff content.
agent: Agent name.
prompt_version: Prompt version.
policy_hash: Optional policy hash.
Returns:
True if a key was deleted.
"""
key = self._compute_key(diff, agent, prompt_version, policy_hash)
try:
deleted: int = await self.redis.delete(key)
return deleted > 0
except Exception as e:
logger.warning("Cache invalidate error: %s", e)
return False
async def clear_agent(self, _agent: str) -> int:
"""Clear all cached responses for an agent.
Note: This uses SCAN which may be slow on large datasets.
Args:
agent: Agent name to clear cache for.
Returns:
Number of keys deleted.
"""
pattern = f"{self.PREFIX}:*"
deleted = 0
try:
async for key in self.redis.scan_iter(match=pattern):
deleted += await self.redis.delete(key)
return deleted
except Exception as e:
logger.warning("Cache clear error: %s", e)
return 0
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics.
Returns:
Dict with hits, misses, and hit rate.
"""
total = self._hits + self._misses
hit_rate = self._hits / total if total > 0 else 0.0
return {
"hits": self._hits,
"misses": self._misses,
"total": total,
"hit_rate": hit_rate,
}
def compute_policy_hash(policy_dict: dict[str, Any]) -> str:
# Sort keys for consistent hashing
content = json.dumps(policy_dict, sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:16]

View File

@@ -1,5 +1,6 @@
"""Arbiter data models.""" """Arbiter data models."""
from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost
from arbiter.models.enums import AgentName, Severity, Verdict from arbiter.models.enums import AgentName, Severity, Verdict
from arbiter.models.finding import Finding from arbiter.models.finding import Finding
from arbiter.models.policy import AgentConfig, Policy from arbiter.models.policy import AgentConfig, Policy
@@ -7,9 +8,12 @@ from arbiter.models.review import ReviewResult
__all__ = [ __all__ = [
"AgentConfig", "AgentConfig",
"AgentCost",
"AgentName", "AgentName",
"CostEstimate",
"Finding", "Finding",
"Policy", "Policy",
"ReviewCost",
"ReviewResult", "ReviewResult",
"Severity", "Severity",
"Verdict", "Verdict",

152
src/arbiter/models/cost.py Normal file
View File

@@ -0,0 +1,152 @@
"""Cost tracking models for Arbiter."""
from pydantic import BaseModel, Field
from arbiter.models.enums import AgentName
class AgentCost(BaseModel):
"""Cost breakdown for a single agent."""
agent: AgentName = Field(description="Agent name")
tokens_in: int = Field(ge=0, default=0, description="Input tokens used")
tokens_out: int = Field(ge=0, default=0, description="Output tokens used")
total_tokens: int = Field(ge=0, default=0, description="Total tokens used")
cost_usd: float = Field(ge=0.0, default=0.0, description="Estimated cost in USD")
class ReviewCost(BaseModel):
"""Complete cost tracking for a review."""
# Per-agent costs
agent_costs: list[AgentCost] = Field(
default_factory=list, description="Cost breakdown by agent"
)
# Deliberation costs (synthesis)
deliberation_tokens_in: int = Field(ge=0, default=0, description="Deliberation input tokens")
deliberation_tokens_out: int = Field(ge=0, default=0, description="Deliberation output tokens")
deliberation_cost_usd: float = Field(ge=0.0, default=0.0, description="Deliberation cost")
# Totals
total_tokens_in: int = Field(ge=0, default=0, description="Total input tokens")
total_tokens_out: int = Field(ge=0, default=0, description="Total output tokens")
total_tokens: int = Field(ge=0, default=0, description="Total tokens")
total_cost_usd: float = Field(ge=0.0, default=0.0, description="Total cost in USD")
# Cache stats
cache_hits: int = Field(ge=0, default=0, description="Number of cache hits")
cache_misses: int = Field(ge=0, default=0, description="Number of cache misses")
def add_agent_cost(
self,
agent: AgentName,
tokens_in: int,
tokens_out: int,
cost_usd: float,
) -> None:
"""Add cost for an agent.
Args:
agent: Agent name.
tokens_in: Input tokens used.
tokens_out: Output tokens used.
cost_usd: Cost in USD.
"""
self.agent_costs.append(
AgentCost(
agent=agent,
tokens_in=tokens_in,
tokens_out=tokens_out,
total_tokens=tokens_in + tokens_out,
cost_usd=cost_usd,
)
)
self._update_totals()
def add_deliberation_cost(
self,
tokens_in: int,
tokens_out: int,
cost_usd: float,
) -> None:
self.deliberation_tokens_in += tokens_in
self.deliberation_tokens_out += tokens_out
self.deliberation_cost_usd += cost_usd
self._update_totals()
def _update_totals(self) -> None:
"""Recalculate totals from components."""
agent_tokens_in = sum(c.tokens_in for c in self.agent_costs)
agent_tokens_out = sum(c.tokens_out for c in self.agent_costs)
agent_cost = sum(c.cost_usd for c in self.agent_costs)
self.total_tokens_in = agent_tokens_in + self.deliberation_tokens_in
self.total_tokens_out = agent_tokens_out + self.deliberation_tokens_out
self.total_tokens = self.total_tokens_in + self.total_tokens_out
self.total_cost_usd = agent_cost + self.deliberation_cost_usd
def to_agent_dict(self) -> dict[str, int]:
return {c.agent.value: c.total_tokens for c in self.agent_costs}
def to_cost_dict(self) -> dict[str, float]:
return {c.agent.value: c.cost_usd for c in self.agent_costs}
def is_within_budget(self, max_tokens: int, max_cost_usd: float) -> bool:
return self.total_tokens <= max_tokens and self.total_cost_usd <= max_cost_usd
class CostEstimate(BaseModel):
"""Pre-review cost estimate."""
estimated_tokens: int = Field(ge=0, description="Estimated tokens needed")
estimated_cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
agents_enabled: list[AgentName] = Field(description="Agents that will run")
model: str = Field(description="Model to be used")
within_budget: bool = Field(description="Whether estimate is within budget")
@classmethod
def estimate(
cls,
diff_size: int,
agents: list[AgentName],
model: str,
max_tokens: int = 50000,
max_cost_usd: float = 0.50,
) -> "CostEstimate":
"""Estimate cost for a review.
This is a rough estimate based on diff size and model pricing.
Args:
diff_size: Size of diff in characters.
agents: Agents that will run.
model: Model to be used.
max_tokens: Maximum allowed tokens.
max_cost_usd: Maximum allowed cost.
Returns:
Cost estimate.
"""
# Rough token estimate: ~4 chars per token for input
# Each agent typically uses 3x input for output
tokens_per_agent = (diff_size // 4) * 4 # input + 3x output
# Deliberation uses ~20% of agent tokens
deliberation_tokens = tokens_per_agent // 5
total_tokens = (tokens_per_agent * len(agents)) + deliberation_tokens
# Rough cost estimate based on model
# GPT-4o: $5/1M input, $15/1M output (~$10/1M average)
# GPT-4o-mini: $0.15/1M input, $0.60/1M output (~$0.40/1M average)
cost_per_million = 10.0 if "gpt-4o" in model and "mini" not in model else 0.4
estimated_cost = (total_tokens / 1_000_000) * cost_per_million
return cls(
estimated_tokens=total_tokens,
estimated_cost_usd=round(estimated_cost, 4),
agents_enabled=agents,
model=model,
within_budget=total_tokens <= max_tokens and estimated_cost <= max_cost_usd,
)