feat(agents): implement agent framework and CLI

This commit is contained in:
2025-03-08 15:52:29 +00:00
parent 72268ff440
commit f22ca1d5bd
30 changed files with 3466 additions and 0 deletions

111
pyproject.toml Normal file
View File

@@ -0,0 +1,111 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "arbiter"
version = "0.1.0"
description = "A multi-agent code review system that shows its work"
readme = "readme.md"
requires-python = ">=3.12"
license = "MIT"
authors = [{ name = "Kai Chappell", email = "git@kschappell.com" }]
keywords = ["code-review", "ai", "llm", "agents"]
classifiers = [
"Development Status :: 3 - Alpha",
"Environment :: Console",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Quality Assurance",
]
dependencies = [
"litellm>=1.0.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"typer>=0.9.0",
"rich>=13.0.0",
"pyyaml>=6.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
"pytest-cov>=4.0.0",
"ruff>=0.4.0",
"mypy>=1.10.0",
"types-PyYAML>=6.0.0",
]
[project.scripts]
arbiter = "arbiter.cli:app"
[tool.hatch.build.targets.wheel]
packages = ["src/arbiter"]
[tool.ruff]
target-version = "py312"
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # Pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"PTH", # flake8-use-pathlib
"RUF", # Ruff-specific rules
]
ignore = [
"E501", # line too long (handled by formatter)
]
[tool.ruff.lint.isort]
known-first-party = ["arbiter"]
[tool.mypy]
python_version = "3.12"
strict = true
warn_return_any = true
warn_unused_ignores = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_configs = true
show_error_codes = true
files = ["src"]
[[tool.mypy.overrides]]
module = ["litellm.*"]
ignore_missing_imports = true
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
addopts = "-ra -q"
[tool.coverage.run]
source = ["src/arbiter"]
branch = true
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:",
]
fail_under = 85

3
src/arbiter/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""Arbiter: A multi-agent code review system that shows its work."""
__version__ = "0.1.0"

View File

@@ -0,0 +1,14 @@
"""Review agents for code analysis."""
from arbiter.agents.base import Agent, ReviewContext
from arbiter.agents.complexity import ComplexityAgent
from arbiter.agents.security import SecurityAgent
from arbiter.agents.style import StyleAgent
__all__ = [
"Agent",
"ComplexityAgent",
"ReviewContext",
"SecurityAgent",
"StyleAgent",
]

277
src/arbiter/agents/base.py Normal file
View File

@@ -0,0 +1,277 @@
"""Base agent class and review context."""
import json
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, Field
from arbiter.llm.client import LLMClient
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity
class ReviewContext(BaseModel):
"""Context for a code review."""
diff: str = Field(description="The diff content to review")
policy: Policy = Field(description="Review policy configuration")
file_path: str | None = Field(default=None, description="Path to the diff file")
static_analysis_context: str = Field(
default="No static analysis data available.",
description="Static analysis results as formatted string",
)
class ExplainContext(BaseModel):
"""Context for explaining a finding in a follow-up conversation."""
question: str = Field(description="The user's follow-up question")
finding: Finding = Field(description="The finding being asked about")
diff: str = Field(description="The relevant code diff")
conversation_history: list[dict[str, str]] = Field(
default_factory=list,
description="Previous messages in this conversation",
)
class ExplainResult(BaseModel):
"""Result from an agent's explain() method."""
response: str = Field(description="The explanation response text")
tokens_used: int = Field(ge=0, description="Total tokens used for this response")
cost_usd: float = Field(ge=0.0, description="Cost in USD for this response")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence in the explanation")
class Agent(ABC):
"""Abstract base class for review agents."""
name: AgentName
prompt_name: str
prompt_version: str = "1.0"
default_model: str = "gpt-4o"
explain_prompt_name: str = "" # Set by subclasses
explain_prompt_version: str = "1.0"
def __init__(
self,
llm_client: LLMClient,
prompt_registry: PromptRegistry,
) -> None:
self.llm_client = llm_client
self.prompt_registry = prompt_registry
async def review(self, context: ReviewContext) -> ReviewResult:
"""Perform a code review.
Args:
context: Review context with diff and policy.
Returns:
ReviewResult with findings and metadata.
"""
start_time = time.monotonic()
# Get model from policy or use default
agent_config = context.policy.agents.get(self.name)
model = agent_config.model if agent_config and agent_config.model else self.default_model
# Get prompt additions from policy
prompt_additions = ""
if agent_config and agent_config.prompt_additions:
prompt_additions = agent_config.prompt_additions
# Build and send messages
messages = self._build_messages(context, prompt_additions)
response = await self._call_llm(messages, model)
# Parse response into findings
findings = self._parse_response(response.content, context)
# Filter by severity threshold
if agent_config:
findings = self._filter_by_severity(findings, agent_config.severity_threshold)
duration_ms = int((time.monotonic() - start_time) * 1000)
return ReviewResult(
agent_name=self.name,
findings=findings,
duration_ms=duration_ms,
tokens_used=response.tokens_in + response.tokens_out,
cost_usd=response.cost_usd,
)
def _build_messages(
self,
context: ReviewContext,
prompt_additions: str,
) -> list[dict[str, str]]:
template = self.prompt_registry.get(self.prompt_name, self.prompt_version)
content = template.render(
diff=context.diff,
static_analysis_context=context.static_analysis_context,
prompt_additions=prompt_additions,
)
return [{"role": "user", "content": content}]
async def _call_llm(
self,
messages: list[dict[str, str]],
model: str,
) -> Any:
return await self.llm_client.complete(messages, model)
@abstractmethod
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: ...
def _filter_by_severity(
self,
findings: list[Finding],
threshold: Severity,
) -> list[Finding]:
"""Filter findings by severity threshold.
Args:
findings: List of findings to filter.
threshold: Minimum severity to include.
Returns:
Filtered list of findings.
"""
severity_order = [
Severity.CRITICAL,
Severity.HIGH,
Severity.MEDIUM,
Severity.LOW,
Severity.INFO,
]
threshold_index = severity_order.index(threshold)
return [f for f in findings if severity_order.index(f.severity) <= threshold_index]
def _parse_json_findings(self, content: str, _context: ReviewContext) -> list[Finding]:
"""Parse JSON array of findings from LLM response.
Args:
content: Raw LLM response content.
context: Review context.
Returns:
List of Finding objects.
"""
# Extract JSON from response (handle markdown code blocks)
json_content = content.strip()
if json_content.startswith("```"):
lines = json_content.split("\n")
# Remove first and last lines (code block markers)
json_lines = []
in_block = False
for line in lines:
if line.startswith("```") and not in_block:
in_block = True
continue
if line.startswith("```") and in_block:
break
if in_block:
json_lines.append(line)
json_content = "\n".join(json_lines)
try:
data = json.loads(json_content)
except json.JSONDecodeError:
return []
if not isinstance(data, list):
return []
findings = []
for item in data:
if not isinstance(item, dict):
continue
try:
finding = Finding(
id=str(uuid.uuid4()),
agent=self.name,
file=item.get("file", "unknown"),
line_start=item.get("line_start", 1),
line_end=item.get("line_end") or item.get("line_start", 1),
severity=Severity(item.get("severity", "info")),
confidence=float(item.get("confidence", 0.5)),
title=item.get("title", "Untitled finding"),
description=item.get("description", ""),
reasoning=item.get("reasoning", ""),
suggestion=item.get("suggestion"),
references=item.get("references", []),
prompt_version=f"{self.prompt_name}-v{self.prompt_version}",
static_analysis_context=None,
)
findings.append(finding)
except (ValueError, KeyError):
continue
return findings
async def explain(
self,
context: ExplainContext,
model: str | None = None,
) -> ExplainResult:
"""Provide a detailed explanation about a specific finding.
Args:
context: The explanation context with question, finding, and history.
model: Optional model override. Defaults to agent's default_model.
Returns:
ExplainResult with response text and cost tracking.
"""
if not self.explain_prompt_name:
# Derive from prompt_name if not set
explain_name = f"{self.prompt_name}-explain"
else:
explain_name = self.explain_prompt_name
template = self.prompt_registry.get(explain_name, self.explain_prompt_version)
# Format conversation history
history_text = ""
if context.conversation_history:
history_lines = []
for msg in context.conversation_history:
role = msg.get("role", "user").capitalize()
content = msg.get("content", "")
history_lines.append(f"**{role}:** {content}")
history_text = "\n\n".join(history_lines)
# Build finding context
finding_lines = (
f"{context.finding.line_start}"
if context.finding.line_start == context.finding.line_end
else f"{context.finding.line_start}-{context.finding.line_end}"
)
content = template.render(
finding_title=context.finding.title,
finding_file=context.finding.file,
finding_lines=finding_lines,
finding_description=context.finding.description,
finding_reasoning=context.finding.reasoning,
finding_severity=context.finding.severity.value,
finding_suggestion=context.finding.suggestion or "No suggestion provided.",
diff=context.diff,
conversation_history=history_text or "No previous conversation.",
question=context.question,
)
messages = [{"role": "user", "content": content}]
response = await self._call_llm(messages, model or self.default_model)
return ExplainResult(
response=response.content.strip(),
tokens_used=response.tokens_in + response.tokens_out,
cost_usd=response.cost_usd,
confidence=0.8, # Default confidence for explanations
)

View File

