add llm response cache (redis)
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
"""LLM client and prompt management."""
|
||||
|
||||
from arbiter.llm.cache import LLMCache, compute_policy_hash
|
||||
from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||
|
||||
__all__ = [
|
||||
"LLMCache",
|
||||
"LLMClient",
|
||||
"LLMResponse",
|
||||
"LiteLLMClient",
|
||||
"PromptRegistry",
|
||||
"PromptTemplate",
|
||||
"compute_policy_hash",
|
||||
]
|
||||
|
||||
210
src/arbiter/llm/cache.py
Normal file
210
src/arbiter/llm/cache.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Redis-backed cache for LLM responses."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from arbiter.config import get_settings
|
||||
from arbiter.llm.client import LLMResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCache:
|
||||
"""Redis-backed cache for LLM responses.
|
||||
|
||||
Cache key format: arbiter:llm:cache:{hash}
|
||||
where hash = sha256(diff + agent + prompt_version + policy_hash)
|
||||
"""
|
||||
|
||||
PREFIX = "arbiter:llm:cache"
|
||||
|
||||
def __init__(self, redis: Redis) -> None:
|
||||
self.redis = redis
|
||||
self.settings = get_settings()
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
def _compute_key(
|
||||
self,
|
||||
diff: str,
|
||||
agent: str,
|
||||
prompt_version: str,
|
||||
policy_hash: str | None = None,
|
||||
) -> str:
|
||||
"""Compute cache key from inputs.
|
||||
|
||||
Args:
|
||||
diff: The diff content being reviewed.
|
||||
agent: Agent name.
|
||||
prompt_version: Version of the prompt template.
|
||||
policy_hash: Optional hash of policy configuration.
|
||||
|
||||
Returns:
|
||||
Cache key string.
|
||||
"""
|
||||
components = [
|
||||
diff,
|
||||
agent,
|
||||
prompt_version,
|
||||
policy_hash or "default",
|
||||
]
|
||||
content = "|".join(components)
|
||||
hash_value = hashlib.sha256(content.encode()).hexdigest()
|
||||
return f"{self.PREFIX}:{hash_value}"
|
||||
|
||||
def _serialize_response(self, response: LLMResponse) -> str:
|
||||
"""Serialize LLMResponse to JSON string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"content": response.content,
|
||||
"model": response.model,
|
||||
"tokens_in": response.tokens_in,
|
||||
"tokens_out": response.tokens_out,
|
||||
"cost_usd": response.cost_usd,
|
||||
}
|
||||
)
|
||||
|
||||
def _deserialize_response(self, data: str) -> LLMResponse:
|
||||
parsed = json.loads(data)
|
||||
return LLMResponse(
|
||||
content=parsed["content"],
|
||||
model=parsed["model"],
|
||||
tokens_in=parsed["tokens_in"],
|
||||
tokens_out=parsed["tokens_out"],
|
||||
cost_usd=parsed["cost_usd"],
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
diff: str,
|
||||
agent: str,
|
||||
prompt_version: str,
|
||||
policy_hash: str | None = None,
|
||||
) -> LLMResponse | None:
|
||||
"""Get cached LLM response if available.
|
||||
|
||||
Args:
|
||||
diff: The diff content.
|
||||
agent: Agent name.
|
||||
prompt_version: Prompt version.
|
||||
policy_hash: Optional policy hash.
|
||||
|
||||
Returns:
|
||||
Cached LLMResponse or None if not found.
|
||||
"""
|
||||
key = self._compute_key(diff, agent, prompt_version, policy_hash)
|
||||
|
||||
try:
|
||||
data = await self.redis.get(key)
|
||||
if data:
|
||||
self._hits += 1
|
||||
logger.debug("Cache hit for %s", key[:50])
|
||||
return self._deserialize_response(data)
|
||||
self._misses += 1
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Cache get error: %s", e)
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
diff: str,
|
||||
agent: str,
|
||||
prompt_version: str,
|
||||
response: LLMResponse,
|
||||
policy_hash: str | None = None,
|
||||
) -> None:
|
||||
"""Cache an LLM response.
|
||||
|
||||
Args:
|
||||
diff: The diff content.
|
||||
agent: Agent name.
|
||||
prompt_version: Prompt version.
|
||||
response: LLM response to cache.
|
||||
policy_hash: Optional policy hash.
|
||||
"""
|
||||
key = self._compute_key(diff, agent, prompt_version, policy_hash)
|
||||
ttl_seconds = self.settings.cache_ttl_hours * 3600
|
||||
|
||||
try:
|
||||
serialized = self._serialize_response(response)
|
||||
await self.redis.set(key, serialized, ex=ttl_seconds)
|
||||
logger.debug("Cached response for %s (TTL: %ds)", key[:50], ttl_seconds)
|
||||
except Exception as e:
|
||||
logger.warning("Cache set error: %s", e)
|
||||
|
||||
async def invalidate(
|
||||
self,
|
||||
diff: str,
|
||||
agent: str,
|
||||
prompt_version: str,
|
||||
policy_hash: str | None = None,
|
||||
) -> bool:
|
||||
"""Invalidate a cached response.
|
||||
|
||||
Args:
|
||||
diff: The diff content.
|
||||
agent: Agent name.
|
||||
prompt_version: Prompt version.
|
||||
policy_hash: Optional policy hash.
|
||||
|
||||
Returns:
|
||||
True if a key was deleted.
|
||||
"""
|
||||
key = self._compute_key(diff, agent, prompt_version, policy_hash)
|
||||
|
||||
try:
|
||||
deleted: int = await self.redis.delete(key)
|
||||
return deleted > 0
|
||||
except Exception as e:
|
||||
logger.warning("Cache invalidate error: %s", e)
|
||||
return False
|
||||
|
||||
async def clear_agent(self, _agent: str) -> int:
|
||||
"""Clear all cached responses for an agent.
|
||||
|
||||
Note: This uses SCAN which may be slow on large datasets.
|
||||
|
||||
Args:
|
||||
agent: Agent name to clear cache for.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted.
|
||||
"""
|
||||
pattern = f"{self.PREFIX}:*"
|
||||
deleted = 0
|
||||
|
||||
try:
|
||||
async for key in self.redis.scan_iter(match=pattern):
|
||||
deleted += await self.redis.delete(key)
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.warning("Cache clear error: %s", e)
|
||||
return 0
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dict with hits, misses, and hit rate.
|
||||
"""
|
||||
total = self._hits + self._misses
|
||||
hit_rate = self._hits / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
"total": total,
|
||||
"hit_rate": hit_rate,
|
||||
}
|
||||
|
||||
|
||||
def compute_policy_hash(policy_dict: dict[str, Any]) -> str:
|
||||
# Sort keys for consistent hashing
|
||||
content = json.dumps(policy_dict, sort_keys=True)
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Arbiter data models."""
|
||||
|
||||
from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost
|
||||
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||
from arbiter.models.finding import Finding
|
||||
from arbiter.models.policy import AgentConfig, Policy
|
||||
@@ -7,9 +8,12 @@ from arbiter.models.review import ReviewResult
|
||||
|
||||
__all__ = [
|
||||
"AgentConfig",
|
||||
"AgentCost",
|
||||
"AgentName",
|
||||
"CostEstimate",
|
||||
"Finding",
|
||||
"Policy",
|
||||
"ReviewCost",
|
||||
"ReviewResult",
|
||||
"Severity",
|
||||
"Verdict",
|
||||
|
||||
152
src/arbiter/models/cost.py
Normal file
152
src/arbiter/models/cost.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Cost tracking models for Arbiter."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.models.enums import AgentName
|
||||
|
||||
|
||||
class AgentCost(BaseModel):
|
||||
"""Cost breakdown for a single agent."""
|
||||
|
||||
agent: AgentName = Field(description="Agent name")
|
||||
tokens_in: int = Field(ge=0, default=0, description="Input tokens used")
|
||||
tokens_out: int = Field(ge=0, default=0, description="Output tokens used")
|
||||
total_tokens: int = Field(ge=0, default=0, description="Total tokens used")
|
||||
cost_usd: float = Field(ge=0.0, default=0.0, description="Estimated cost in USD")
|
||||
|
||||
|
||||
class ReviewCost(BaseModel):
|
||||
"""Complete cost tracking for a review."""
|
||||
|
||||
# Per-agent costs
|
||||
agent_costs: list[AgentCost] = Field(
|
||||
default_factory=list, description="Cost breakdown by agent"
|
||||
)
|
||||
|
||||
# Deliberation costs (synthesis)
|
||||
deliberation_tokens_in: int = Field(ge=0, default=0, description="Deliberation input tokens")
|
||||
deliberation_tokens_out: int = Field(ge=0, default=0, description="Deliberation output tokens")
|
||||
deliberation_cost_usd: float = Field(ge=0.0, default=0.0, description="Deliberation cost")
|
||||
|
||||
# Totals
|
||||
total_tokens_in: int = Field(ge=0, default=0, description="Total input tokens")
|
||||
total_tokens_out: int = Field(ge=0, default=0, description="Total output tokens")
|
||||
total_tokens: int = Field(ge=0, default=0, description="Total tokens")
|
||||
total_cost_usd: float = Field(ge=0.0, default=0.0, description="Total cost in USD")
|
||||
|
||||
# Cache stats
|
||||
cache_hits: int = Field(ge=0, default=0, description="Number of cache hits")
|
||||
cache_misses: int = Field(ge=0, default=0, description="Number of cache misses")
|
||||
|
||||
def add_agent_cost(
|
||||
self,
|
||||
agent: AgentName,
|
||||
tokens_in: int,
|
||||
tokens_out: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
"""Add cost for an agent.
|
||||
|
||||
Args:
|
||||
agent: Agent name.
|
||||
tokens_in: Input tokens used.
|
||||
tokens_out: Output tokens used.
|
||||
cost_usd: Cost in USD.
|
||||
"""
|
||||
self.agent_costs.append(
|
||||
AgentCost(
|
||||
agent=agent,
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
total_tokens=tokens_in + tokens_out,
|
||||
cost_usd=cost_usd,
|
||||
)
|
||||
)
|
||||
self._update_totals()
|
||||
|
||||
def add_deliberation_cost(
|
||||
self,
|
||||
tokens_in: int,
|
||||
tokens_out: int,
|
||||
cost_usd: float,
|
||||
) -> None:
|
||||
self.deliberation_tokens_in += tokens_in
|
||||
self.deliberation_tokens_out += tokens_out
|
||||
self.deliberation_cost_usd += cost_usd
|
||||
self._update_totals()
|
||||
|
||||
def _update_totals(self) -> None:
|
||||
"""Recalculate totals from components."""
|
||||
agent_tokens_in = sum(c.tokens_in for c in self.agent_costs)
|
||||
agent_tokens_out = sum(c.tokens_out for c in self.agent_costs)
|
||||
agent_cost = sum(c.cost_usd for c in self.agent_costs)
|
||||
|
||||
self.total_tokens_in = agent_tokens_in + self.deliberation_tokens_in
|
||||
self.total_tokens_out = agent_tokens_out + self.deliberation_tokens_out
|
||||
self.total_tokens = self.total_tokens_in + self.total_tokens_out
|
||||
self.total_cost_usd = agent_cost + self.deliberation_cost_usd
|
||||
|
||||
def to_agent_dict(self) -> dict[str, int]:
|
||||
return {c.agent.value: c.total_tokens for c in self.agent_costs}
|
||||
|
||||
def to_cost_dict(self) -> dict[str, float]:
|
||||
return {c.agent.value: c.cost_usd for c in self.agent_costs}
|
||||
|
||||
def is_within_budget(self, max_tokens: int, max_cost_usd: float) -> bool:
|
||||
return self.total_tokens <= max_tokens and self.total_cost_usd <= max_cost_usd
|
||||
|
||||
|
||||
class CostEstimate(BaseModel):
|
||||
"""Pre-review cost estimate."""
|
||||
|
||||
estimated_tokens: int = Field(ge=0, description="Estimated tokens needed")
|
||||
estimated_cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
|
||||
agents_enabled: list[AgentName] = Field(description="Agents that will run")
|
||||
model: str = Field(description="Model to be used")
|
||||
within_budget: bool = Field(description="Whether estimate is within budget")
|
||||
|
||||
@classmethod
|
||||
def estimate(
|
||||
cls,
|
||||
diff_size: int,
|
||||
agents: list[AgentName],
|
||||
model: str,
|
||||
max_tokens: int = 50000,
|
||||
max_cost_usd: float = 0.50,
|
||||
) -> "CostEstimate":
|
||||
"""Estimate cost for a review.
|
||||
|
||||
This is a rough estimate based on diff size and model pricing.
|
||||
|
||||
Args:
|
||||
diff_size: Size of diff in characters.
|
||||
agents: Agents that will run.
|
||||
model: Model to be used.
|
||||
max_tokens: Maximum allowed tokens.
|
||||
max_cost_usd: Maximum allowed cost.
|
||||
|
||||
Returns:
|
||||
Cost estimate.
|
||||
"""
|
||||
# Rough token estimate: ~4 chars per token for input
|
||||
# Each agent typically uses 3x input for output
|
||||
tokens_per_agent = (diff_size // 4) * 4 # input + 3x output
|
||||
|
||||
# Deliberation uses ~20% of agent tokens
|
||||
deliberation_tokens = tokens_per_agent // 5
|
||||
|
||||
total_tokens = (tokens_per_agent * len(agents)) + deliberation_tokens
|
||||
|
||||
# Rough cost estimate based on model
|
||||
# GPT-4o: $5/1M input, $15/1M output (~$10/1M average)
|
||||
# GPT-4o-mini: $0.15/1M input, $0.60/1M output (~$0.40/1M average)
|
||||
cost_per_million = 10.0 if "gpt-4o" in model and "mini" not in model else 0.4
|
||||
estimated_cost = (total_tokens / 1_000_000) * cost_per_million
|
||||
|
||||
return cls(
|
||||
estimated_tokens=total_tokens,
|
||||
estimated_cost_usd=round(estimated_cost, 4),
|
||||
agents_enabled=agents,
|
||||
model=model,
|
||||
within_budget=total_tokens <= max_tokens and estimated_cost <= max_cost_usd,
|
||||
)
|
||||
Reference in New Issue
Block a user