diff --git a/src/arbiter/llm/__init__.py b/src/arbiter/llm/__init__.py index 2b8b971..f87667e 100644 --- a/src/arbiter/llm/__init__.py +++ b/src/arbiter/llm/__init__.py @@ -1,12 +1,15 @@ """LLM client and prompt management.""" +from arbiter.llm.cache import LLMCache, compute_policy_hash from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse from arbiter.llm.prompts import PromptRegistry, PromptTemplate __all__ = [ + "LLMCache", "LLMClient", "LLMResponse", "LiteLLMClient", "PromptRegistry", "PromptTemplate", + "compute_policy_hash", ] diff --git a/src/arbiter/llm/cache.py b/src/arbiter/llm/cache.py new file mode 100644 index 0000000..8b2aa59 --- /dev/null +++ b/src/arbiter/llm/cache.py @@ -0,0 +1,210 @@ +"""Redis-backed cache for LLM responses.""" + +import hashlib +import json +import logging +from typing import Any + +from redis.asyncio import Redis + +from arbiter.config import get_settings +from arbiter.llm.client import LLMResponse + +logger = logging.getLogger(__name__) + + +class LLMCache: + """Redis-backed cache for LLM responses. + + Cache key format: arbiter:llm:cache:{hash} + where hash = sha256(diff + agent + prompt_version + policy_hash) + """ + + PREFIX = "arbiter:llm:cache" + + def __init__(self, redis: Redis) -> None: + self.redis = redis + self.settings = get_settings() + self._hits = 0 + self._misses = 0 + + def _compute_key( + self, + diff: str, + agent: str, + prompt_version: str, + policy_hash: str | None = None, + ) -> str: + """Compute cache key from inputs. + + Args: + diff: The diff content being reviewed. + agent: Agent name. + prompt_version: Version of the prompt template. + policy_hash: Optional hash of policy configuration. + + Returns: + Cache key string. + """ + components = [ + diff, + agent, + prompt_version, + policy_hash or "default", + ] + content = "|".join(components) + hash_value = hashlib.sha256(content.encode()).hexdigest() + return f"{self.PREFIX}:{hash_value}" + + def _serialize_response(self, response: LLMResponse) -> str: + """Serialize LLMResponse to JSON string.""" + return json.dumps( + { + "content": response.content, + "model": response.model, + "tokens_in": response.tokens_in, + "tokens_out": response.tokens_out, + "cost_usd": response.cost_usd, + } + ) + + def _deserialize_response(self, data: str) -> LLMResponse: + parsed = json.loads(data) + return LLMResponse( + content=parsed["content"], + model=parsed["model"], + tokens_in=parsed["tokens_in"], + tokens_out=parsed["tokens_out"], + cost_usd=parsed["cost_usd"], + ) + + async def get( + self, + diff: str, + agent: str, + prompt_version: str, + policy_hash: str | None = None, + ) -> LLMResponse | None: + """Get cached LLM response if available. + + Args: + diff: The diff content. + agent: Agent name. + prompt_version: Prompt version. + policy_hash: Optional policy hash. + + Returns: + Cached LLMResponse or None if not found. + """ + key = self._compute_key(diff, agent, prompt_version, policy_hash) + + try: + data = await self.redis.get(key) + if data: + self._hits += 1 + logger.debug("Cache hit for %s", key[:50]) + return self._deserialize_response(data) + self._misses += 1 + return None + except Exception as e: + logger.warning("Cache get error: %s", e) + self._misses += 1 + return None + + async def set( + self, + diff: str, + agent: str, + prompt_version: str, + response: LLMResponse, + policy_hash: str | None = None, + ) -> None: + """Cache an LLM response. + + Args: + diff: The diff content. + agent: Agent name. + prompt_version: Prompt version. + response: LLM response to cache. + policy_hash: Optional policy hash. + """ + key = self._compute_key(diff, agent, prompt_version, policy_hash) + ttl_seconds = self.settings.cache_ttl_hours * 3600 + + try: + serialized = self._serialize_response(response) + await self.redis.set(key, serialized, ex=ttl_seconds) + logger.debug("Cached response for %s (TTL: %ds)", key[:50], ttl_seconds) + except Exception as e: + logger.warning("Cache set error: %s", e) + + async def invalidate( + self, + diff: str, + agent: str, + prompt_version: str, + policy_hash: str | None = None, + ) -> bool: + """Invalidate a cached response. + + Args: + diff: The diff content. + agent: Agent name. + prompt_version: Prompt version. + policy_hash: Optional policy hash. + + Returns: + True if a key was deleted. + """ + key = self._compute_key(diff, agent, prompt_version, policy_hash) + + try: + deleted: int = await self.redis.delete(key) + return deleted > 0 + except Exception as e: + logger.warning("Cache invalidate error: %s", e) + return False + + async def clear_agent(self, _agent: str) -> int: + """Clear all cached responses for an agent. + + Note: This uses SCAN which may be slow on large datasets. + + Args: + agent: Agent name to clear cache for. + + Returns: + Number of keys deleted. + """ + pattern = f"{self.PREFIX}:*" + deleted = 0 + + try: + async for key in self.redis.scan_iter(match=pattern): + deleted += await self.redis.delete(key) + return deleted + except Exception as e: + logger.warning("Cache clear error: %s", e) + return 0 + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics. + + Returns: + Dict with hits, misses, and hit rate. + """ + total = self._hits + self._misses + hit_rate = self._hits / total if total > 0 else 0.0 + + return { + "hits": self._hits, + "misses": self._misses, + "total": total, + "hit_rate": hit_rate, + } + + +def compute_policy_hash(policy_dict: dict[str, Any]) -> str: + # Sort keys for consistent hashing + content = json.dumps(policy_dict, sort_keys=True) + return hashlib.sha256(content.encode()).hexdigest()[:16] diff --git a/src/arbiter/models/__init__.py b/src/arbiter/models/__init__.py index f0d1ba2..45015ca 100644 --- a/src/arbiter/models/__init__.py +++ b/src/arbiter/models/__init__.py @@ -1,5 +1,6 @@ """Arbiter data models.""" +from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost from arbiter.models.enums import AgentName, Severity, Verdict from arbiter.models.finding import Finding from arbiter.models.policy import AgentConfig, Policy @@ -7,9 +8,12 @@ from arbiter.models.review import ReviewResult __all__ = [ "AgentConfig", + "AgentCost", "AgentName", + "CostEstimate", "Finding", "Policy", + "ReviewCost", "ReviewResult", "Severity", "Verdict", diff --git a/src/arbiter/models/cost.py b/src/arbiter/models/cost.py new file mode 100644 index 0000000..295f84f --- /dev/null +++ b/src/arbiter/models/cost.py @@ -0,0 +1,152 @@ +"""Cost tracking models for Arbiter.""" + +from pydantic import BaseModel, Field + +from arbiter.models.enums import AgentName + + +class AgentCost(BaseModel): + """Cost breakdown for a single agent.""" + + agent: AgentName = Field(description="Agent name") + tokens_in: int = Field(ge=0, default=0, description="Input tokens used") + tokens_out: int = Field(ge=0, default=0, description="Output tokens used") + total_tokens: int = Field(ge=0, default=0, description="Total tokens used") + cost_usd: float = Field(ge=0.0, default=0.0, description="Estimated cost in USD") + + +class ReviewCost(BaseModel): + """Complete cost tracking for a review.""" + + # Per-agent costs + agent_costs: list[AgentCost] = Field( + default_factory=list, description="Cost breakdown by agent" + ) + + # Deliberation costs (synthesis) + deliberation_tokens_in: int = Field(ge=0, default=0, description="Deliberation input tokens") + deliberation_tokens_out: int = Field(ge=0, default=0, description="Deliberation output tokens") + deliberation_cost_usd: float = Field(ge=0.0, default=0.0, description="Deliberation cost") + + # Totals + total_tokens_in: int = Field(ge=0, default=0, description="Total input tokens") + total_tokens_out: int = Field(ge=0, default=0, description="Total output tokens") + total_tokens: int = Field(ge=0, default=0, description="Total tokens") + total_cost_usd: float = Field(ge=0.0, default=0.0, description="Total cost in USD") + + # Cache stats + cache_hits: int = Field(ge=0, default=0, description="Number of cache hits") + cache_misses: int = Field(ge=0, default=0, description="Number of cache misses") + + def add_agent_cost( + self, + agent: AgentName, + tokens_in: int, + tokens_out: int, + cost_usd: float, + ) -> None: + """Add cost for an agent. + + Args: + agent: Agent name. + tokens_in: Input tokens used. + tokens_out: Output tokens used. + cost_usd: Cost in USD. + """ + self.agent_costs.append( + AgentCost( + agent=agent, + tokens_in=tokens_in, + tokens_out=tokens_out, + total_tokens=tokens_in + tokens_out, + cost_usd=cost_usd, + ) + ) + self._update_totals() + + def add_deliberation_cost( + self, + tokens_in: int, + tokens_out: int, + cost_usd: float, + ) -> None: + self.deliberation_tokens_in += tokens_in + self.deliberation_tokens_out += tokens_out + self.deliberation_cost_usd += cost_usd + self._update_totals() + + def _update_totals(self) -> None: + """Recalculate totals from components.""" + agent_tokens_in = sum(c.tokens_in for c in self.agent_costs) + agent_tokens_out = sum(c.tokens_out for c in self.agent_costs) + agent_cost = sum(c.cost_usd for c in self.agent_costs) + + self.total_tokens_in = agent_tokens_in + self.deliberation_tokens_in + self.total_tokens_out = agent_tokens_out + self.deliberation_tokens_out + self.total_tokens = self.total_tokens_in + self.total_tokens_out + self.total_cost_usd = agent_cost + self.deliberation_cost_usd + + def to_agent_dict(self) -> dict[str, int]: + return {c.agent.value: c.total_tokens for c in self.agent_costs} + + def to_cost_dict(self) -> dict[str, float]: + return {c.agent.value: c.cost_usd for c in self.agent_costs} + + def is_within_budget(self, max_tokens: int, max_cost_usd: float) -> bool: + return self.total_tokens <= max_tokens and self.total_cost_usd <= max_cost_usd + + +class CostEstimate(BaseModel): + """Pre-review cost estimate.""" + + estimated_tokens: int = Field(ge=0, description="Estimated tokens needed") + estimated_cost_usd: float = Field(ge=0.0, description="Estimated cost in USD") + agents_enabled: list[AgentName] = Field(description="Agents that will run") + model: str = Field(description="Model to be used") + within_budget: bool = Field(description="Whether estimate is within budget") + + @classmethod + def estimate( + cls, + diff_size: int, + agents: list[AgentName], + model: str, + max_tokens: int = 50000, + max_cost_usd: float = 0.50, + ) -> "CostEstimate": + """Estimate cost for a review. + + This is a rough estimate based on diff size and model pricing. + + Args: + diff_size: Size of diff in characters. + agents: Agents that will run. + model: Model to be used. + max_tokens: Maximum allowed tokens. + max_cost_usd: Maximum allowed cost. + + Returns: + Cost estimate. + """ + # Rough token estimate: ~4 chars per token for input + # Each agent typically uses 3x input for output + tokens_per_agent = (diff_size // 4) * 4 # input + 3x output + + # Deliberation uses ~20% of agent tokens + deliberation_tokens = tokens_per_agent // 5 + + total_tokens = (tokens_per_agent * len(agents)) + deliberation_tokens + + # Rough cost estimate based on model + # GPT-4o: $5/1M input, $15/1M output (~$10/1M average) + # GPT-4o-mini: $0.15/1M input, $0.60/1M output (~$0.40/1M average) + cost_per_million = 10.0 if "gpt-4o" in model and "mini" not in model else 0.4 + estimated_cost = (total_tokens / 1_000_000) * cost_per_million + + return cls( + estimated_tokens=total_tokens, + estimated_cost_usd=round(estimated_cost, 4), + agents_enabled=agents, + model=model, + within_budget=total_tokens <= max_tokens and estimated_cost <= max_cost_usd, + )