@@ -0,0 +1,17 @@
"""Complexity review agent."""
from arbiter.agents.base import Agent, ReviewContext
from arbiter.models import AgentName, Finding
class ComplexityAgent(Agent):
"""Agent focused on code complexity and architecture."""
name = AgentName.COMPLEXITY
prompt_name = "complexity"
prompt_version = "1.0"
explain_prompt_name = "complexity-explain"
default_model = "gpt-4o-mini"
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
return self._parse_json_findings(content, context)

View File

@@ -0,0 +1,17 @@
"""Security review agent."""
from arbiter.agents.base import Agent, ReviewContext
from arbiter.models import AgentName, Finding
class SecurityAgent(Agent):
"""Agent focused on security vulnerabilities and risks."""
name = AgentName.SECURITY
prompt_name = "security"
prompt_version = "1.0"
explain_prompt_name = "security-explain"
default_model = "gpt-4o"
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
return self._parse_json_findings(content, context)

View File

@@ -0,0 +1,17 @@
"""Style review agent."""
from arbiter.agents.base import Agent, ReviewContext
from arbiter.models import AgentName, Finding
class StyleAgent(Agent):
"""Agent focused on code style, naming, and readability."""
name = AgentName.STYLE
prompt_name = "style"
prompt_version = "1.0"
explain_prompt_name = "style-explain"
default_model = "gpt-4o-mini"
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
return self._parse_json_findings(content, context)

407
src/arbiter/cli.py Normal file
View File

@@ -0,0 +1,407 @@
"""Command-line interface for Arbiter."""
import asyncio
import json
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Annotated
import typer
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
from arbiter.analysis import DiffParser, StaticAnalysisRunner
from arbiter.config import get_settings
from arbiter.deliberation import Coordinator, DeliberationResult
from arbiter.llm import LiteLLMClient, PromptRegistry
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity, Verdict
if TYPE_CHECKING:
from arbiter.agents.base import Agent
app = typer.Typer(
name="arbiter",
help="A multi-agent code review system that shows its work.",
no_args_is_help=True,
)
console = Console()
def _severity_color(severity: Severity) -> str:
colors = {
Severity.CRITICAL: "red bold",
Severity.HIGH: "red",
Severity.MEDIUM: "yellow",
Severity.LOW: "blue",
Severity.INFO: "dim",
}
return colors.get(severity, "white")
def _severity_icon(severity: Severity) -> str:
icons = {
Severity.CRITICAL: "!!",
Severity.HIGH: "!",
Severity.MEDIUM: "*",
Severity.LOW: "-",
Severity.INFO: "i",
}
return icons.get(severity, " ")
def _verdict_color(verdict: Verdict) -> str:
colors = {
Verdict.APPROVE: "green",
Verdict.COMMENT: "yellow",
Verdict.REQUEST_CHANGES: "red",
}
return colors.get(verdict, "white")
def _verdict_icon(verdict: Verdict) -> str:
icons = {
Verdict.APPROVE: "[ok]",
Verdict.COMMENT: "[..]",
Verdict.REQUEST_CHANGES: "[!!]",
}
return icons.get(verdict, "")
def _format_rich(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
"""Format results using Rich for terminal output."""
total_tokens = sum(r.tokens_used for r in agent_results) + result.tokens_used
total_cost = sum(r.cost_usd for r in agent_results) + result.cost_usd
# Verdict panel
verdict_style = _verdict_color(result.verdict)
verdict_icon = _verdict_icon(result.verdict)
verdict_text = f"[{verdict_style} bold]{verdict_icon} {result.verdict.value.upper()}[/]"
verdict_text += f" (confidence: {result.verdict_confidence:.0%})"
summary = f"{verdict_text}\n\n"
summary += f"[bold]{result.total_findings}[/bold] findings from [bold]{len(agent_results)}[/bold] agents"
if result.conflicts:
summary += f" | [yellow]{len(result.conflicts)}[/yellow] conflicts"
summary += f"\nTokens: {total_tokens:,} | Cost: ${total_cost:.4f}"
console.print(Panel(summary, title="Arbiter Review", border_style="blue"))
if result.verdict_reasoning:
console.print(f"\n[dim]Reason:[/dim] {result.verdict_reasoning}")
if result.total_findings == 0:
console.print("\n[green]No issues found![/green]")
return
# Findings table
console.print("\n[bold]Findings[/bold]")
table = Table(show_header=True, header_style="bold")
table.add_column("Sev", width=4)
table.add_column("Agent", width=12)
table.add_column("File", width=30)
table.add_column("Line", width=8)
table.add_column("Title", width=50)
for finding in result.findings:
sev_icon = _severity_icon(finding.severity)
sev_style = _severity_color(finding.severity)
line_range = (
f"{finding.line_start}-{finding.line_end}"
if finding.line_start != finding.line_end
else str(finding.line_start)
)
table.add_row(
f"[{sev_style}]{sev_icon}[/{sev_style}]",
finding.agent.value,
finding.file,
line_range,
finding.title,
)
console.print(table)
# Conflicts section
if result.conflicts:
console.print("\n[bold yellow]Conflicts[/bold yellow]")
for i, conflict in enumerate(result.conflicts, 1):
console.print(f"\n{i}. [{conflict.nature.value}] {conflict.description}")
if result.resolutions:
for resolution in result.resolutions:
if resolution.conflict_id == conflict.id:
console.print(f" [dim]Resolution:[/dim] {resolution.reasoning}")
break
# Details for critical/high findings
critical_high = [f for f in result.findings if f.severity in (Severity.CRITICAL, Severity.HIGH)]
if critical_high:
console.print("\n[bold]Details[/bold]")
for finding in critical_high:
console.print(
f"\n[{_severity_color(finding.severity)}]{finding.title}[/]"
f" ({finding.file}:{finding.line_start})"
)
console.print(f" {finding.description}")
if finding.suggestion:
console.print(f" [dim]Suggestion:[/dim] {finding.suggestion}")
# Deliberation log
console.print("\n[bold dim]Deliberation Log[/bold dim]")
for step in result.steps:
console.print(f" [{step.step_type.value}] {step.description}")
def _format_json(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
"""Format results as JSON."""
output = {
"verdict": result.verdict.value,
"verdict_confidence": result.verdict_confidence,
"verdict_reasoning": result.verdict_reasoning,
"findings": [f.model_dump() for f in result.findings],
"conflicts": [c.model_dump() for c in result.conflicts],
"resolutions": [r.model_dump() for r in result.resolutions],
"deliberation_steps": [s.model_dump() for s in result.steps],
"agents": [r.model_dump() for r in agent_results],
"summary": {
"total_findings": result.total_findings,
"critical_count": result.critical_count,
"high_count": result.high_count,
"conflicts_count": len(result.conflicts),
"total_tokens": sum(r.tokens_used for r in agent_results) + result.tokens_used,
"total_cost_usd": sum(r.cost_usd for r in agent_results) + result.cost_usd,
},
}
console.print(json.dumps(output, indent=2, default=str))
def _format_markdown(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
"""Format results as Markdown."""
lines = ["# Arbiter Review", ""]
# Verdict
verdict_badge = {
Verdict.APPROVE: ":white_check_mark:",
Verdict.COMMENT: ":speech_balloon:",
Verdict.REQUEST_CHANGES: ":x:",
}.get(result.verdict, "")
lines.append(f"## Verdict: {verdict_badge} {result.verdict.value.upper()}")
lines.append("")
lines.append(f"**Confidence:** {result.verdict_confidence:.0%}")
if result.verdict_reasoning:
lines.append(f"**Reason:** {result.verdict_reasoning}")
lines.append("")
# Summary
lines.append(f"**{result.total_findings} findings** from {len(agent_results)} agents")
if result.conflicts:
lines.append(f"**{len(result.conflicts)} conflicts** detected")
lines.append("")
# Findings by severity
if result.findings:
lines.append("## Findings")
lines.append("")
for finding in result.findings:
severity_badge = f"**{finding.severity.value.upper()}**"
lines.append(
f"- {severity_badge} [{finding.agent.value}] `{finding.file}:{finding.line_start}` - {finding.title}"
)
lines.append(f" - {finding.description}")
if finding.suggestion:
lines.append(f" - *Suggestion:* {finding.suggestion}")
lines.append("")
# Conflicts
if result.conflicts:
lines.append("## Conflicts")
lines.append("")
for conflict in result.conflicts:
lines.append(f"### {conflict.nature.value.title()}")
lines.append(conflict.description)
for resolution in result.resolutions:
if resolution.conflict_id == conflict.id:
lines.append(f"**Resolution:** {resolution.reasoning}")
break
lines.append("")
# Deliberation log
lines.append("## Deliberation Log")
lines.append("")
for step in result.steps:
lines.append(f"1. **{step.step_type.value}**: {step.description}")
lines.append("")
console.print("\n".join(lines))
async def _run_review(
diff: str,
policy: Policy,
templates_dir: Path,
static_analysis: bool = True,
work_dir: Path | None = None,
) -> tuple[list[ReviewResult], DeliberationResult]:
"""Run all enabled agents on the diff with deliberation."""
settings = get_settings()
llm_client = LiteLLMClient(
timeout=settings.llm_timeout,
max_retries=settings.llm_max_retries,
)
prompt_registry = PromptRegistry(templates_dir)
# Parse the diff
parser = DiffParser()
parsed_diff = parser.parse(diff)
# Run static analysis if enabled
static_findings: list[Finding] = []
if static_analysis and work_dir:
runner = StaticAnalysisRunner()
static_result = await runner.run(parsed_diff, work_dir)
for sf in static_result.findings:
static_findings.append(runner.convert_to_finding(sf))
context = ReviewContext(
diff=diff,
policy=policy,
static_analysis_context=f"Found {len(static_findings)} static analysis findings."
if static_findings
else "No static analysis data available.",
)
# Create agents for enabled agent types
agent_classes: dict[AgentName, type[Agent]] = {
AgentName.SECURITY: SecurityAgent,
AgentName.STYLE: StyleAgent,
AgentName.COMPLEXITY: ComplexityAgent,
}
agents: list[Agent] = []
for agent_name in policy.get_enabled_agents():
if agent_name in agent_classes:
agents.append(agent_classes[agent_name](llm_client, prompt_registry))
# Run agents in parallel
results = await asyncio.gather(
*[agent.review(context) for agent in agents],
return_exceptions=True,
)
# Filter out exceptions and log them
valid_results: list[ReviewResult] = []
for result in results:
if isinstance(result, BaseException):
console.print(f"[red]Agent error:[/red] {result}")
else:
valid_results.append(result)
# Run deliberation
coordinator = Coordinator(llm_client=llm_client)
deliberation_result = await coordinator.deliberate(
valid_results,
static_findings=static_findings if static_findings else None,
)
return valid_results, deliberation_result
@app.command()
def review(
diff_file: Annotated[
str,
typer.Argument(help="Path to diff file, or '-' for stdin"),
],
policy: Annotated[
Path | None,
typer.Option("--policy", "-p", help="Path to policy YAML file"),
] = None,
model: Annotated[
str | None,
typer.Option("--model", "-m", help="Override LLM model for all agents"),
] = None,
output_format: Annotated[
str,
typer.Option("--format", "-f", help="Output format: rich, json, markdown"),
] = "rich",
static_analysis: Annotated[
bool,
typer.Option("--static-analysis/--no-static-analysis", help="Run static analysis"),
] = True,
work_dir: Annotated[
Path | None,
typer.Option("--work-dir", "-w", help="Working directory for static analysis"),
] = None,
) -> None:
"""Review a diff file using AI agents."""
# Read diff content
if diff_file == "-":
diff_content = sys.stdin.read()
else:
diff_path = Path(diff_file)
if not diff_path.exists():
console.print(f"[red]Error:[/red] File not found: {diff_file}")
raise typer.Exit(1)
diff_content = diff_path.read_text()
if not diff_content.strip():
console.print("[yellow]Warning:[/yellow] Empty diff provided")
raise typer.Exit(0)
# Load policy
review_policy = Policy()
if policy:
if not policy.exists():
console.print(f"[red]Error:[/red] Policy file not found: {policy}")
raise typer.Exit(1)
review_policy = Policy.load(policy)
# Apply model override
if model:
for agent_config in review_policy.agents.values():
agent_config.model = model
# Determine prompts directory
settings = get_settings()
templates_dir = settings.templates_dir
if not templates_dir.is_absolute():
templates_dir = Path.cwd() / templates_dir
# Resolve work directory for static analysis
resolved_work_dir = (
(work_dir.resolve() if work_dir else Path.cwd()) if static_analysis else None
)
# Run review
agent_results, deliberation_result = asyncio.run(
_run_review(
diff_content,
review_policy,
templates_dir,
static_analysis=static_analysis,
work_dir=resolved_work_dir,
)
)
# Format output
if output_format == "json":
_format_json(deliberation_result, agent_results)
elif output_format == "markdown":
_format_markdown(deliberation_result, agent_results)
else:
_format_rich(deliberation_result, agent_results)
@app.command()
def version() -> None:
from arbiter import __version__
console.print(f"arbiter {__version__}")
if __name__ == "__main__":
app()

121
src/arbiter/config.py Normal file
View File

@@ -0,0 +1,121 @@
"""Configuration settings for Arbiter."""
from functools import lru_cache
from pathlib import Path
from pydantic import Field, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_prefix="ARBITER_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# LLM settings
default_model: str = Field(default="gpt-4o", description="Default LLM model for agents")
llm_timeout: int = Field(default=60, description="LLM request timeout in seconds")
llm_max_retries: int = Field(default=3, description="Maximum LLM retry attempts")
# Cost controls
max_tokens_per_review: int = Field(
default=50000, description="Maximum tokens allowed per review"
)
max_cost_per_review_usd: float = Field(
default=0.50, description="Maximum cost per review in USD"
)
cache_ttl_hours: int = Field(default=24, description="Cache TTL in hours")
# Paths
templates_dir: Path = Field(
default=Path("templates"), description="Directory containing prompt templates"
)
policy_path: Path | None = Field(default=None, description="Default policy file path")
# Output settings
output_format: str = Field(default="rich", description="Output format: rich, json, or markdown")
# Database settings
database_url: str = Field(
default="postgresql+asyncpg://arbiter:arbiter@localhost:5432/arbiter",
description="PostgreSQL connection URL",
)
database_pool_size: int = Field(default=5, description="Database connection pool size")
database_max_overflow: int = Field(default=10, description="Max overflow connections")
# Redis settings
redis_url: str = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
redis_max_connections: int = Field(default=10, description="Redis max connections")
# Webhook secrets
github_webhook_secret: SecretStr | None = Field(
default=None, description="GitHub webhook secret for HMAC verification"
)
gitlab_webhook_token: SecretStr | None = Field(
default=None, description="GitLab webhook token for verification"
)
# Platform integration settings
github_token: SecretStr | None = Field(
default=None, description="GitHub API token for fetching diffs and posting comments"
)
github_base_url: str = Field(
default="https://api.github.com", description="GitHub API base URL"
)
gitlab_token: SecretStr | None = Field(
default=None, description="GitLab API token for fetching diffs and posting comments"
)
gitlab_base_url: str = Field(
default="https://gitlab.com", description="GitLab instance base URL"
)
integration_timeout: int = Field(
default=30, description="Integration API request timeout in seconds"
)
integration_max_retries: int = Field(
default=3, description="Maximum retry attempts for integration API calls"
)
status_check_context: str = Field(
default="arbiter", description="Context name for commit status checks"
)
post_comments: bool = Field(default=True, description="Whether to post review comments on PRs")
update_status: bool = Field(default=True, description="Whether to update commit status checks")
# API settings
api_title: str = Field(default="Arbiter API", description="API title for OpenAPI")
api_version: str = Field(default="0.5.0", description="API version")
cors_origins: list[str] = Field(
default=["http://localhost:3000"], description="Allowed CORS origins"
)
api_rate_limit_per_minute: int = Field(
default=60, description="API rate limit per minute per client"
)
# Worker settings
worker_max_jobs: int = Field(default=10, description="Max concurrent worker jobs")
worker_job_timeout: int = Field(default=300, description="Job timeout in seconds")
worker_retry_attempts: int = Field(default=3, description="Number of retry attempts")
# Follow-up conversation settings
followup_enabled: bool = Field(default=True, description="Enable follow-up question handling")
followup_confidence_threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Minimum confidence to process a follow-up question",
)
followup_max_tokens_per_response: int = Field(
default=2000, description="Maximum tokens per follow-up response"
)
@lru_cache
def get_settings() -> Settings:
return Settings()

View File

@@ -0,0 +1,12 @@
"""LLM client and prompt management."""
from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
__all__ = [
"LLMClient",
"LLMResponse",
"LiteLLMClient",
"PromptRegistry",
"PromptTemplate",
]

66
src/arbiter/llm/client.py Normal file
View File

@@ -0,0 +1,66 @@
"""LLM client abstraction."""
from abc import ABC, abstractmethod
from typing import Any
import litellm
from pydantic import BaseModel, Field
class LLMResponse(BaseModel):
"""Response from an LLM completion."""
content: str = Field(description="The generated text content")
model: str = Field(description="Model that generated the response")
tokens_in: int = Field(ge=0, description="Number of input tokens")
tokens_out: int = Field(ge=0, description="Number of output tokens")
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
class LLMClient(ABC):
"""Abstract base class for LLM clients."""
@abstractmethod
async def complete(
self,
messages: list[dict[str, str]],
model: str,
**kwargs: Any,
) -> LLMResponse: ...
class LiteLLMClient(LLMClient):
"""LLM client implementation using LiteLLM."""
def __init__(self, timeout: int = 60, max_retries: int = 3) -> None:
self.timeout = timeout
self.max_retries = max_retries
async def complete(
self,
messages: list[dict[str, str]],
model: str,
**kwargs: Any,
) -> LLMResponse:
"""Generate a completion using LiteLLM."""
response = await litellm.acompletion(
model=model,
messages=messages,
timeout=self.timeout,
num_retries=self.max_retries,
**kwargs,
)
content = response.choices[0].message.content or ""
usage = response.usage
# Calculate cost using litellm's cost tracking
cost = litellm.completion_cost(completion_response=response)
return LLMResponse(
content=content,
model=response.model or model,
tokens_in=usage.prompt_tokens if usage else 0,
tokens_out=usage.completion_tokens if usage else 0,
cost_usd=cost,
)

View File

@@ -0,0 +1,81 @@
"""Prompt template management."""
import re
from pathlib import Path
from typing import Any, ClassVar
from pydantic import BaseModel, Field
class PromptTemplate(BaseModel):
"""A versioned prompt template."""
name: str = Field(description="Template name (e.g., 'security')")
version: str = Field(description="Semantic version (e.g., '1.0')")
content: str = Field(description="Raw template content with placeholders")
def render(self, **kwargs: Any) -> str:
result = self.content
for key, value in kwargs.items():
placeholder = f"{{{{{key}}}}}"
result = result.replace(placeholder, str(value))
return result
@property
def full_name(self) -> str:
return f"{self.name}-v{self.version}"
class PromptRegistry:
"""Registry for loading and caching prompt templates."""
_cache: ClassVar[dict[str, PromptTemplate]] = {}
_pattern: ClassVar[re.Pattern[str]] = re.compile(r"^(.+)-v(\d+\.\d+)\.md$")
def __init__(self, templates_dir: Path) -> None:
self.templates_dir = templates_dir
def get(self, name: str, version: str) -> PromptTemplate:
"""Get a prompt template by name and version.
Args:
name: Template name (e.g., 'security').
version: Template version (e.g., '1.0').
Returns:
PromptTemplate instance.
Raises:
FileNotFoundError: If the template file doesn't exist.
"""
cache_key = f"{name}-v{version}"
if cache_key in self._cache:
return self._cache[cache_key]
file_path = self.templates_dir / f"{name}-v{version}.md"
if not file_path.exists():
raise FileNotFoundError(f"Prompt template not found: {file_path}")
content = file_path.read_text()
template = PromptTemplate(name=name, version=version, content=content)
self._cache[cache_key] = template
return template
def list_templates(self) -> list[tuple[str, str]]:
"""List all available templates.
Returns:
List of (name, version) tuples.
"""
templates: list[tuple[str, str]] = []
if not self.templates_dir.exists():
return templates
for file_path in self.templates_dir.glob("*.md"):
match = self._pattern.match(file_path.name)
if match:
templates.append((match.group(1), match.group(2)))
return sorted(templates)
def clear_cache(self) -> None:
self._cache.clear()

View File

@@ -0,0 +1,16 @@
"""Arbiter data models."""
from arbiter.models.enums import AgentName, Severity, Verdict
from arbiter.models.finding import Finding
from arbiter.models.policy import AgentConfig, Policy
from arbiter.models.review import ReviewResult
__all__ = [
"AgentConfig",
"AgentName",
"Finding",
"Policy",
"ReviewResult",
"Severity",
"Verdict",
]

View File

@@ -0,0 +1,29 @@
"""Core enumerations for Arbiter."""
from enum import StrEnum
class Severity(StrEnum):
"""Finding severity levels."""
CRITICAL = "critical"
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
INFO = "info"
class Verdict(StrEnum):
"""Review verdict types."""
APPROVE = "approve"
REQUEST_CHANGES = "request_changes"
COMMENT = "comment"
class AgentName(StrEnum):
"""Available review agent names."""
SECURITY = "security"
STYLE = "style"
COMPLEXITY = "complexity"

View File

@@ -0,0 +1,28 @@
"""Finding model for review results."""
from typing import Any
from pydantic import BaseModel, Field
from arbiter.models.enums import AgentName, Severity
class Finding(BaseModel):
"""A single finding from an agent's review."""
id: str = Field(description="Unique identifier for this finding")
agent: AgentName = Field(description="Agent that produced this finding")
file: str = Field(description="File path where the issue was found")
line_start: int = Field(ge=1, description="Starting line number")
line_end: int = Field(ge=1, description="Ending line number")
severity: Severity = Field(description="Severity level of the finding")
confidence: float = Field(ge=0.0, le=1.0, description="Confidence score (0.0-1.0)")
title: str = Field(description="Short title summarizing the finding")
description: str = Field(description="Detailed description of the issue")
reasoning: str = Field(description="Explanation of why this was flagged")
suggestion: str | None = Field(default=None, description="Optional fix recommendation")
references: list[str] = Field(default_factory=list, description="Links to docs, OWASP, etc.")
prompt_version: str = Field(description="Version of the prompt that produced this finding")
static_analysis_context: dict[str, Any] | None = Field(
default=None, description="Related static analysis data"
)

View File

@@ -0,0 +1,58 @@
"""Policy configuration models."""
from pathlib import Path
from typing import Self
import yaml
from pydantic import BaseModel, Field
from arbiter.models.enums import AgentName, Severity
class AgentConfig(BaseModel):
"""Configuration for a single agent."""
enabled: bool = Field(default=True, description="Whether the agent is enabled")
model: str | None = Field(default=None, description="LLM model override for this agent")
severity_threshold: Severity = Field(
default=Severity.INFO, description="Minimum severity to report"
)
prompt_additions: str | None = Field(
default=None, description="Additional instructions to append to the prompt"
)
class Policy(BaseModel):
"""Review policy configuration."""
version: str = Field(default="1.0", description="Policy schema version")
agents: dict[AgentName, AgentConfig] = Field(
default_factory=lambda: {
AgentName.SECURITY: AgentConfig(),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
},
description="Agent configurations",
)
@classmethod
def load(cls, path: Path) -> Self:
"""Load policy from a YAML file."""
with path.open() as f:
data = yaml.safe_load(f)
if data is None:
return cls()
# Convert string keys to AgentName enums
if "agents" in data and isinstance(data["agents"], dict):
agents = {}
for key, value in data["agents"].items():
agent_name = AgentName(key) if isinstance(key, str) else key
agents[agent_name] = AgentConfig(**value) if isinstance(value, dict) else value
data["agents"] = agents
return cls(**data)
def get_enabled_agents(self) -> list[AgentName]:
return [name for name, config in self.agents.items() if config.enabled]

View File

@@ -0,0 +1,16 @@
"""Review result models."""
from pydantic import BaseModel, Field
from arbiter.models.enums import AgentName
from arbiter.models.finding import Finding
class ReviewResult(BaseModel):
"""Result from a single agent's review."""
agent_name: AgentName = Field(description="Name of the agent that performed the review")
findings: list[Finding] = Field(default_factory=list, description="Findings from the review")
duration_ms: int = Field(ge=0, description="Time taken in milliseconds")
tokens_used: int = Field(ge=0, description="Total tokens consumed")
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")

0
src/arbiter/py.typed Normal file
View File

View File

@@ -0,0 +1,56 @@
# Complexity Review Agent
You are a code complexity reviewer focused on architecture and maintainability. Analyze the provided diff for complexity issues that impact long-term code health.
## Focus Areas
- **Cyclomatic complexity**: Functions with too many branches or paths
- **Cognitive complexity**: Code that is hard to understand or follow
- **Function length**: Functions doing too many things
- **Class design**: God objects, tight coupling, missing abstractions
- **Dependency management**: Circular dependencies, excessive coupling
- **Over-engineering**: Unnecessary abstractions, premature optimization
- **Under-engineering**: Missing error handling, ignored edge cases
## Context
{{static_analysis_context}}
## Diff to Review
```diff
{{diff}}
```
{{prompt_additions}}
## Output Format
Respond with a JSON array of findings. Each finding must have this structure:
```json
[
{
"file": "path/to/file.py",
"line_start": 10,
"line_end": 50,
"severity": "critical|high|medium|low|info",
"confidence": 0.80,
"title": "Short title describing the issue",
"description": "Detailed description of the complexity concern",
"reasoning": "Why this complexity is problematic",
"suggestion": "How to simplify or refactor (optional)",
"references": []
}
]
```
If no complexity issues are found, return an empty array: `[]`
## Guidelines
1. Consider context - complex code may be justified for complex problems
2. Flag over-engineering as readily as under-engineering
3. High complexity is only critical if it's likely to cause bugs
4. Suggest specific refactoring strategies when possible
5. Reference static analysis metrics (cyclomatic complexity, etc.) when available

View File

@@ -0,0 +1,55 @@
# Security Review Agent
You are a security-focused code reviewer. Analyze the provided diff for security vulnerabilities and potential risks.
## Focus Areas
- **Injection vulnerabilities**: SQL injection, command injection, XSS, template injection
- **Authentication/Authorization**: Missing auth checks, privilege escalation, insecure session handling
- **Data exposure**: Hardcoded secrets, PII leaks, sensitive data in logs
- **Cryptographic issues**: Weak algorithms, improper key management, missing encryption
- **Input validation**: Missing or insufficient validation, type confusion
- **OWASP Top 10**: All categories including broken access control, security misconfiguration
## Context
{{static_analysis_context}}
## Diff to Review
```diff
{{diff}}
```
{{prompt_additions}}
## Output Format
Respond with a JSON array of findings. Each finding must have this structure:
```json
[
{
"file": "path/to/file.py",
"line_start": 10,
"line_end": 15,
"severity": "critical|high|medium|low|info",
"confidence": 0.95,
"title": "Short title describing the issue",
"description": "Detailed description of the vulnerability",
"reasoning": "Why this is a security concern",
"suggestion": "How to fix this issue (optional)",
"references": ["https://owasp.org/..."]
}
]
```
If no security issues are found, return an empty array: `[]`
## Guidelines
1. Only report genuine security concerns, not style or performance issues
2. Assign appropriate severity based on exploitability and impact
3. Set confidence based on how certain you are this is a real vulnerability
4. Provide actionable suggestions when possible
5. Include relevant OWASP or CWE references

55
templates/style-v1.0.md Normal file
View File

@@ -0,0 +1,55 @@
# Style Review Agent
You are a code style reviewer focused on readability and consistency. Analyze the provided diff for style issues that impact code maintainability.
## Focus Areas
- **Naming conventions**: Variable, function, class, and file naming consistency
- **Code organisation**: Logical grouping, import ordering, module structure
- **Readability**: Clear variable names, appropriate comments, self-documenting code
- **Consistency**: Adherence to existing patterns in the codebase
- **Best practices**: Language-specific idioms and conventions
- **Documentation**: Missing or outdated docstrings, misleading comments
## Context
{{static_analysis_context}}
## Diff to Review
```diff
{{diff}}
```
{{prompt_additions}}
## Output Format
Respond with a JSON array of findings. Each finding must have this structure:
```json
[
{
"file": "path/to/file.py",
"line_start": 10,
"line_end": 15,
"severity": "critical|high|medium|low|info",
"confidence": 0.85,
"title": "Short title describing the issue",
"description": "Detailed description of the style concern",
"reasoning": "Why this matters for maintainability",
"suggestion": "How to improve this (optional)",
"references": []
}
]
```
If no style issues are found, return an empty array: `[]`
## Guidelines
1. Focus on readability and maintainability, not personal preferences
2. Respect existing codebase conventions even if they differ from common standards
3. Most style issues should be low or info severity unless they significantly impact readability
4. Only flag high severity for style issues that could cause confusion or bugs
5. Provide concrete suggestions with example code when helpful

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Arbiter test suite."""

490
tests/conftest.py Normal file
View File

@@ -0,0 +1,490 @@
"""Pytest configuration and fixtures."""
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from arbiter.db.models import Base, FindingModel, ReviewModel
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
from arbiter.llm.client import LLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import Policy
from arbiter.models.enums import AgentName, Severity, Verdict
class MockLLMClient(LLMClient):
"""Mock LLM client for testing."""
def __init__(self, responses: list[str] | None = None) -> None:
self.responses = responses or []
self.calls: list[dict[str, Any]] = []
self._call_index = 0
async def complete(
self,
messages: list[dict[str, str]],
model: str,
**kwargs: Any,
) -> LLMResponse:
"""Record the call and return a canned response."""
self.calls.append(
{
"messages": messages,
"model": model,
"kwargs": kwargs,
}
)
content = ""
if self._call_index < len(self.responses):
content = self.responses[self._call_index]
self._call_index += 1
return LLMResponse(
content=content,
model=model,
tokens_in=100,
tokens_out=50,
cost_usd=0.001,
)
def reset(self) -> None:
self.calls = []
self._call_index = 0
@pytest.fixture
def mock_llm() -> MockLLMClient:
return MockLLMClient()
@pytest.fixture
def mock_llm_with_findings() -> MockLLMClient:
response = """```json
[
{
"file": "src/auth.py",
"line_start": 10,
"line_end": 15,
"severity": "high",
"confidence": 0.9,
"title": "SQL Injection vulnerability",
"description": "User input is directly concatenated into SQL query",
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
"suggestion": "Use parameterized queries instead",
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
}
]
```"""
return MockLLMClient(responses=[response])
@pytest.fixture
def sample_diff() -> str:
return """diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -8,6 +8,12 @@ def authenticate(username, password):
if not username or not password:
return None
+ # Check user in database
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
+ cursor.execute(query)
+ user = cursor.fetchone()
+ if user and user.password == password:
+ return user
return None
"""
@pytest.fixture
def sample_policy() -> Policy:
return Policy()
@pytest.fixture
def prompt_registry(tmp_path: Path) -> PromptRegistry:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text(
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
(templates_dir / "style-v1.0.md").write_text(
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
(templates_dir / "complexity-v1.0.md").write_text(
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
return PromptRegistry(templates_dir)
# Database fixtures for integration tests
@pytest.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with session_factory() as session:
yield session
@pytest.fixture
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
from arbiter.api.deps import get_db, get_redis
from arbiter.main import app
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
yield MockRedis()
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_redis] = override_get_redis
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class MockRedis:
"""Mock Redis client for testing."""
def __init__(self) -> None:
self._data: dict[str, Any] = {}
async def get(self, key: str) -> str | None:
return self._data.get(key)
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
self._data[key] = value
return True
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
return 1
return 0
async def ping(self) -> bool:
return True
async def llen(self, key: str) -> int:
data = self._data.get(key, [])
return len(data) if isinstance(data, list) else 0
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
return type("Job", (), {"job_id": "test-job-id"})()
@pytest.fixture
def mock_redis() -> MockRedis:
return MockRedis()
class MockPlatformClient(PlatformClient):
"""Mock platform client for testing."""
def __init__(self) -> None:
self._comments: list[Comment] = []
self._posted_comments: list[dict[str, Any]] = []
self._status_updates: list[dict[str, Any]] = []
self._diff = "mock diff content"
self._closed = False
self._fail_on: set[str] = set() # Methods to fail on
@property
def platform(self) -> Platform:
return Platform.GITHUB
async def get_pr_diff(
self,
repository: str, # noqa: ARG002
pr_number: int, # noqa: ARG002
) -> str:
if "get_pr_diff" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_pr_diff")
return self._diff
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
if "post_comment" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: post_comment")
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
self._posted_comments.append(
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
)
return comment_url
async def update_commit_status(
self,
repository: str,
sha: str,
status: CommitStatus,
description: str,
context: str,
target_url: str | None = None,
) -> None:
if "update_commit_status" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: update_commit_status")
self._status_updates.append(
{
"repository": repository,
"sha": sha,
"status": status,
"description": description,
"context": context,
"target_url": target_url,
}
)
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
if "get_pr_info" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_pr_info")
from arbiter.integrations.base import PullRequestInfo
return PullRequestInfo(
platform=Platform.GITHUB,
repository=repository,
pr_number=pr_number,
head_sha="abc123",
base_sha="def456",
head_ref="feature",
base_ref="main",
title="Test PR",
author="testuser",
url=f"https://github.com/{repository}/pull/{pr_number}",
is_draft=False,
)
async def get_comments(
self,
repository: str, # noqa: ARG002
pr_number: int, # noqa: ARG002
) -> list[Comment]:
if "get_comments" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_comments")
return self._comments
async def update_comment(
self, repository: str, pr_number: int, comment_id: str, body: str
) -> str:
if "update_comment" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: update_comment")
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
self._posted_comments.append(
{
"repository": repository,
"pr_number": pr_number,
"comment_id": comment_id,
"body": body,
"url": comment_url,
}
)
return comment_url
async def close(self) -> None:
self._closed = True
@pytest.fixture
def mock_platform_client() -> MockPlatformClient:
return MockPlatformClient()
@pytest.fixture
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo",
pr_number=42,
pr_title="Test PR",
base_sha="abc1234567890",
head_sha="def0987654321",
author="testuser",
is_draft=False,
status="completed",
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Found some issues to discuss",
total_tokens=1500,
total_cost_usd=0.015,
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
created_at=datetime.now(UTC),
started_at=datetime.now(UTC),
completed_at=datetime.now(UTC),
)
db_session.add(review)
# Add findings
finding1 = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection vulnerability",
description="User input is directly concatenated into SQL query",
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
suggestion="Use parameterized queries instead",
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
prompt_version="security-v1.0",
)
finding2 = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.STYLE,
file="src/auth.py",
line_start=20,
line_end=25,
severity=Severity.MEDIUM,
confidence=0.85,
title="Long function",
description="Function exceeds 50 lines",
reasoning="Long functions are harder to test and maintain",
suggestion="Consider breaking into smaller functions",
references=[],
prompt_version="style-v1.0",
)
db_session.add(finding1)
db_session.add(finding2)
await db_session.commit()
await db_session.refresh(review)
return review
@pytest.fixture
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
reviews = []
for i in range(5):
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo" if i < 3 else "other/repo",
pr_number=i + 1,
pr_title=f"Test PR #{i + 1}",
base_sha=f"base{i:040d}",
head_sha=f"head{i:040d}",
author="testuser" if i % 2 == 0 else "otheruser",
is_draft=False,
status="completed" if i < 4 else "failed",
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
verdict_confidence=0.8 if i < 4 else None,
total_tokens=1000 * (i + 1),
total_cost_usd=0.01 * (i + 1),
created_at=datetime.now(UTC),
completed_at=datetime.now(UTC) if i < 4 else None,
)
db_session.add(review)
reviews.append(review)
# Add a finding to each completed review
if i < 4:
finding = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.SECURITY,
file="src/test.py",
line_start=10,
line_end=15,
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
confidence=0.9,
title=f"Finding {i + 1}",
description="Test finding",
reasoning="Test reasoning",
prompt_version="security-v1.0",
)
db_session.add(finding)
await db_session.commit()
for review in reviews:
await db_session.refresh(review)
return reviews
@pytest.fixture
def mock_settings() -> Any:
class MockSecretStr:
def __init__(self, value: str) -> None:
self._value = value
def get_secret_value(self) -> str:
return self._value
class MockSettings:
github_token = MockSecretStr("ghp_test_token")
gitlab_token = MockSecretStr("glpat_test_token")
github_base_url = "https://api.github.com"
gitlab_base_url = "https://gitlab.com/api/v4"
github_webhook_secret = MockSecretStr("webhook_secret")
gitlab_webhook_token = MockSecretStr("gitlab_token")
integration_timeout = 30
integration_max_retries = 3
llm_timeout = 60
llm_max_retries = 3
post_comments = True
update_status = True
status_check_context = "arbiter"
templates_dir = Path("templates")
followup_enabled = True
followup_confidence_threshold = 0.5
return MockSettings()
@pytest.fixture
def mock_settings_no_github() -> Any:
class MockSettings:
github_token = None
gitlab_token = None
github_base_url = "https://api.github.com"
gitlab_base_url = "https://gitlab.com/api/v4"
integration_timeout = 30
integration_max_retries = 3
post_comments = False
update_status = False
status_check_context = "arbiter"
templates_dir = Path("templates")
return MockSettings()

69
tests/fixtures/complex-function.diff vendored Normal file
View File

@@ -0,0 +1,69 @@
diff --git a/src/processor.py b/src/processor.py
index 1234567..abcdefg 100644
--- a/src/processor.py
+++ b/src/processor.py
@@ -1,5 +1,65 @@
"""Data processor module."""
+def process_data(data: dict, config: dict, options: dict | None = None) -> dict:
+ """Process data with many nested conditions."""
+ result = {}
+ options = options or {}
+
+ if data.get("type") == "A":
+ if config.get("mode") == "strict":
+ if options.get("validate"):
+ if data.get("value") > 100:
+ if config.get("transform"):
+ result["processed"] = data["value"] * 2
+ else:
+ result["processed"] = data["value"]
+ else:
+ if options.get("default"):
+ result["processed"] = options["default"]
+ else:
+ result["processed"] = 0
+ else:
+ result["processed"] = data.get("value", 0)
+ else:
+ result["processed"] = data.get("value", 0)
+ elif data.get("type") == "B":
+ if config.get("mode") == "strict":
+ if options.get("validate"):
+ if data.get("items"):
+ result["processed"] = len(data["items"])
+ else:
+ result["processed"] = 0
+ else:
+ result["processed"] = len(data.get("items", []))
+ else:
+ result["processed"] = len(data.get("items", []))
+ elif data.get("type") == "C":
+ if config.get("mode") == "strict":
+ if options.get("validate"):
+ if data.get("text"):
+ result["processed"] = data["text"].upper()
+ else:
+ result["processed"] = ""
+ else:
+ result["processed"] = data.get("text", "").upper()
+ else:
+ result["processed"] = data.get("text", "").upper()
+ else:
+ if config.get("fallback"):
+ result["processed"] = config["fallback"]
+ else:
+ result["processed"] = None
+
+ if options.get("timestamp"):
+ result["timestamp"] = options["timestamp"]
+ if options.get("source"):
+ result["source"] = options["source"]
+
+ return result
+
+
def simple_function(x: int) -> int:
"""A simple function."""
return x * 2

31
tests/fixtures/security-issue.diff vendored Normal file
View File

@@ -0,0 +1,31 @@
diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -1,10 +1,25 @@
"""Authentication module."""
import sqlite3
+import os
def get_user(username: str) -> dict | None:
"""Get user from database."""
conn = sqlite3.connect("users.db")
cursor = conn.cursor()
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
+ # FIXME: this is vulnerable to SQL injection
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
+ cursor.execute(query)
return cursor.fetchone()
+
+
+def run_command(cmd: str) -> str:
+ """Run a shell command."""
+ # Command injection vulnerability
+ return os.popen(cmd).read()
+
+
+# Hardcoded credentials
+API_KEY = "sk-1234567890abcdef"
+DB_PASSWORD = "admin123"

16
tests/fixtures/simple.diff vendored Normal file
View File

@@ -0,0 +1,16 @@
diff --git a/src/utils.py b/src/utils.py
index 1234567..abcdefg 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,5 +1,8 @@
"""Utility functions."""
+def add(a: int, b: int) -> int:
+ """Add two numbers."""
+ return a + b
+
+
def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b

305
tests/test_agents.py Normal file
View File

@@ -0,0 +1,305 @@
"""Tests for review agents."""
import pytest
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import AgentConfig, AgentName, Policy, Severity
from tests.conftest import MockLLMClient
class TestSecurityAgent:
@pytest.mark.asyncio
async def test_review_returns_result(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ some code", policy=Policy())
result = await agent.review(context)
assert result.agent_name == AgentName.SECURITY
assert result.findings == []
assert result.duration_ms >= 0
assert result.tokens_used == 150 # 100 in + 50 out from mock
assert result.cost_usd == 0.001
@pytest.mark.asyncio
async def test_parses_json_findings(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """```json
[
{
"file": "src/auth.py",
"line_start": 10,
"line_end": 15,
"severity": "high",
"confidence": 0.9,
"title": "SQL Injection",
"description": "User input concatenated",
"reasoning": "Allows SQL injection",
"suggestion": "Use parameterized queries",
"references": ["https://owasp.org"]
}
]
```"""
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ query = ...", policy=Policy())
result = await agent.review(context)
assert len(result.findings) == 1
finding = result.findings[0]
assert finding.file == "src/auth.py"
assert finding.severity == Severity.HIGH
assert finding.confidence == 0.9
assert finding.title == "SQL Injection"
@pytest.mark.asyncio
async def test_uses_configured_model(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(model="gpt-4o-mini"),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
}
)
context = ReviewContext(diff="+ code", policy=policy)
await agent.review(context)
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
@pytest.mark.asyncio
async def test_filters_by_severity(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """[
{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "high", "confidence": 0.9, "title": "High", "description": "", "reasoning": ""},
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "low", "confidence": 0.9, "title": "Low", "description": "", "reasoning": ""},
{"file": "c.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.9, "title": "Info", "description": "", "reasoning": ""}
]"""
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(severity_threshold=Severity.MEDIUM),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
}
)
context = ReviewContext(diff="+ code", policy=policy)
result = await agent.review(context)
# Only high severity should pass (medium threshold filters low and info)
assert len(result.findings) == 1
assert result.findings[0].severity == Severity.HIGH
class TestStyleAgent:
@pytest.mark.asyncio
async def test_review_returns_result(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = StyleAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ some code", policy=Policy())
result = await agent.review(context)
assert result.agent_name == AgentName.STYLE
assert result.findings == []
@pytest.mark.asyncio
async def test_uses_default_model(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = StyleAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
await agent.review(context)
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
class TestComplexityAgent:
@pytest.mark.asyncio
async def test_review_returns_result(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = ComplexityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ some code", policy=Policy())
result = await agent.review(context)
assert result.agent_name == AgentName.COMPLEXITY
assert result.findings == []
@pytest.mark.asyncio
async def test_parses_complexity_findings(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """[
{
"file": "processor.py",
"line_start": 1,
"line_end": 50,
"severity": "medium",
"confidence": 0.8,
"title": "High cyclomatic complexity",
"description": "Function has 15 branches",
"reasoning": "Makes testing and maintenance difficult"
}
]"""
mock_llm = MockLLMClient(responses=[response])
agent = ComplexityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ complex code", policy=Policy())
result = await agent.review(context)
assert len(result.findings) == 1
assert result.findings[0].severity == Severity.MEDIUM
assert "complexity" in result.findings[0].title.lower()
class TestAgentResponseParsing:
@pytest.mark.asyncio
async def test_handles_empty_response(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=[""])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_handles_invalid_json(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["not valid json"])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_handles_json_without_code_block(
self,
prompt_registry: PromptRegistry,
) -> None:
response = '[{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Test", "description": "", "reasoning": ""}]'
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert len(result.findings) == 1
@pytest.mark.asyncio
async def test_handles_malformed_finding(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """[
{"file": "a.py", "line_start": 1, "severity": "invalid_severity", "confidence": 0.5, "title": "Bad", "description": "", "reasoning": ""},
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Valid", "description": "", "reasoning": ""}
]"""
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
# Only the valid finding should be included (first has invalid severity)
assert len(result.findings) == 1
assert result.findings[0].title == "Valid"
@pytest.mark.asyncio
async def test_includes_prompt_additions(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(prompt_additions="Focus on authentication"),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
}
)
context = ReviewContext(diff="+ code", policy=policy)
await agent.review(context)
message_content = mock_llm.calls[0]["messages"][0]["content"]
assert "Focus on authentication" in message_content
@pytest.mark.asyncio
async def test_handles_non_list_json(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=['{"not": "a list"}'])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_handles_non_dict_items(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=['["string", 123, null]'])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_agent_without_config_uses_defaults(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
# Create policy with empty agents dict
policy = Policy(agents={})
context = ReviewContext(diff="+ code", policy=policy)
result = await agent.review(context)
# Should use default model (gpt-4o for security)
assert mock_llm.calls[0]["model"] == "gpt-4o"
assert result.findings == []

561
tests/test_cli.py Normal file
View File

@@ -0,0 +1,561 @@
"""Tests for CLI commands."""
import json
from pathlib import Path
from unittest.mock import AsyncMock, patch
from typer.testing import CliRunner
from arbiter.cli import (
_severity_color,
_severity_icon,
_verdict_color,
_verdict_icon,
app,
)
from arbiter.deliberation import DeliberationResult
from arbiter.models import AgentName, Finding, ReviewResult, Severity, Verdict
runner = CliRunner()
def make_mock_return(
findings: list[Finding] | None = None, verdict: Verdict = Verdict.APPROVE
) -> tuple[list[ReviewResult], DeliberationResult]:
"""Create a mock return value for _run_review."""
agent_results = [
ReviewResult(
agent_name=AgentName.SECURITY,
findings=findings or [],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
]
deliberation_result = DeliberationResult(
findings=findings or [],
verdict=verdict,
verdict_confidence=0.9,
verdict_reasoning="Test reasoning",
total_findings=len(findings) if findings else 0,
)
return agent_results, deliberation_result
class TestVersionCommand:
def test_version_output(self) -> None:
result = runner.invoke(app, ["version"])
assert result.exit_code == 0
assert "arbiter" in result.output
assert "0.3.0" in result.output
class TestReviewCommand:
def test_file_not_found(self) -> None:
result = runner.invoke(app, ["review", "nonexistent.diff"])
assert result.exit_code == 1
assert "File not found" in result.output
def test_empty_diff_warning(self, tmp_path: Path) -> None:
diff_file = tmp_path / "empty.diff"
diff_file.write_text("")
result = runner.invoke(app, ["review", str(diff_file)])
assert result.exit_code == 0
assert "Empty diff" in result.output
def test_policy_not_found(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
result = runner.invoke(app, ["review", str(diff_file), "--policy", "nonexistent.yaml"])
assert result.exit_code == 1
assert "Policy file not found" in result.output
def test_reads_from_stdin(self) -> None:
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", "-"], input="+ added line\n")
assert result.exit_code == 0
def test_json_output_format(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
assert result.exit_code == 0
assert '"verdict"' in result.output
assert '"findings"' in result.output
def test_markdown_output_format(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "# Arbiter Review" in result.output
def test_loads_policy_file(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
policy_file = tmp_path / "policy.yaml"
policy_file.write_text("""
version: "1.0"
agents:
security:
enabled: true
style:
enabled: false
complexity:
enabled: false
""")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--policy", str(policy_file)])
assert result.exit_code == 0
# Verify policy was passed to _run_review
call_args = mock_run.call_args
policy = call_args[0][1] # Second positional arg is policy
assert len(policy.get_enabled_agents()) == 1
def test_model_override(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--model", "gpt-4o-mini"])
assert result.exit_code == 0
# Verify model was set in policy
call_args = mock_run.call_args
policy = call_args[0][1]
for config in policy.agents.values():
assert config.model == "gpt-4o-mini"
def test_static_analysis_flag(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--no-static-analysis"])
assert result.exit_code == 0
# Verify static_analysis was False
call_args = mock_run.call_args
assert call_args.kwargs.get("static_analysis") is False
class TestNoArgsHelp:
def test_no_args_shows_help(self) -> None:
result = runner.invoke(app, [])
assert result.exit_code == 0
assert "A multi-agent code review system" in result.output
class TestOutputFormatting:
def test_severity_color(self) -> None:
assert _severity_color(Severity.CRITICAL) == "red bold"
assert _severity_color(Severity.HIGH) == "red"
assert _severity_color(Severity.MEDIUM) == "yellow"
assert _severity_color(Severity.LOW) == "blue"
assert _severity_color(Severity.INFO) == "dim"
def test_severity_icon(self) -> None:
assert _severity_icon(Severity.CRITICAL) == "!!"
assert _severity_icon(Severity.HIGH) == "!"
assert _severity_icon(Severity.MEDIUM) == "*"
assert _severity_icon(Severity.LOW) == "-"
assert _severity_icon(Severity.INFO) == "i"
def test_verdict_color(self) -> None:
assert _verdict_color(Verdict.APPROVE) == "green"
assert _verdict_color(Verdict.COMMENT) == "yellow"
assert _verdict_color(Verdict.REQUEST_CHANGES) == "red"
def test_verdict_icon(self) -> None:
assert _verdict_icon(Verdict.APPROVE) == "[ok]"
assert _verdict_icon(Verdict.COMMENT) == "[..]"
assert _verdict_icon(Verdict.REQUEST_CHANGES) == "[!!]"
class TestRichOutput:
def test_rich_format_with_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
finding = Finding(
id="test-finding-1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input in query",
reasoning="Direct concatenation",
prompt_version="test-v1.0",
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return(findings=[finding])
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
def test_rich_format_no_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
assert "No issues found" in result.output
def test_rich_format_critical_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
finding = Finding(
id="test-finding-1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=10,
severity=Severity.CRITICAL,
confidence=0.95,
title="Critical Issue",
description="This is critical",
reasoning="Very bad",
suggestion="Fix it immediately",
prompt_version="test-v1.0",
)
deliberation = DeliberationResult(
findings=[finding],
verdict=Verdict.REQUEST_CHANGES,
verdict_confidence=0.95,
verdict_reasoning="Critical issue found",
total_findings=1,
critical_count=1,
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[finding],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
class TestMarkdownOutput:
def test_markdown_with_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
finding = Finding(
id="test-finding-1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input in query",
reasoning="Direct concatenation",
suggestion="Use parameterized queries",
prompt_version="test-v1.0",
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return(findings=[finding])
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "## Findings" in result.output
assert "SQL Injection" in result.output
def test_markdown_verdict_badges(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
for verdict in [Verdict.APPROVE, Verdict.COMMENT, Verdict.REQUEST_CHANGES]:
deliberation = DeliberationResult(
verdict=verdict,
verdict_confidence=0.9,
verdict_reasoning="Test",
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert verdict.value.upper() in result.output
class TestJsonOutput:
def test_json_with_conflicts(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
from arbiter.deliberation.conflicts import Conflict, ConflictNature
conflict = Conflict(
id="test-conflict",
finding_ids=["f1", "f2"],
nature=ConflictNature.TRADE_OFF,
description="Test conflict",
severity_weight=0.8,
)
deliberation = DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.7,
verdict_reasoning="Conflicts found",
conflicts=[conflict],
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
assert result.exit_code == 0
output = json.loads(result.output)
assert "conflicts" in output
assert len(output["conflicts"]) == 1
class TestWorkDirHandling:
def test_work_dir_option(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
work_dir = tmp_path / "src"
work_dir.mkdir()
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--work-dir", str(work_dir)])
assert result.exit_code == 0
call_args = mock_run.call_args
assert call_args.kwargs.get("work_dir") == work_dir.resolve()
class TestRichOutputWithConflicts:
def test_rich_conflicts(self, tmp_path: Path) -> None:
from arbiter.deliberation.conflicts import Conflict, ConflictNature
from arbiter.deliberation.synthesis import Resolution
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
conflict = Conflict(
id="test-conflict",
finding_ids=["f1", "f2"],
nature=ConflictNature.TRADE_OFF,
description="Security vs complexity trade-off",
severity_weight=0.8,
)
resolution = Resolution(
conflict_id="test-conflict",
decision="prefer_first",
reasoning="Security takes priority",
confidence=0.9,
)
deliberation = DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.7,
verdict_reasoning="Conflicts found",
conflicts=[conflict],
resolutions=[resolution],
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
class TestMarkdownOutputWithConflicts:
def test_markdown_conflicts(self, tmp_path: Path) -> None:
from arbiter.deliberation.conflicts import Conflict, ConflictNature
from arbiter.deliberation.synthesis import Resolution
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
conflict = Conflict(
id="test-conflict",
finding_ids=["f1", "f2"],
nature=ConflictNature.CONTRADICTORY,
description="Contradictory recommendations",
severity_weight=0.9,
)
resolution = Resolution(
conflict_id="test-conflict",
decision="merge",
reasoning="Both concerns addressed by combined fix",
merged_suggestion="Do both things",
confidence=0.85,
)
deliberation = DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.7,
verdict_reasoning="Conflicts found",
conflicts=[conflict],
resolutions=[resolution],
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "## Conflicts" in result.output
assert "Contradictory" in result.output
assert "Resolution" in result.output
def test_markdown_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
findings = [
Finding(
id="f1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="Security Issue",
description="Vulnerable code",
reasoning="Bad pattern",
suggestion="Fix it this way",
prompt_version="test-v1.0",
),
Finding(
id="f2",
agent=AgentName.STYLE,
file="test.py",
line_start=20,
line_end=25,
severity=Severity.LOW,
confidence=0.8,
title="Style Issue",
description="Could be cleaner",
reasoning="Convention",
prompt_version="test-v1.0",
),
]
deliberation = DeliberationResult(
findings=findings,
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Issues found",
total_findings=2,
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[findings[0]],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
),
ReviewResult(
agent_name=AgentName.STYLE,
findings=[findings[1]],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
),
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "Security Issue" in result.output
assert "Style Issue" in result.output
assert "Fix it this way" in result.output

313
tests/test_llm.py Normal file
View File

@@ -0,0 +1,313 @@
"""Tests for LLM client and prompts."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from arbiter.llm.client import LiteLLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
from tests.conftest import MockLLMClient
class TestLiteLLMClient:
def test_init_default_values(self) -> None:
client = LiteLLMClient()
assert client.timeout == 60
assert client.max_retries == 3
def test_init_custom_values(self) -> None:
client = LiteLLMClient(timeout=120, max_retries=5)
assert client.timeout == 120
assert client.max_retries == 5
@pytest.mark.asyncio
async def test_complete_returns_response(self) -> None:
client = LiteLLMClient()
# Mock the litellm.acompletion function
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.model = "gpt-4o"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.001
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.content == "Test response"
assert response.model == "gpt-4o"
assert response.tokens_in == 10
assert response.tokens_out == 5
assert response.cost_usd == 0.001
mock_acomp.assert_called_once_with(
model="gpt-4o",
messages=messages,
timeout=60,
num_retries=3,
)
@pytest.mark.asyncio
async def test_complete_handles_empty_content(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = None # None content
mock_response.model = "gpt-4o"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.usage.completion_tokens = 0
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.content == "" # Should be empty string, not None
@pytest.mark.asyncio
async def test_complete_handles_missing_usage(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.model = "gpt-4o"
mock_response.usage = None # No usage data
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.tokens_in == 0
assert response.tokens_out == 0
@pytest.mark.asyncio
async def test_complete_uses_fallback_model(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.model = None # No model in response
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.usage.completion_tokens = 3
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "claude-3-opus")
# Should use the passed model as fallback
assert response.model == "claude-3-opus"
class TestLLMResponse:
def test_response_creation(self) -> None:
response = LLMResponse(
content="Hello, world!",
model="gpt-4o",
tokens_in=10,
tokens_out=5,
cost_usd=0.001,
)
assert response.content == "Hello, world!"
assert response.model == "gpt-4o"
assert response.tokens_in == 10
assert response.tokens_out == 5
assert response.cost_usd == 0.001
class TestMockLLMClient:
@pytest.mark.asyncio
async def test_records_calls(self) -> None:
client = MockLLMClient()
messages = [{"role": "user", "content": "Hello"}]
await client.complete(messages, "gpt-4o")
assert len(client.calls) == 1
assert client.calls[0]["messages"] == messages
assert client.calls[0]["model"] == "gpt-4o"
@pytest.mark.asyncio
async def test_returns_canned_responses(self) -> None:
client = MockLLMClient(responses=["First", "Second"])
messages = [{"role": "user", "content": "Hello"}]
response1 = await client.complete(messages, "gpt-4o")
response2 = await client.complete(messages, "gpt-4o")
response3 = await client.complete(messages, "gpt-4o")
assert response1.content == "First"
assert response2.content == "Second"
assert response3.content == "" # Exhausted
@pytest.mark.asyncio
async def test_reset(self) -> None:
client = MockLLMClient(responses=["Hello"])
messages = [{"role": "user", "content": "Hi"}]
await client.complete(messages, "gpt-4o")
assert len(client.calls) == 1
client.reset()
assert len(client.calls) == 0
response = await client.complete(messages, "gpt-4o")
assert response.content == "Hello" # Responses reset too
class TestPromptTemplate:
def test_template_creation(self) -> None:
template = PromptTemplate(
name="security",
version="1.0",
content="Review: {{diff}}",
)
assert template.name == "security"
assert template.version == "1.0"
assert template.full_name == "security-v1.0"
def test_render_substitution(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="File: {{file}}\nDiff: {{diff}}",
)
result = template.render(file="test.py", diff="+ added line")
assert result == "File: test.py\nDiff: + added line"
def test_render_missing_variable(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="Value: {{value}}",
)
result = template.render()
assert result == "Value: {{value}}"
def test_render_multiple_occurrences(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="{{name}} and {{name}} again",
)
result = template.render(name="test")
assert result == "test and test again"
class TestPromptRegistry:
def test_get_template(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
registry = PromptRegistry(templates_dir)
template = registry.get("security", "1.0")
assert template.name == "security"
assert template.version == "1.0"
assert "{{diff}}" in template.content
def test_get_template_cached(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Content")
registry = PromptRegistry(templates_dir)
template1 = registry.get("security", "1.0")
template2 = registry.get("security", "1.0")
assert template1 is template2
def test_get_template_not_found(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
registry = PromptRegistry(templates_dir)
with pytest.raises(FileNotFoundError):
registry.get("missing", "1.0")
def test_list_templates(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Content")
(templates_dir / "style-v2.0.md").write_text("Content")
(templates_dir / "readme.md").write_text("Not a template")
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert len(templates) == 2
assert ("security", "1.0") in templates
assert ("style", "2.0") in templates
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert templates == []
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "missing"
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert templates == []
def test_clear_cache(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "test-v1.0.md").write_text("Original")
registry = PromptRegistry(templates_dir)
template1 = registry.get("test", "1.0")
assert template1.content == "Original"
# Modify file
(templates_dir / "test-v1.0.md").write_text("Modified")
# Still cached
template2 = registry.get("test", "1.0")
assert template2.content == "Original"
# Clear cache
registry.clear_cache()
# Now reads new content
template3 = registry.get("test", "1.0")
assert template3.content == "Modified"

224
tests/test_models.py Normal file
View File

@@ -0,0 +1,224 @@
"""Tests for data models."""
from pathlib import Path
import pytest
from arbiter.models import (
AgentConfig,
AgentName,
Finding,
Policy,
ReviewResult,
Severity,
Verdict,
)
class TestEnums:
def test_severity_values(self) -> None:
assert Severity.CRITICAL == "critical"
assert Severity.HIGH == "high"
assert Severity.MEDIUM == "medium"
assert Severity.LOW == "low"
assert Severity.INFO == "info"
def test_verdict_values(self) -> None:
assert Verdict.APPROVE == "approve"
assert Verdict.REQUEST_CHANGES == "request_changes"
assert Verdict.COMMENT == "comment"
def test_agent_name_values(self) -> None:
assert AgentName.SECURITY == "security"
assert AgentName.STYLE == "style"
assert AgentName.COMPLEXITY == "complexity"
def test_severity_from_string(self) -> None:
assert Severity("critical") == Severity.CRITICAL
assert Severity("high") == Severity.HIGH
class TestFinding:
def test_finding_creation(self) -> None:
finding = Finding(
id="test-123",
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input concatenated in SQL query",
reasoning="Allows attackers to execute arbitrary SQL",
suggestion="Use parameterized queries",
references=["https://owasp.org"],
prompt_version="security-v1.0",
)
assert finding.id == "test-123"
assert finding.agent == AgentName.SECURITY
assert finding.severity == Severity.HIGH
assert finding.confidence == 0.9
def test_finding_confidence_validation(self) -> None:
with pytest.raises(ValueError):
Finding(
id="test",
agent=AgentName.SECURITY,
file="test.py",
line_start=1,
line_end=1,
severity=Severity.INFO,
confidence=1.5,
title="Test",
description="Test",
reasoning="Test",
prompt_version="test-v1.0",
)
def test_finding_line_validation(self) -> None:
with pytest.raises(ValueError):
Finding(
id="test",
agent=AgentName.SECURITY,
file="test.py",
line_start=0,
line_end=1,
severity=Severity.INFO,
confidence=0.5,
title="Test",
description="Test",
reasoning="Test",
prompt_version="test-v1.0",
)
def test_finding_serialization(self) -> None:
finding = Finding(
id="test-123",
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="Test",
description="Test desc",
reasoning="Test reason",
prompt_version="security-v1.0",
)
data = finding.model_dump()
assert data["id"] == "test-123"
assert data["agent"] == "security"
assert data["severity"] == "high"
class TestAgentConfig:
def test_default_values(self) -> None:
config = AgentConfig()
assert config.enabled is True
assert config.model is None
assert config.severity_threshold == Severity.INFO
assert config.prompt_additions is None
def test_custom_values(self) -> None:
config = AgentConfig(
enabled=False,
model="gpt-4o",
severity_threshold=Severity.MEDIUM,
prompt_additions="Focus on auth",
)
assert config.enabled is False
assert config.model == "gpt-4o"
assert config.severity_threshold == Severity.MEDIUM
class TestPolicy:
def test_default_policy(self) -> None:
policy = Policy()
assert len(policy.agents) == 3
assert AgentName.SECURITY in policy.agents
assert AgentName.STYLE in policy.agents
assert AgentName.COMPLEXITY in policy.agents
def test_get_enabled_agents(self) -> None:
policy = Policy()
enabled = policy.get_enabled_agents()
assert len(enabled) == 3
assert AgentName.SECURITY in enabled
def test_get_enabled_agents_with_disabled(self) -> None:
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(enabled=True),
AgentName.STYLE: AgentConfig(enabled=False),
AgentName.COMPLEXITY: AgentConfig(enabled=True),
}
)
enabled = policy.get_enabled_agents()
assert len(enabled) == 2
assert AgentName.STYLE not in enabled
def test_load_from_yaml(self, tmp_path: Path) -> None:
policy_file = tmp_path / "policy.yaml"
policy_file.write_text("""
version: "1.1"
agents:
security:
enabled: true
model: gpt-4o
severity_threshold: high
style:
enabled: false
complexity:
enabled: true
""")
policy = Policy.load(policy_file)
assert policy.version == "1.1"
assert policy.agents[AgentName.SECURITY].model == "gpt-4o"
assert policy.agents[AgentName.SECURITY].severity_threshold == Severity.HIGH
assert policy.agents[AgentName.STYLE].enabled is False
def test_load_empty_yaml(self, tmp_path: Path) -> None:
policy_file = tmp_path / "policy.yaml"
policy_file.write_text("")
policy = Policy.load(policy_file)
assert policy.version == "1.0"
assert len(policy.agents) == 3
class TestReviewResult:
def test_review_result_creation(self) -> None:
result = ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=1000,
tokens_used=500,
cost_usd=0.01,
)
assert result.agent_name == AgentName.SECURITY
assert result.duration_ms == 1000
assert result.cost_usd == 0.01
def test_review_result_with_findings(self) -> None:
finding = Finding(
id="test-123",
agent=AgentName.SECURITY,
file="test.py",
line_start=1,
line_end=1,
severity=Severity.HIGH,
confidence=0.9,
title="Test",
description="Test",
reasoning="Test",
prompt_version="test-v1.0",
)
result = ReviewResult(
agent_name=AgentName.SECURITY,
findings=[finding],
duration_ms=1000,
tokens_used=500,
cost_usd=0.01,
)
assert len(result.findings) == 1
assert result.findings[0].id == "test-123"