feat(agents): implement agent framework and CLI
This commit is contained in:
111
pyproject.toml
Normal file
111
pyproject.toml
Normal file
@@ -0,0 +1,111 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "arbiter"
|
||||
version = "0.1.0"
|
||||
description = "A multi-agent code review system that shows its work"
|
||||
readme = "readme.md"
|
||||
requires-python = ">=3.12"
|
||||
license = "MIT"
|
||||
authors = [{ name = "Kai Chappell", email = "git@kschappell.com" }]
|
||||
keywords = ["code-review", "ai", "llm", "agents"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Software Development :: Quality Assurance",
|
||||
]
|
||||
dependencies = [
|
||||
"litellm>=1.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"typer>=0.9.0",
|
||||
"rich>=13.0.0",
|
||||
"pyyaml>=6.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-cov>=4.0.0",
|
||||
"ruff>=0.4.0",
|
||||
"mypy>=1.10.0",
|
||||
"types-PyYAML>=6.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
arbiter = "arbiter.cli:app"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/arbiter"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 100
|
||||
src = ["src", "tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
"TCH", # flake8-type-checking
|
||||
"PTH", # flake8-use-pathlib
|
||||
"RUF", # Ruff-specific rules
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["arbiter"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_ignores = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_configs = true
|
||||
show_error_codes = true
|
||||
files = ["src"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["litellm.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
addopts = "-ra -q"
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src/arbiter"]
|
||||
branch = true
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise NotImplementedError",
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
fail_under = 85
|
||||
3
src/arbiter/__init__.py
Normal file
3
src/arbiter/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Arbiter: A multi-agent code review system that shows its work."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
14
src/arbiter/agents/__init__.py
Normal file
14
src/arbiter/agents/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Review agents for code analysis."""
|
||||
|
||||
from arbiter.agents.base import Agent, ReviewContext
|
||||
from arbiter.agents.complexity import ComplexityAgent
|
||||
from arbiter.agents.security import SecurityAgent
|
||||
from arbiter.agents.style import StyleAgent
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"ComplexityAgent",
|
||||
"ReviewContext",
|
||||
"SecurityAgent",
|
||||
"StyleAgent",
|
||||
]
|
||||
277
src/arbiter/agents/base.py
Normal file
277
src/arbiter/agents/base.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Base agent class and review context."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.llm.client import LLMClient
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity
|
||||
|
||||
|
||||
class ReviewContext(BaseModel):
|
||||
"""Context for a code review."""
|
||||
|
||||
diff: str = Field(description="The diff content to review")
|
||||
policy: Policy = Field(description="Review policy configuration")
|
||||
file_path: str | None = Field(default=None, description="Path to the diff file")
|
||||
static_analysis_context: str = Field(
|
||||
default="No static analysis data available.",
|
||||
description="Static analysis results as formatted string",
|
||||
)
|
||||
|
||||
|
||||
class ExplainContext(BaseModel):
|
||||
"""Context for explaining a finding in a follow-up conversation."""
|
||||
|
||||
question: str = Field(description="The user's follow-up question")
|
||||
finding: Finding = Field(description="The finding being asked about")
|
||||
diff: str = Field(description="The relevant code diff")
|
||||
conversation_history: list[dict[str, str]] = Field(
|
||||
default_factory=list,
|
||||
description="Previous messages in this conversation",
|
||||
)
|
||||
|
||||
|
||||
class ExplainResult(BaseModel):
|
||||
"""Result from an agent's explain() method."""
|
||||
|
||||
response: str = Field(description="The explanation response text")
|
||||
tokens_used: int = Field(ge=0, description="Total tokens used for this response")
|
||||
cost_usd: float = Field(ge=0.0, description="Cost in USD for this response")
|
||||
confidence: float = Field(ge=0.0, le=1.0, description="Confidence in the explanation")
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
"""Abstract base class for review agents."""
|
||||
|
||||
name: AgentName
|
||||
prompt_name: str
|
||||
prompt_version: str = "1.0"
|
||||
default_model: str = "gpt-4o"
|
||||
explain_prompt_name: str = "" # Set by subclasses
|
||||
explain_prompt_version: str = "1.0"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
self.llm_client = llm_client
|
||||
self.prompt_registry = prompt_registry
|
||||
|
||||
async def review(self, context: ReviewContext) -> ReviewResult:
|
||||
"""Perform a code review.
|
||||
|
||||
Args:
|
||||
context: Review context with diff and policy.
|
||||
|
||||
Returns:
|
||||
ReviewResult with findings and metadata.
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
|
||||
# Get model from policy or use default
|
||||
agent_config = context.policy.agents.get(self.name)
|
||||
model = agent_config.model if agent_config and agent_config.model else self.default_model
|
||||
|
||||
# Get prompt additions from policy
|
||||
prompt_additions = ""
|
||||
if agent_config and agent_config.prompt_additions:
|
||||
prompt_additions = agent_config.prompt_additions
|
||||
|
||||
# Build and send messages
|
||||
messages = self._build_messages(context, prompt_additions)
|
||||
response = await self._call_llm(messages, model)
|
||||
|
||||
# Parse response into findings
|
||||
findings = self._parse_response(response.content, context)
|
||||
|
||||
# Filter by severity threshold
|
||||
if agent_config:
|
||||
findings = self._filter_by_severity(findings, agent_config.severity_threshold)
|
||||
|
||||
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||
|
||||
return ReviewResult(
|
||||
agent_name=self.name,
|
||||
findings=findings,
|
||||
duration_ms=duration_ms,
|
||||
tokens_used=response.tokens_in + response.tokens_out,
|
||||
cost_usd=response.cost_usd,
|
||||
)
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
context: ReviewContext,
|
||||
prompt_additions: str,
|
||||
) -> list[dict[str, str]]:
|
||||
template = self.prompt_registry.get(self.prompt_name, self.prompt_version)
|
||||
content = template.render(
|
||||
diff=context.diff,
|
||||
static_analysis_context=context.static_analysis_context,
|
||||
prompt_additions=prompt_additions,
|
||||
)
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
async def _call_llm(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
) -> Any:
|
||||
return await self.llm_client.complete(messages, model)
|
||||
|
||||
@abstractmethod
|
||||
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: ...
|
||||
|
||||
def _filter_by_severity(
|
||||
self,
|
||||
findings: list[Finding],
|
||||
threshold: Severity,
|
||||
) -> list[Finding]:
|
||||
"""Filter findings by severity threshold.
|
||||
|
||||
Args:
|
||||
findings: List of findings to filter.
|
||||
threshold: Minimum severity to include.
|
||||
|
||||
Returns:
|
||||
Filtered list of findings.
|
||||
"""
|
||||
severity_order = [
|
||||
Severity.CRITICAL,
|
||||
Severity.HIGH,
|
||||
Severity.MEDIUM,
|
||||
Severity.LOW,
|
||||
Severity.INFO,
|
||||
]
|
||||
threshold_index = severity_order.index(threshold)
|
||||
return [f for f in findings if severity_order.index(f.severity) <= threshold_index]
|
||||
|
||||
def _parse_json_findings(self, content: str, _context: ReviewContext) -> list[Finding]:
|
||||
"""Parse JSON array of findings from LLM response.
|
||||
|
||||
Args:
|
||||
content: Raw LLM response content.
|
||||
context: Review context.
|
||||
|
||||
Returns:
|
||||
List of Finding objects.
|
||||
"""
|
||||
# Extract JSON from response (handle markdown code blocks)
|
||||
json_content = content.strip()
|
||||
if json_content.startswith("```"):
|
||||
lines = json_content.split("\n")
|
||||
# Remove first and last lines (code block markers)
|
||||
json_lines = []
|
||||
in_block = False
|
||||
for line in lines:
|
||||
if line.startswith("```") and not in_block:
|
||||
in_block = True
|
||||
continue
|
||||
if line.startswith("```") and in_block:
|
||||
break
|
||||
if in_block:
|
||||
json_lines.append(line)
|
||||
json_content = "\n".join(json_lines)
|
||||
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
|
||||
findings = []
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
try:
|
||||
finding = Finding(
|
||||
id=str(uuid.uuid4()),
|
||||
agent=self.name,
|
||||
file=item.get("file", "unknown"),
|
||||
line_start=item.get("line_start", 1),
|
||||
line_end=item.get("line_end") or item.get("line_start", 1),
|
||||
severity=Severity(item.get("severity", "info")),
|
||||
confidence=float(item.get("confidence", 0.5)),
|
||||
title=item.get("title", "Untitled finding"),
|
||||
description=item.get("description", ""),
|
||||
reasoning=item.get("reasoning", ""),
|
||||
suggestion=item.get("suggestion"),
|
||||
references=item.get("references", []),
|
||||
prompt_version=f"{self.prompt_name}-v{self.prompt_version}",
|
||||
static_analysis_context=None,
|
||||
)
|
||||
findings.append(finding)
|
||||
except (ValueError, KeyError):
|
||||
continue
|
||||
|
||||
return findings
|
||||
|
||||
async def explain(
|
||||
self,
|
||||
context: ExplainContext,
|
||||
model: str | None = None,
|
||||
) -> ExplainResult:
|
||||
"""Provide a detailed explanation about a specific finding.
|
||||
|
||||
Args:
|
||||
context: The explanation context with question, finding, and history.
|
||||
model: Optional model override. Defaults to agent's default_model.
|
||||
|
||||
Returns:
|
||||
ExplainResult with response text and cost tracking.
|
||||
"""
|
||||
if not self.explain_prompt_name:
|
||||
# Derive from prompt_name if not set
|
||||
explain_name = f"{self.prompt_name}-explain"
|
||||
else:
|
||||
explain_name = self.explain_prompt_name
|
||||
|
||||
template = self.prompt_registry.get(explain_name, self.explain_prompt_version)
|
||||
|
||||
# Format conversation history
|
||||
history_text = ""
|
||||
if context.conversation_history:
|
||||
history_lines = []
|
||||
for msg in context.conversation_history:
|
||||
role = msg.get("role", "user").capitalize()
|
||||
content = msg.get("content", "")
|
||||
history_lines.append(f"**{role}:** {content}")
|
||||
history_text = "\n\n".join(history_lines)
|
||||
|
||||
# Build finding context
|
||||
finding_lines = (
|
||||
f"{context.finding.line_start}"
|
||||
if context.finding.line_start == context.finding.line_end
|
||||
else f"{context.finding.line_start}-{context.finding.line_end}"
|
||||
)
|
||||
|
||||
content = template.render(
|
||||
finding_title=context.finding.title,
|
||||
finding_file=context.finding.file,
|
||||
finding_lines=finding_lines,
|
||||
finding_description=context.finding.description,
|
||||
finding_reasoning=context.finding.reasoning,
|
||||
finding_severity=context.finding.severity.value,
|
||||
finding_suggestion=context.finding.suggestion or "No suggestion provided.",
|
||||
diff=context.diff,
|
||||
conversation_history=history_text or "No previous conversation.",
|
||||
question=context.question,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
response = await self._call_llm(messages, model or self.default_model)
|
||||
|
||||
return ExplainResult(
|
||||
response=response.content.strip(),
|
||||
tokens_used=response.tokens_in + response.tokens_out,
|
||||
cost_usd=response.cost_usd,
|
||||
confidence=0.8, # Default confidence for explanations
|
||||
)
|
||||
17
src/arbiter/agents/complexity.py
Normal file
17
src/arbiter/agents/complexity.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Complexity review agent."""
|
||||
|
||||
from arbiter.agents.base import Agent, ReviewContext
|
||||
from arbiter.models import AgentName, Finding
|
||||
|
||||
|
||||
class ComplexityAgent(Agent):
|
||||
"""Agent focused on code complexity and architecture."""
|
||||
|
||||
name = AgentName.COMPLEXITY
|
||||
prompt_name = "complexity"
|
||||
prompt_version = "1.0"
|
||||
explain_prompt_name = "complexity-explain"
|
||||
default_model = "gpt-4o-mini"
|
||||
|
||||
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
|
||||
return self._parse_json_findings(content, context)
|
||||
17
src/arbiter/agents/security.py
Normal file
17
src/arbiter/agents/security.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Security review agent."""
|
||||
|
||||
from arbiter.agents.base import Agent, ReviewContext
|
||||
from arbiter.models import AgentName, Finding
|
||||
|
||||
|
||||
class SecurityAgent(Agent):
|
||||
"""Agent focused on security vulnerabilities and risks."""
|
||||
|
||||
name = AgentName.SECURITY
|
||||
prompt_name = "security"
|
||||
prompt_version = "1.0"
|
||||
explain_prompt_name = "security-explain"
|
||||
default_model = "gpt-4o"
|
||||
|
||||
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
|
||||
return self._parse_json_findings(content, context)
|
||||
17
src/arbiter/agents/style.py
Normal file
17
src/arbiter/agents/style.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Style review agent."""
|
||||
|
||||
from arbiter.agents.base import Agent, ReviewContext
|
||||
from arbiter.models import AgentName, Finding
|
||||
|
||||
|
||||
class StyleAgent(Agent):
|
||||
"""Agent focused on code style, naming, and readability."""
|
||||
|
||||
name = AgentName.STYLE
|
||||
prompt_name = "style"
|
||||
prompt_version = "1.0"
|
||||
explain_prompt_name = "style-explain"
|
||||
default_model = "gpt-4o-mini"
|
||||
|
||||
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
|
||||
return self._parse_json_findings(content, context)
|
||||
407
src/arbiter/cli.py
Normal file
407
src/arbiter/cli.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""Command-line interface for Arbiter."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
|
||||
from arbiter.analysis import DiffParser, StaticAnalysisRunner
|
||||
from arbiter.config import get_settings
|
||||
from arbiter.deliberation import Coordinator, DeliberationResult
|
||||
from arbiter.llm import LiteLLMClient, PromptRegistry
|
||||
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity, Verdict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from arbiter.agents.base import Agent
|
||||
|
||||
app = typer.Typer(
|
||||
name="arbiter",
|
||||
help="A multi-agent code review system that shows its work.",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
console = Console()
|
||||
|
||||
|
||||
def _severity_color(severity: Severity) -> str:
|
||||
colors = {
|
||||
Severity.CRITICAL: "red bold",
|
||||
Severity.HIGH: "red",
|
||||
Severity.MEDIUM: "yellow",
|
||||
Severity.LOW: "blue",
|
||||
Severity.INFO: "dim",
|
||||
}
|
||||
return colors.get(severity, "white")
|
||||
|
||||
|
||||
def _severity_icon(severity: Severity) -> str:
|
||||
icons = {
|
||||
Severity.CRITICAL: "!!",
|
||||
Severity.HIGH: "!",
|
||||
Severity.MEDIUM: "*",
|
||||
Severity.LOW: "-",
|
||||
Severity.INFO: "i",
|
||||
}
|
||||
return icons.get(severity, " ")
|
||||
|
||||
|
||||
def _verdict_color(verdict: Verdict) -> str:
|
||||
colors = {
|
||||
Verdict.APPROVE: "green",
|
||||
Verdict.COMMENT: "yellow",
|
||||
Verdict.REQUEST_CHANGES: "red",
|
||||
}
|
||||
return colors.get(verdict, "white")
|
||||
|
||||
|
||||
def _verdict_icon(verdict: Verdict) -> str:
|
||||
icons = {
|
||||
Verdict.APPROVE: "[ok]",
|
||||
Verdict.COMMENT: "[..]",
|
||||
Verdict.REQUEST_CHANGES: "[!!]",
|
||||
}
|
||||
return icons.get(verdict, "")
|
||||
|
||||
|
||||
def _format_rich(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
|
||||
"""Format results using Rich for terminal output."""
|
||||
total_tokens = sum(r.tokens_used for r in agent_results) + result.tokens_used
|
||||
total_cost = sum(r.cost_usd for r in agent_results) + result.cost_usd
|
||||
|
||||
# Verdict panel
|
||||
verdict_style = _verdict_color(result.verdict)
|
||||
verdict_icon = _verdict_icon(result.verdict)
|
||||
verdict_text = f"[{verdict_style} bold]{verdict_icon} {result.verdict.value.upper()}[/]"
|
||||
verdict_text += f" (confidence: {result.verdict_confidence:.0%})"
|
||||
|
||||
summary = f"{verdict_text}\n\n"
|
||||
summary += f"[bold]{result.total_findings}[/bold] findings from [bold]{len(agent_results)}[/bold] agents"
|
||||
if result.conflicts:
|
||||
summary += f" | [yellow]{len(result.conflicts)}[/yellow] conflicts"
|
||||
summary += f"\nTokens: {total_tokens:,} | Cost: ${total_cost:.4f}"
|
||||
|
||||
console.print(Panel(summary, title="Arbiter Review", border_style="blue"))
|
||||
|
||||
if result.verdict_reasoning:
|
||||
console.print(f"\n[dim]Reason:[/dim] {result.verdict_reasoning}")
|
||||
|
||||
if result.total_findings == 0:
|
||||
console.print("\n[green]No issues found![/green]")
|
||||
return
|
||||
|
||||
# Findings table
|
||||
console.print("\n[bold]Findings[/bold]")
|
||||
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Sev", width=4)
|
||||
table.add_column("Agent", width=12)
|
||||
table.add_column("File", width=30)
|
||||
table.add_column("Line", width=8)
|
||||
table.add_column("Title", width=50)
|
||||
|
||||
for finding in result.findings:
|
||||
sev_icon = _severity_icon(finding.severity)
|
||||
sev_style = _severity_color(finding.severity)
|
||||
line_range = (
|
||||
f"{finding.line_start}-{finding.line_end}"
|
||||
if finding.line_start != finding.line_end
|
||||
else str(finding.line_start)
|
||||
)
|
||||
table.add_row(
|
||||
f"[{sev_style}]{sev_icon}[/{sev_style}]",
|
||||
finding.agent.value,
|
||||
finding.file,
|
||||
line_range,
|
||||
finding.title,
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Conflicts section
|
||||
if result.conflicts:
|
||||
console.print("\n[bold yellow]Conflicts[/bold yellow]")
|
||||
for i, conflict in enumerate(result.conflicts, 1):
|
||||
console.print(f"\n{i}. [{conflict.nature.value}] {conflict.description}")
|
||||
if result.resolutions:
|
||||
for resolution in result.resolutions:
|
||||
if resolution.conflict_id == conflict.id:
|
||||
console.print(f" [dim]Resolution:[/dim] {resolution.reasoning}")
|
||||
break
|
||||
|
||||
# Details for critical/high findings
|
||||
critical_high = [f for f in result.findings if f.severity in (Severity.CRITICAL, Severity.HIGH)]
|
||||
if critical_high:
|
||||
console.print("\n[bold]Details[/bold]")
|
||||
for finding in critical_high:
|
||||
console.print(
|
||||
f"\n[{_severity_color(finding.severity)}]{finding.title}[/]"
|
||||
f" ({finding.file}:{finding.line_start})"
|
||||
)
|
||||
console.print(f" {finding.description}")
|
||||
if finding.suggestion:
|
||||
console.print(f" [dim]Suggestion:[/dim] {finding.suggestion}")
|
||||
|
||||
# Deliberation log
|
||||
console.print("\n[bold dim]Deliberation Log[/bold dim]")
|
||||
for step in result.steps:
|
||||
console.print(f" [{step.step_type.value}] {step.description}")
|
||||
|
||||
|
||||
def _format_json(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
|
||||
"""Format results as JSON."""
|
||||
output = {
|
||||
"verdict": result.verdict.value,
|
||||
"verdict_confidence": result.verdict_confidence,
|
||||
"verdict_reasoning": result.verdict_reasoning,
|
||||
"findings": [f.model_dump() for f in result.findings],
|
||||
"conflicts": [c.model_dump() for c in result.conflicts],
|
||||
"resolutions": [r.model_dump() for r in result.resolutions],
|
||||
"deliberation_steps": [s.model_dump() for s in result.steps],
|
||||
"agents": [r.model_dump() for r in agent_results],
|
||||
"summary": {
|
||||
"total_findings": result.total_findings,
|
||||
"critical_count": result.critical_count,
|
||||
"high_count": result.high_count,
|
||||
"conflicts_count": len(result.conflicts),
|
||||
"total_tokens": sum(r.tokens_used for r in agent_results) + result.tokens_used,
|
||||
"total_cost_usd": sum(r.cost_usd for r in agent_results) + result.cost_usd,
|
||||
},
|
||||
}
|
||||
console.print(json.dumps(output, indent=2, default=str))
|
||||
|
||||
|
||||
def _format_markdown(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
|
||||
"""Format results as Markdown."""
|
||||
lines = ["# Arbiter Review", ""]
|
||||
|
||||
# Verdict
|
||||
verdict_badge = {
|
||||
Verdict.APPROVE: ":white_check_mark:",
|
||||
Verdict.COMMENT: ":speech_balloon:",
|
||||
Verdict.REQUEST_CHANGES: ":x:",
|
||||
}.get(result.verdict, "")
|
||||
|
||||
lines.append(f"## Verdict: {verdict_badge} {result.verdict.value.upper()}")
|
||||
lines.append("")
|
||||
lines.append(f"**Confidence:** {result.verdict_confidence:.0%}")
|
||||
if result.verdict_reasoning:
|
||||
lines.append(f"**Reason:** {result.verdict_reasoning}")
|
||||
lines.append("")
|
||||
|
||||
# Summary
|
||||
lines.append(f"**{result.total_findings} findings** from {len(agent_results)} agents")
|
||||
if result.conflicts:
|
||||
lines.append(f"**{len(result.conflicts)} conflicts** detected")
|
||||
lines.append("")
|
||||
|
||||
# Findings by severity
|
||||
if result.findings:
|
||||
lines.append("## Findings")
|
||||
lines.append("")
|
||||
|
||||
for finding in result.findings:
|
||||
severity_badge = f"**{finding.severity.value.upper()}**"
|
||||
lines.append(
|
||||
f"- {severity_badge} [{finding.agent.value}] `{finding.file}:{finding.line_start}` - {finding.title}"
|
||||
)
|
||||
lines.append(f" - {finding.description}")
|
||||
if finding.suggestion:
|
||||
lines.append(f" - *Suggestion:* {finding.suggestion}")
|
||||
lines.append("")
|
||||
|
||||
# Conflicts
|
||||
if result.conflicts:
|
||||
lines.append("## Conflicts")
|
||||
lines.append("")
|
||||
|
||||
for conflict in result.conflicts:
|
||||
lines.append(f"### {conflict.nature.value.title()}")
|
||||
lines.append(conflict.description)
|
||||
for resolution in result.resolutions:
|
||||
if resolution.conflict_id == conflict.id:
|
||||
lines.append(f"**Resolution:** {resolution.reasoning}")
|
||||
break
|
||||
lines.append("")
|
||||
|
||||
# Deliberation log
|
||||
lines.append("## Deliberation Log")
|
||||
lines.append("")
|
||||
for step in result.steps:
|
||||
lines.append(f"1. **{step.step_type.value}**: {step.description}")
|
||||
lines.append("")
|
||||
|
||||
console.print("\n".join(lines))
|
||||
|
||||
|
||||
async def _run_review(
|
||||
diff: str,
|
||||
policy: Policy,
|
||||
templates_dir: Path,
|
||||
static_analysis: bool = True,
|
||||
work_dir: Path | None = None,
|
||||
) -> tuple[list[ReviewResult], DeliberationResult]:
|
||||
"""Run all enabled agents on the diff with deliberation."""
|
||||
settings = get_settings()
|
||||
llm_client = LiteLLMClient(
|
||||
timeout=settings.llm_timeout,
|
||||
max_retries=settings.llm_max_retries,
|
||||
)
|
||||
prompt_registry = PromptRegistry(templates_dir)
|
||||
|
||||
# Parse the diff
|
||||
parser = DiffParser()
|
||||
parsed_diff = parser.parse(diff)
|
||||
|
||||
# Run static analysis if enabled
|
||||
static_findings: list[Finding] = []
|
||||
if static_analysis and work_dir:
|
||||
runner = StaticAnalysisRunner()
|
||||
static_result = await runner.run(parsed_diff, work_dir)
|
||||
for sf in static_result.findings:
|
||||
static_findings.append(runner.convert_to_finding(sf))
|
||||
|
||||
context = ReviewContext(
|
||||
diff=diff,
|
||||
policy=policy,
|
||||
static_analysis_context=f"Found {len(static_findings)} static analysis findings."
|
||||
if static_findings
|
||||
else "No static analysis data available.",
|
||||
)
|
||||
|
||||
# Create agents for enabled agent types
|
||||
agent_classes: dict[AgentName, type[Agent]] = {
|
||||
AgentName.SECURITY: SecurityAgent,
|
||||
AgentName.STYLE: StyleAgent,
|
||||
AgentName.COMPLEXITY: ComplexityAgent,
|
||||
}
|
||||
|
||||
agents: list[Agent] = []
|
||||
for agent_name in policy.get_enabled_agents():
|
||||
if agent_name in agent_classes:
|
||||
agents.append(agent_classes[agent_name](llm_client, prompt_registry))
|
||||
|
||||
# Run agents in parallel
|
||||
results = await asyncio.gather(
|
||||
*[agent.review(context) for agent in agents],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Filter out exceptions and log them
|
||||
valid_results: list[ReviewResult] = []
|
||||
for result in results:
|
||||
if isinstance(result, BaseException):
|
||||
console.print(f"[red]Agent error:[/red] {result}")
|
||||
else:
|
||||
valid_results.append(result)
|
||||
|
||||
# Run deliberation
|
||||
coordinator = Coordinator(llm_client=llm_client)
|
||||
deliberation_result = await coordinator.deliberate(
|
||||
valid_results,
|
||||
static_findings=static_findings if static_findings else None,
|
||||
)
|
||||
|
||||
return valid_results, deliberation_result
|
||||
|
||||
|
||||
@app.command()
|
||||
def review(
|
||||
diff_file: Annotated[
|
||||
str,
|
||||
typer.Argument(help="Path to diff file, or '-' for stdin"),
|
||||
],
|
||||
policy: Annotated[
|
||||
Path | None,
|
||||
typer.Option("--policy", "-p", help="Path to policy YAML file"),
|
||||
] = None,
|
||||
model: Annotated[
|
||||
str | None,
|
||||
typer.Option("--model", "-m", help="Override LLM model for all agents"),
|
||||
] = None,
|
||||
output_format: Annotated[
|
||||
str,
|
||||
typer.Option("--format", "-f", help="Output format: rich, json, markdown"),
|
||||
] = "rich",
|
||||
static_analysis: Annotated[
|
||||
bool,
|
||||
typer.Option("--static-analysis/--no-static-analysis", help="Run static analysis"),
|
||||
] = True,
|
||||
work_dir: Annotated[
|
||||
Path | None,
|
||||
typer.Option("--work-dir", "-w", help="Working directory for static analysis"),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Review a diff file using AI agents."""
|
||||
# Read diff content
|
||||
if diff_file == "-":
|
||||
diff_content = sys.stdin.read()
|
||||
else:
|
||||
diff_path = Path(diff_file)
|
||||
if not diff_path.exists():
|
||||
console.print(f"[red]Error:[/red] File not found: {diff_file}")
|
||||
raise typer.Exit(1)
|
||||
diff_content = diff_path.read_text()
|
||||
|
||||
if not diff_content.strip():
|
||||
console.print("[yellow]Warning:[/yellow] Empty diff provided")
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Load policy
|
||||
review_policy = Policy()
|
||||
if policy:
|
||||
if not policy.exists():
|
||||
console.print(f"[red]Error:[/red] Policy file not found: {policy}")
|
||||
raise typer.Exit(1)
|
||||
review_policy = Policy.load(policy)
|
||||
|
||||
# Apply model override
|
||||
if model:
|
||||
for agent_config in review_policy.agents.values():
|
||||
agent_config.model = model
|
||||
|
||||
# Determine prompts directory
|
||||
settings = get_settings()
|
||||
templates_dir = settings.templates_dir
|
||||
if not templates_dir.is_absolute():
|
||||
templates_dir = Path.cwd() / templates_dir
|
||||
|
||||
# Resolve work directory for static analysis
|
||||
resolved_work_dir = (
|
||||
(work_dir.resolve() if work_dir else Path.cwd()) if static_analysis else None
|
||||
)
|
||||
|
||||
# Run review
|
||||
agent_results, deliberation_result = asyncio.run(
|
||||
_run_review(
|
||||
diff_content,
|
||||
review_policy,
|
||||
templates_dir,
|
||||
static_analysis=static_analysis,
|
||||
work_dir=resolved_work_dir,
|
||||
)
|
||||
)
|
||||
|
||||
# Format output
|
||||
if output_format == "json":
|
||||
_format_json(deliberation_result, agent_results)
|
||||
elif output_format == "markdown":
|
||||
_format_markdown(deliberation_result, agent_results)
|
||||
else:
|
||||
_format_rich(deliberation_result, agent_results)
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None:
|
||||
from arbiter import __version__
|
||||
|
||||
console.print(f"arbiter {__version__}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
121
src/arbiter/config.py
Normal file
121
src/arbiter/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Configuration settings for Arbiter."""
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="ARBITER_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# LLM settings
|
||||
default_model: str = Field(default="gpt-4o", description="Default LLM model for agents")
|
||||
llm_timeout: int = Field(default=60, description="LLM request timeout in seconds")
|
||||
llm_max_retries: int = Field(default=3, description="Maximum LLM retry attempts")
|
||||
|
||||
# Cost controls
|
||||
max_tokens_per_review: int = Field(
|
||||
default=50000, description="Maximum tokens allowed per review"
|
||||
)
|
||||
max_cost_per_review_usd: float = Field(
|
||||
default=0.50, description="Maximum cost per review in USD"
|
||||
)
|
||||
cache_ttl_hours: int = Field(default=24, description="Cache TTL in hours")
|
||||
|
||||
# Paths
|
||||
templates_dir: Path = Field(
|
||||
default=Path("templates"), description="Directory containing prompt templates"
|
||||
)
|
||||
policy_path: Path | None = Field(default=None, description="Default policy file path")
|
||||
|
||||
# Output settings
|
||||
output_format: str = Field(default="rich", description="Output format: rich, json, or markdown")
|
||||
|
||||
# Database settings
|
||||
database_url: str = Field(
|
||||
default="postgresql+asyncpg://arbiter:arbiter@localhost:5432/arbiter",
|
||||
description="PostgreSQL connection URL",
|
||||
)
|
||||
database_pool_size: int = Field(default=5, description="Database connection pool size")
|
||||
database_max_overflow: int = Field(default=10, description="Max overflow connections")
|
||||
|
||||
# Redis settings
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
redis_max_connections: int = Field(default=10, description="Redis max connections")
|
||||
|
||||
# Webhook secrets
|
||||
github_webhook_secret: SecretStr | None = Field(
|
||||
default=None, description="GitHub webhook secret for HMAC verification"
|
||||
)
|
||||
gitlab_webhook_token: SecretStr | None = Field(
|
||||
default=None, description="GitLab webhook token for verification"
|
||||
)
|
||||
|
||||
# Platform integration settings
|
||||
github_token: SecretStr | None = Field(
|
||||
default=None, description="GitHub API token for fetching diffs and posting comments"
|
||||
)
|
||||
github_base_url: str = Field(
|
||||
default="https://api.github.com", description="GitHub API base URL"
|
||||
)
|
||||
gitlab_token: SecretStr | None = Field(
|
||||
default=None, description="GitLab API token for fetching diffs and posting comments"
|
||||
)
|
||||
gitlab_base_url: str = Field(
|
||||
default="https://gitlab.com", description="GitLab instance base URL"
|
||||
)
|
||||
integration_timeout: int = Field(
|
||||
default=30, description="Integration API request timeout in seconds"
|
||||
)
|
||||
integration_max_retries: int = Field(
|
||||
default=3, description="Maximum retry attempts for integration API calls"
|
||||
)
|
||||
status_check_context: str = Field(
|
||||
default="arbiter", description="Context name for commit status checks"
|
||||
)
|
||||
post_comments: bool = Field(default=True, description="Whether to post review comments on PRs")
|
||||
update_status: bool = Field(default=True, description="Whether to update commit status checks")
|
||||
|
||||
# API settings
|
||||
api_title: str = Field(default="Arbiter API", description="API title for OpenAPI")
|
||||
api_version: str = Field(default="0.5.0", description="API version")
|
||||
cors_origins: list[str] = Field(
|
||||
default=["http://localhost:3000"], description="Allowed CORS origins"
|
||||
)
|
||||
api_rate_limit_per_minute: int = Field(
|
||||
default=60, description="API rate limit per minute per client"
|
||||
)
|
||||
|
||||
# Worker settings
|
||||
worker_max_jobs: int = Field(default=10, description="Max concurrent worker jobs")
|
||||
worker_job_timeout: int = Field(default=300, description="Job timeout in seconds")
|
||||
worker_retry_attempts: int = Field(default=3, description="Number of retry attempts")
|
||||
|
||||
# Follow-up conversation settings
|
||||
followup_enabled: bool = Field(default=True, description="Enable follow-up question handling")
|
||||
followup_confidence_threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence to process a follow-up question",
|
||||
)
|
||||
followup_max_tokens_per_response: int = Field(
|
||||
default=2000, description="Maximum tokens per follow-up response"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
12
src/arbiter/llm/__init__.py
Normal file
12
src/arbiter/llm/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""LLM client and prompt management."""
|
||||
|
||||
from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||
|
||||
__all__ = [
|
||||
"LLMClient",
|
||||
"LLMResponse",
|
||||
"LiteLLMClient",
|
||||
"PromptRegistry",
|
||||
"PromptTemplate",
|
||||
]
|
||||
66
src/arbiter/llm/client.py
Normal file
66
src/arbiter/llm/client.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""LLM client abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LLMResponse(BaseModel):
|
||||
"""Response from an LLM completion."""
|
||||
|
||||
content: str = Field(description="The generated text content")
|
||||
model: str = Field(description="Model that generated the response")
|
||||
tokens_in: int = Field(ge=0, description="Number of input tokens")
|
||||
tokens_out: int = Field(ge=0, description="Number of output tokens")
|
||||
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""Abstract base class for LLM clients."""
|
||||
|
||||
@abstractmethod
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse: ...
|
||||
|
||||
|
||||
class LiteLLMClient(LLMClient):
|
||||
"""LLM client implementation using LiteLLM."""
|
||||
|
||||
def __init__(self, timeout: int = 60, max_retries: int = 3) -> None:
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
timeout=self.timeout,
|
||||
num_retries=self.max_retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
usage = response.usage
|
||||
|
||||
# Calculate cost using litellm's cost tracking
|
||||
cost = litellm.completion_cost(completion_response=response)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=response.model or model,
|
||||
tokens_in=usage.prompt_tokens if usage else 0,
|
||||
tokens_out=usage.completion_tokens if usage else 0,
|
||||
cost_usd=cost,
|
||||
)
|
||||
81
src/arbiter/llm/prompts.py
Normal file
81
src/arbiter/llm/prompts.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Prompt template management."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
"""A versioned prompt template."""
|
||||
|
||||
name: str = Field(description="Template name (e.g., 'security')")
|
||||
version: str = Field(description="Semantic version (e.g., '1.0')")
|
||||
content: str = Field(description="Raw template content with placeholders")
|
||||
|
||||
def render(self, **kwargs: Any) -> str:
|
||||
result = self.content
|
||||
for key, value in kwargs.items():
|
||||
placeholder = f"{{{{{key}}}}}"
|
||||
result = result.replace(placeholder, str(value))
|
||||
return result
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.name}-v{self.version}"
|
||||
|
||||
|
||||
class PromptRegistry:
|
||||
"""Registry for loading and caching prompt templates."""
|
||||
|
||||
_cache: ClassVar[dict[str, PromptTemplate]] = {}
|
||||
_pattern: ClassVar[re.Pattern[str]] = re.compile(r"^(.+)-v(\d+\.\d+)\.md$")
|
||||
|
||||
def __init__(self, templates_dir: Path) -> None:
|
||||
self.templates_dir = templates_dir
|
||||
|
||||
def get(self, name: str, version: str) -> PromptTemplate:
|
||||
"""Get a prompt template by name and version.
|
||||
|
||||
Args:
|
||||
name: Template name (e.g., 'security').
|
||||
version: Template version (e.g., '1.0').
|
||||
|
||||
Returns:
|
||||
PromptTemplate instance.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the template file doesn't exist.
|
||||
"""
|
||||
cache_key = f"{name}-v{version}"
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
file_path = self.templates_dir / f"{name}-v{version}.md"
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Prompt template not found: {file_path}")
|
||||
|
||||
content = file_path.read_text()
|
||||
template = PromptTemplate(name=name, version=version, content=content)
|
||||
self._cache[cache_key] = template
|
||||
return template
|
||||
|
||||
def list_templates(self) -> list[tuple[str, str]]:
|
||||
"""List all available templates.
|
||||
|
||||
Returns:
|
||||
List of (name, version) tuples.
|
||||
"""
|
||||
templates: list[tuple[str, str]] = []
|
||||
if not self.templates_dir.exists():
|
||||
return templates
|
||||
|
||||
for file_path in self.templates_dir.glob("*.md"):
|
||||
match = self._pattern.match(file_path.name)
|
||||
if match:
|
||||
templates.append((match.group(1), match.group(2)))
|
||||
return sorted(templates)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
16
src/arbiter/models/__init__.py
Normal file
16
src/arbiter/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Arbiter data models."""
|
||||
|
||||
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||
from arbiter.models.finding import Finding
|
||||
from arbiter.models.policy import AgentConfig, Policy
|
||||
from arbiter.models.review import ReviewResult
|
||||
|
||||
__all__ = [
|
||||
"AgentConfig",
|
||||
"AgentName",
|
||||
"Finding",
|
||||
"Policy",
|
||||
"ReviewResult",
|
||||
"Severity",
|
||||
"Verdict",
|
||||
]
|
||||
29
src/arbiter/models/enums.py
Normal file
29
src/arbiter/models/enums.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Core enumerations for Arbiter."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class Severity(StrEnum):
|
||||
"""Finding severity levels."""
|
||||
|
||||
CRITICAL = "critical"
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
INFO = "info"
|
||||
|
||||
|
||||
class Verdict(StrEnum):
|
||||
"""Review verdict types."""
|
||||
|
||||
APPROVE = "approve"
|
||||
REQUEST_CHANGES = "request_changes"
|
||||
COMMENT = "comment"
|
||||
|
||||
|
||||
class AgentName(StrEnum):
|
||||
"""Available review agent names."""
|
||||
|
||||
SECURITY = "security"
|
||||
STYLE = "style"
|
||||
COMPLEXITY = "complexity"
|
||||
28
src/arbiter/models/finding.py
Normal file
28
src/arbiter/models/finding.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Finding model for review results."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.models.enums import AgentName, Severity
|
||||
|
||||
|
||||
class Finding(BaseModel):
|
||||
"""A single finding from an agent's review."""
|
||||
|
||||
id: str = Field(description="Unique identifier for this finding")
|
||||
agent: AgentName = Field(description="Agent that produced this finding")
|
||||
file: str = Field(description="File path where the issue was found")
|
||||
line_start: int = Field(ge=1, description="Starting line number")
|
||||
line_end: int = Field(ge=1, description="Ending line number")
|
||||
severity: Severity = Field(description="Severity level of the finding")
|
||||
confidence: float = Field(ge=0.0, le=1.0, description="Confidence score (0.0-1.0)")
|
||||
title: str = Field(description="Short title summarizing the finding")
|
||||
description: str = Field(description="Detailed description of the issue")
|
||||
reasoning: str = Field(description="Explanation of why this was flagged")
|
||||
suggestion: str | None = Field(default=None, description="Optional fix recommendation")
|
||||
references: list[str] = Field(default_factory=list, description="Links to docs, OWASP, etc.")
|
||||
prompt_version: str = Field(description="Version of the prompt that produced this finding")
|
||||
static_analysis_context: dict[str, Any] | None = Field(
|
||||
default=None, description="Related static analysis data"
|
||||
)
|
||||
58
src/arbiter/models/policy.py
Normal file
58
src/arbiter/models/policy.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Policy configuration models."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.models.enums import AgentName, Severity
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Configuration for a single agent."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether the agent is enabled")
|
||||
model: str | None = Field(default=None, description="LLM model override for this agent")
|
||||
severity_threshold: Severity = Field(
|
||||
default=Severity.INFO, description="Minimum severity to report"
|
||||
)
|
||||
prompt_additions: str | None = Field(
|
||||
default=None, description="Additional instructions to append to the prompt"
|
||||
)
|
||||
|
||||
|
||||
class Policy(BaseModel):
|
||||
"""Review policy configuration."""
|
||||
|
||||
version: str = Field(default="1.0", description="Policy schema version")
|
||||
agents: dict[AgentName, AgentConfig] = Field(
|
||||
default_factory=lambda: {
|
||||
AgentName.SECURITY: AgentConfig(),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
},
|
||||
description="Agent configurations",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> Self:
|
||||
"""Load policy from a YAML file."""
|
||||
with path.open() as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
return cls()
|
||||
|
||||
# Convert string keys to AgentName enums
|
||||
if "agents" in data and isinstance(data["agents"], dict):
|
||||
agents = {}
|
||||
for key, value in data["agents"].items():
|
||||
agent_name = AgentName(key) if isinstance(key, str) else key
|
||||
agents[agent_name] = AgentConfig(**value) if isinstance(value, dict) else value
|
||||
data["agents"] = agents
|
||||
|
||||
return cls(**data)
|
||||
|
||||
def get_enabled_agents(self) -> list[AgentName]:
|
||||
return [name for name, config in self.agents.items() if config.enabled]
|
||||
16
src/arbiter/models/review.py
Normal file
16
src/arbiter/models/review.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Review result models."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.models.enums import AgentName
|
||||
from arbiter.models.finding import Finding
|
||||
|
||||
|
||||
class ReviewResult(BaseModel):
|
||||
"""Result from a single agent's review."""
|
||||
|
||||
agent_name: AgentName = Field(description="Name of the agent that performed the review")
|
||||
findings: list[Finding] = Field(default_factory=list, description="Findings from the review")
|
||||
duration_ms: int = Field(ge=0, description="Time taken in milliseconds")
|
||||
tokens_used: int = Field(ge=0, description="Total tokens consumed")
|
||||
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
|
||||
0
src/arbiter/py.typed
Normal file
0
src/arbiter/py.typed
Normal file
56
templates/complexity-v1.0.md
Normal file
56
templates/complexity-v1.0.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Complexity Review Agent
|
||||
|
||||
You are a code complexity reviewer focused on architecture and maintainability. Analyze the provided diff for complexity issues that impact long-term code health.
|
||||
|
||||
## Focus Areas
|
||||
|
||||
- **Cyclomatic complexity**: Functions with too many branches or paths
|
||||
- **Cognitive complexity**: Code that is hard to understand or follow
|
||||
- **Function length**: Functions doing too many things
|
||||
- **Class design**: God objects, tight coupling, missing abstractions
|
||||
- **Dependency management**: Circular dependencies, excessive coupling
|
||||
- **Over-engineering**: Unnecessary abstractions, premature optimization
|
||||
- **Under-engineering**: Missing error handling, ignored edge cases
|
||||
|
||||
## Context
|
||||
|
||||
{{static_analysis_context}}
|
||||
|
||||
## Diff to Review
|
||||
|
||||
```diff
|
||||
{{diff}}
|
||||
```
|
||||
|
||||
{{prompt_additions}}
|
||||
|
||||
## Output Format
|
||||
|
||||
Respond with a JSON array of findings. Each finding must have this structure:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"file": "path/to/file.py",
|
||||
"line_start": 10,
|
||||
"line_end": 50,
|
||||
"severity": "critical|high|medium|low|info",
|
||||
"confidence": 0.80,
|
||||
"title": "Short title describing the issue",
|
||||
"description": "Detailed description of the complexity concern",
|
||||
"reasoning": "Why this complexity is problematic",
|
||||
"suggestion": "How to simplify or refactor (optional)",
|
||||
"references": []
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
If no complexity issues are found, return an empty array: `[]`
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. Consider context - complex code may be justified for complex problems
|
||||
2. Flag over-engineering as readily as under-engineering
|
||||
3. High complexity is only critical if it's likely to cause bugs
|
||||
4. Suggest specific refactoring strategies when possible
|
||||
5. Reference static analysis metrics (cyclomatic complexity, etc.) when available
|
||||
55
templates/security-v1.0.md
Normal file
55
templates/security-v1.0.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Security Review Agent
|
||||
|
||||
You are a security-focused code reviewer. Analyze the provided diff for security vulnerabilities and potential risks.
|
||||
|
||||
## Focus Areas
|
||||
|
||||
- **Injection vulnerabilities**: SQL injection, command injection, XSS, template injection
|
||||
- **Authentication/Authorization**: Missing auth checks, privilege escalation, insecure session handling
|
||||
- **Data exposure**: Hardcoded secrets, PII leaks, sensitive data in logs
|
||||
- **Cryptographic issues**: Weak algorithms, improper key management, missing encryption
|
||||
- **Input validation**: Missing or insufficient validation, type confusion
|
||||
- **OWASP Top 10**: All categories including broken access control, security misconfiguration
|
||||
|
||||
## Context
|
||||
|
||||
{{static_analysis_context}}
|
||||
|
||||
## Diff to Review
|
||||
|
||||
```diff
|
||||
{{diff}}
|
||||
```
|
||||
|
||||
{{prompt_additions}}
|
||||
|
||||
## Output Format
|
||||
|
||||
Respond with a JSON array of findings. Each finding must have this structure:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"file": "path/to/file.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "critical|high|medium|low|info",
|
||||
"confidence": 0.95,
|
||||
"title": "Short title describing the issue",
|
||||
"description": "Detailed description of the vulnerability",
|
||||
"reasoning": "Why this is a security concern",
|
||||
"suggestion": "How to fix this issue (optional)",
|
||||
"references": ["https://owasp.org/..."]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
If no security issues are found, return an empty array: `[]`
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. Only report genuine security concerns, not style or performance issues
|
||||
2. Assign appropriate severity based on exploitability and impact
|
||||
3. Set confidence based on how certain you are this is a real vulnerability
|
||||
4. Provide actionable suggestions when possible
|
||||
5. Include relevant OWASP or CWE references
|
||||
55
templates/style-v1.0.md
Normal file
55
templates/style-v1.0.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Style Review Agent
|
||||
|
||||
You are a code style reviewer focused on readability and consistency. Analyze the provided diff for style issues that impact code maintainability.
|
||||
|
||||
## Focus Areas
|
||||
|
||||
- **Naming conventions**: Variable, function, class, and file naming consistency
|
||||
- **Code organisation**: Logical grouping, import ordering, module structure
|
||||
- **Readability**: Clear variable names, appropriate comments, self-documenting code
|
||||
- **Consistency**: Adherence to existing patterns in the codebase
|
||||
- **Best practices**: Language-specific idioms and conventions
|
||||
- **Documentation**: Missing or outdated docstrings, misleading comments
|
||||
|
||||
## Context
|
||||
|
||||
{{static_analysis_context}}
|
||||
|
||||
## Diff to Review
|
||||
|
||||
```diff
|
||||
{{diff}}
|
||||
```
|
||||
|
||||
{{prompt_additions}}
|
||||
|
||||
## Output Format
|
||||
|
||||
Respond with a JSON array of findings. Each finding must have this structure:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"file": "path/to/file.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "critical|high|medium|low|info",
|
||||
"confidence": 0.85,
|
||||
"title": "Short title describing the issue",
|
||||
"description": "Detailed description of the style concern",
|
||||
"reasoning": "Why this matters for maintainability",
|
||||
"suggestion": "How to improve this (optional)",
|
||||
"references": []
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
If no style issues are found, return an empty array: `[]`
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. Focus on readability and maintainability, not personal preferences
|
||||
2. Respect existing codebase conventions even if they differ from common standards
|
||||
3. Most style issues should be low or info severity unless they significantly impact readability
|
||||
4. Only flag high severity for style issues that could cause confusion or bugs
|
||||
5. Provide concrete suggestions with example code when helpful
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Arbiter test suite."""
|
||||
490
tests/conftest.py
Normal file
490
tests/conftest.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""Pytest configuration and fixtures."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from arbiter.db.models import Base, FindingModel, ReviewModel
|
||||
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
|
||||
from arbiter.llm.client import LLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import Policy
|
||||
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||
|
||||
|
||||
class MockLLMClient(LLMClient):
|
||||
"""Mock LLM client for testing."""
|
||||
|
||||
def __init__(self, responses: list[str] | None = None) -> None:
|
||||
self.responses = responses or []
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._call_index = 0
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Record the call and return a canned response."""
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
content = ""
|
||||
if self._call_index < len(self.responses):
|
||||
content = self.responses[self._call_index]
|
||||
self._call_index += 1
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.calls = []
|
||||
self._call_index = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MockLLMClient:
|
||||
return MockLLMClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_with_findings() -> MockLLMClient:
|
||||
response = """```json
|
||||
[
|
||||
{
|
||||
"file": "src/auth.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "high",
|
||||
"confidence": 0.9,
|
||||
"title": "SQL Injection vulnerability",
|
||||
"description": "User input is directly concatenated into SQL query",
|
||||
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||
"suggestion": "Use parameterized queries instead",
|
||||
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
|
||||
}
|
||||
]
|
||||
```"""
|
||||
return MockLLMClient(responses=[response])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_diff() -> str:
|
||||
return """diff --git a/src/auth.py b/src/auth.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/auth.py
|
||||
+++ b/src/auth.py
|
||||
@@ -8,6 +8,12 @@ def authenticate(username, password):
|
||||
if not username or not password:
|
||||
return None
|
||||
|
||||
+ # Check user in database
|
||||
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||
+ cursor.execute(query)
|
||||
+ user = cursor.fetchone()
|
||||
+ if user and user.password == password:
|
||||
+ return user
|
||||
return None
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_policy() -> Policy:
|
||||
return Policy()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_registry(tmp_path: Path) -> PromptRegistry:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
(templates_dir / "security-v1.0.md").write_text(
|
||||
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
(templates_dir / "style-v1.0.md").write_text(
|
||||
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
(templates_dir / "complexity-v1.0.md").write_text(
|
||||
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
|
||||
return PromptRegistry(templates_dir)
|
||||
|
||||
|
||||
# Database fixtures for integration tests
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
session_factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||
from arbiter.api.deps import get_db, get_redis
|
||||
from arbiter.main import app
|
||||
|
||||
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
|
||||
yield MockRedis()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_redis] = override_get_redis
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class MockRedis:
|
||||
"""Mock Redis client for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, Any] = {}
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._data.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||
self._data[key] = value
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def ping(self) -> bool:
|
||||
return True
|
||||
|
||||
async def llen(self, key: str) -> int:
|
||||
data = self._data.get(key, [])
|
||||
return len(data) if isinstance(data, list) else 0
|
||||
|
||||
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
|
||||
return type("Job", (), {"job_id": "test-job-id"})()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis() -> MockRedis:
|
||||
return MockRedis()
|
||||
|
||||
|
||||
class MockPlatformClient(PlatformClient):
|
||||
"""Mock platform client for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._comments: list[Comment] = []
|
||||
self._posted_comments: list[dict[str, Any]] = []
|
||||
self._status_updates: list[dict[str, Any]] = []
|
||||
self._diff = "mock diff content"
|
||||
self._closed = False
|
||||
self._fail_on: set[str] = set() # Methods to fail on
|
||||
|
||||
@property
|
||||
def platform(self) -> Platform:
|
||||
return Platform.GITHUB
|
||||
|
||||
async def get_pr_diff(
|
||||
self,
|
||||
repository: str, # noqa: ARG002
|
||||
pr_number: int, # noqa: ARG002
|
||||
) -> str:
|
||||
if "get_pr_diff" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_pr_diff")
|
||||
return self._diff
|
||||
|
||||
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
||||
if "post_comment" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: post_comment")
|
||||
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
|
||||
self._posted_comments.append(
|
||||
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
|
||||
)
|
||||
return comment_url
|
||||
|
||||
async def update_commit_status(
|
||||
self,
|
||||
repository: str,
|
||||
sha: str,
|
||||
status: CommitStatus,
|
||||
description: str,
|
||||
context: str,
|
||||
target_url: str | None = None,
|
||||
) -> None:
|
||||
if "update_commit_status" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: update_commit_status")
|
||||
self._status_updates.append(
|
||||
{
|
||||
"repository": repository,
|
||||
"sha": sha,
|
||||
"status": status,
|
||||
"description": description,
|
||||
"context": context,
|
||||
"target_url": target_url,
|
||||
}
|
||||
)
|
||||
|
||||
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
|
||||
if "get_pr_info" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_pr_info")
|
||||
from arbiter.integrations.base import PullRequestInfo
|
||||
|
||||
return PullRequestInfo(
|
||||
platform=Platform.GITHUB,
|
||||
repository=repository,
|
||||
pr_number=pr_number,
|
||||
head_sha="abc123",
|
||||
base_sha="def456",
|
||||
head_ref="feature",
|
||||
base_ref="main",
|
||||
title="Test PR",
|
||||
author="testuser",
|
||||
url=f"https://github.com/{repository}/pull/{pr_number}",
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
async def get_comments(
|
||||
self,
|
||||
repository: str, # noqa: ARG002
|
||||
pr_number: int, # noqa: ARG002
|
||||
) -> list[Comment]:
|
||||
if "get_comments" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_comments")
|
||||
return self._comments
|
||||
|
||||
async def update_comment(
|
||||
self, repository: str, pr_number: int, comment_id: str, body: str
|
||||
) -> str:
|
||||
if "update_comment" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: update_comment")
|
||||
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
|
||||
self._posted_comments.append(
|
||||
{
|
||||
"repository": repository,
|
||||
"pr_number": pr_number,
|
||||
"comment_id": comment_id,
|
||||
"body": body,
|
||||
"url": comment_url,
|
||||
}
|
||||
)
|
||||
return comment_url
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_platform_client() -> MockPlatformClient:
|
||||
return MockPlatformClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
|
||||
review = ReviewModel(
|
||||
id=str(uuid4()),
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
pr_title="Test PR",
|
||||
base_sha="abc1234567890",
|
||||
head_sha="def0987654321",
|
||||
author="testuser",
|
||||
is_draft=False,
|
||||
status="completed",
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.75,
|
||||
verdict_reasoning="Found some issues to discuss",
|
||||
total_tokens=1500,
|
||||
total_cost_usd=0.015,
|
||||
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
|
||||
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
|
||||
created_at=datetime.now(UTC),
|
||||
started_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
db_session.add(review)
|
||||
|
||||
# Add findings
|
||||
finding1 = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection vulnerability",
|
||||
description="User input is directly concatenated into SQL query",
|
||||
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||
suggestion="Use parameterized queries instead",
|
||||
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
finding2 = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.STYLE,
|
||||
file="src/auth.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.MEDIUM,
|
||||
confidence=0.85,
|
||||
title="Long function",
|
||||
description="Function exceeds 50 lines",
|
||||
reasoning="Long functions are harder to test and maintain",
|
||||
suggestion="Consider breaking into smaller functions",
|
||||
references=[],
|
||||
prompt_version="style-v1.0",
|
||||
)
|
||||
db_session.add(finding1)
|
||||
db_session.add(finding2)
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(review)
|
||||
return review
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
|
||||
reviews = []
|
||||
for i in range(5):
|
||||
review = ReviewModel(
|
||||
id=str(uuid4()),
|
||||
repository="owner/repo" if i < 3 else "other/repo",
|
||||
pr_number=i + 1,
|
||||
pr_title=f"Test PR #{i + 1}",
|
||||
base_sha=f"base{i:040d}",
|
||||
head_sha=f"head{i:040d}",
|
||||
author="testuser" if i % 2 == 0 else "otheruser",
|
||||
is_draft=False,
|
||||
status="completed" if i < 4 else "failed",
|
||||
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
|
||||
verdict_confidence=0.8 if i < 4 else None,
|
||||
total_tokens=1000 * (i + 1),
|
||||
total_cost_usd=0.01 * (i + 1),
|
||||
created_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC) if i < 4 else None,
|
||||
)
|
||||
db_session.add(review)
|
||||
reviews.append(review)
|
||||
|
||||
# Add a finding to each completed review
|
||||
if i < 4:
|
||||
finding = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title=f"Finding {i + 1}",
|
||||
description="Test finding",
|
||||
reasoning="Test reasoning",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
db_session.add(finding)
|
||||
|
||||
await db_session.commit()
|
||||
for review in reviews:
|
||||
await db_session.refresh(review)
|
||||
return reviews
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> Any:
|
||||
class MockSecretStr:
|
||||
def __init__(self, value: str) -> None:
|
||||
self._value = value
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
return self._value
|
||||
|
||||
class MockSettings:
|
||||
github_token = MockSecretStr("ghp_test_token")
|
||||
gitlab_token = MockSecretStr("glpat_test_token")
|
||||
github_base_url = "https://api.github.com"
|
||||
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||
github_webhook_secret = MockSecretStr("webhook_secret")
|
||||
gitlab_webhook_token = MockSecretStr("gitlab_token")
|
||||
integration_timeout = 30
|
||||
integration_max_retries = 3
|
||||
llm_timeout = 60
|
||||
llm_max_retries = 3
|
||||
post_comments = True
|
||||
update_status = True
|
||||
status_check_context = "arbiter"
|
||||
templates_dir = Path("templates")
|
||||
followup_enabled = True
|
||||
followup_confidence_threshold = 0.5
|
||||
|
||||
return MockSettings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_no_github() -> Any:
|
||||
class MockSettings:
|
||||
github_token = None
|
||||
gitlab_token = None
|
||||
github_base_url = "https://api.github.com"
|
||||
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||
integration_timeout = 30
|
||||
integration_max_retries = 3
|
||||
post_comments = False
|
||||
update_status = False
|
||||
status_check_context = "arbiter"
|
||||
templates_dir = Path("templates")
|
||||
|
||||
return MockSettings()
|
||||
69
tests/fixtures/complex-function.diff
vendored
Normal file
69
tests/fixtures/complex-function.diff
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
diff --git a/src/processor.py b/src/processor.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/processor.py
|
||||
+++ b/src/processor.py
|
||||
@@ -1,5 +1,65 @@
|
||||
"""Data processor module."""
|
||||
|
||||
|
||||
+def process_data(data: dict, config: dict, options: dict | None = None) -> dict:
|
||||
+ """Process data with many nested conditions."""
|
||||
+ result = {}
|
||||
+ options = options or {}
|
||||
+
|
||||
+ if data.get("type") == "A":
|
||||
+ if config.get("mode") == "strict":
|
||||
+ if options.get("validate"):
|
||||
+ if data.get("value") > 100:
|
||||
+ if config.get("transform"):
|
||||
+ result["processed"] = data["value"] * 2
|
||||
+ else:
|
||||
+ result["processed"] = data["value"]
|
||||
+ else:
|
||||
+ if options.get("default"):
|
||||
+ result["processed"] = options["default"]
|
||||
+ else:
|
||||
+ result["processed"] = 0
|
||||
+ else:
|
||||
+ result["processed"] = data.get("value", 0)
|
||||
+ else:
|
||||
+ result["processed"] = data.get("value", 0)
|
||||
+ elif data.get("type") == "B":
|
||||
+ if config.get("mode") == "strict":
|
||||
+ if options.get("validate"):
|
||||
+ if data.get("items"):
|
||||
+ result["processed"] = len(data["items"])
|
||||
+ else:
|
||||
+ result["processed"] = 0
|
||||
+ else:
|
||||
+ result["processed"] = len(data.get("items", []))
|
||||
+ else:
|
||||
+ result["processed"] = len(data.get("items", []))
|
||||
+ elif data.get("type") == "C":
|
||||
+ if config.get("mode") == "strict":
|
||||
+ if options.get("validate"):
|
||||
+ if data.get("text"):
|
||||
+ result["processed"] = data["text"].upper()
|
||||
+ else:
|
||||
+ result["processed"] = ""
|
||||
+ else:
|
||||
+ result["processed"] = data.get("text", "").upper()
|
||||
+ else:
|
||||
+ result["processed"] = data.get("text", "").upper()
|
||||
+ else:
|
||||
+ if config.get("fallback"):
|
||||
+ result["processed"] = config["fallback"]
|
||||
+ else:
|
||||
+ result["processed"] = None
|
||||
+
|
||||
+ if options.get("timestamp"):
|
||||
+ result["timestamp"] = options["timestamp"]
|
||||
+ if options.get("source"):
|
||||
+ result["source"] = options["source"]
|
||||
+
|
||||
+ return result
|
||||
+
|
||||
+
|
||||
def simple_function(x: int) -> int:
|
||||
"""A simple function."""
|
||||
return x * 2
|
||||
31
tests/fixtures/security-issue.diff
vendored
Normal file
31
tests/fixtures/security-issue.diff
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
diff --git a/src/auth.py b/src/auth.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/auth.py
|
||||
+++ b/src/auth.py
|
||||
@@ -1,10 +1,25 @@
|
||||
"""Authentication module."""
|
||||
|
||||
import sqlite3
|
||||
+import os
|
||||
|
||||
|
||||
def get_user(username: str) -> dict | None:
|
||||
"""Get user from database."""
|
||||
conn = sqlite3.connect("users.db")
|
||||
cursor = conn.cursor()
|
||||
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
||||
+ # FIXME: this is vulnerable to SQL injection
|
||||
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||
+ cursor.execute(query)
|
||||
return cursor.fetchone()
|
||||
+
|
||||
+
|
||||
+def run_command(cmd: str) -> str:
|
||||
+ """Run a shell command."""
|
||||
+ # Command injection vulnerability
|
||||
+ return os.popen(cmd).read()
|
||||
+
|
||||
+
|
||||
+# Hardcoded credentials
|
||||
+API_KEY = "sk-1234567890abcdef"
|
||||
+DB_PASSWORD = "admin123"
|
||||
16
tests/fixtures/simple.diff
vendored
Normal file
16
tests/fixtures/simple.diff
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
diff --git a/src/utils.py b/src/utils.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/utils.py
|
||||
+++ b/src/utils.py
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Utility functions."""
|
||||
|
||||
|
||||
+def add(a: int, b: int) -> int:
|
||||
+ """Add two numbers."""
|
||||
+ return a + b
|
||||
+
|
||||
+
|
||||
def subtract(a: int, b: int) -> int:
|
||||
"""Subtract two numbers."""
|
||||
return a - b
|
||||
305
tests/test_agents.py
Normal file
305
tests/test_agents.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Tests for review agents."""
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import AgentConfig, AgentName, Policy, Severity
|
||||
from tests.conftest import MockLLMClient
|
||||
|
||||
|
||||
class TestSecurityAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_returns_result(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.agent_name == AgentName.SECURITY
|
||||
assert result.findings == []
|
||||
assert result.duration_ms >= 0
|
||||
assert result.tokens_used == 150 # 100 in + 50 out from mock
|
||||
assert result.cost_usd == 0.001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_json_findings(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """```json
|
||||
[
|
||||
{
|
||||
"file": "src/auth.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "high",
|
||||
"confidence": 0.9,
|
||||
"title": "SQL Injection",
|
||||
"description": "User input concatenated",
|
||||
"reasoning": "Allows SQL injection",
|
||||
"suggestion": "Use parameterized queries",
|
||||
"references": ["https://owasp.org"]
|
||||
}
|
||||
]
|
||||
```"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ query = ...", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert len(result.findings) == 1
|
||||
finding = result.findings[0]
|
||||
assert finding.file == "src/auth.py"
|
||||
assert finding.severity == Severity.HIGH
|
||||
assert finding.confidence == 0.9
|
||||
assert finding.title == "SQL Injection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_configured_model(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(model="gpt-4o-mini"),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
}
|
||||
)
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
await agent.review(context)
|
||||
|
||||
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_by_severity(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """[
|
||||
{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "high", "confidence": 0.9, "title": "High", "description": "", "reasoning": ""},
|
||||
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "low", "confidence": 0.9, "title": "Low", "description": "", "reasoning": ""},
|
||||
{"file": "c.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.9, "title": "Info", "description": "", "reasoning": ""}
|
||||
]"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(severity_threshold=Severity.MEDIUM),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
}
|
||||
)
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
# Only high severity should pass (medium threshold filters low and info)
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].severity == Severity.HIGH
|
||||
|
||||
|
||||
class TestStyleAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_returns_result(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = StyleAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.agent_name == AgentName.STYLE
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_default_model(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = StyleAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
await agent.review(context)
|
||||
|
||||
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
|
||||
|
||||
|
||||
class TestComplexityAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_returns_result(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = ComplexityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.agent_name == AgentName.COMPLEXITY
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_complexity_findings(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """[
|
||||
{
|
||||
"file": "processor.py",
|
||||
"line_start": 1,
|
||||
"line_end": 50,
|
||||
"severity": "medium",
|
||||
"confidence": 0.8,
|
||||
"title": "High cyclomatic complexity",
|
||||
"description": "Function has 15 branches",
|
||||
"reasoning": "Makes testing and maintenance difficult"
|
||||
}
|
||||
]"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = ComplexityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ complex code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].severity == Severity.MEDIUM
|
||||
assert "complexity" in result.findings[0].title.lower()
|
||||
|
||||
|
||||
class TestAgentResponseParsing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_empty_response(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=[""])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_invalid_json(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["not valid json"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_json_without_code_block(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = '[{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Test", "description": "", "reasoning": ""}]'
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert len(result.findings) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_malformed_finding(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """[
|
||||
{"file": "a.py", "line_start": 1, "severity": "invalid_severity", "confidence": 0.5, "title": "Bad", "description": "", "reasoning": ""},
|
||||
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Valid", "description": "", "reasoning": ""}
|
||||
]"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
# Only the valid finding should be included (first has invalid severity)
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].title == "Valid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_includes_prompt_additions(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(prompt_additions="Focus on authentication"),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
}
|
||||
)
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
await agent.review(context)
|
||||
|
||||
message_content = mock_llm.calls[0]["messages"][0]["content"]
|
||||
assert "Focus on authentication" in message_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_non_list_json(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=['{"not": "a list"}'])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_non_dict_items(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=['["string", 123, null]'])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_without_config_uses_defaults(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
# Create policy with empty agents dict
|
||||
policy = Policy(agents={})
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
# Should use default model (gpt-4o for security)
|
||||
assert mock_llm.calls[0]["model"] == "gpt-4o"
|
||||
assert result.findings == []
|
||||
561
tests/test_cli.py
Normal file
561
tests/test_cli.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""Tests for CLI commands."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from arbiter.cli import (
|
||||
_severity_color,
|
||||
_severity_icon,
|
||||
_verdict_color,
|
||||
_verdict_icon,
|
||||
app,
|
||||
)
|
||||
from arbiter.deliberation import DeliberationResult
|
||||
from arbiter.models import AgentName, Finding, ReviewResult, Severity, Verdict
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def make_mock_return(
|
||||
findings: list[Finding] | None = None, verdict: Verdict = Verdict.APPROVE
|
||||
) -> tuple[list[ReviewResult], DeliberationResult]:
|
||||
"""Create a mock return value for _run_review."""
|
||||
agent_results = [
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=findings or [],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
]
|
||||
deliberation_result = DeliberationResult(
|
||||
findings=findings or [],
|
||||
verdict=verdict,
|
||||
verdict_confidence=0.9,
|
||||
verdict_reasoning="Test reasoning",
|
||||
total_findings=len(findings) if findings else 0,
|
||||
)
|
||||
return agent_results, deliberation_result
|
||||
|
||||
|
||||
class TestVersionCommand:
|
||||
def test_version_output(self) -> None:
|
||||
result = runner.invoke(app, ["version"])
|
||||
assert result.exit_code == 0
|
||||
assert "arbiter" in result.output
|
||||
assert "0.3.0" in result.output
|
||||
|
||||
|
||||
class TestReviewCommand:
|
||||
def test_file_not_found(self) -> None:
|
||||
result = runner.invoke(app, ["review", "nonexistent.diff"])
|
||||
assert result.exit_code == 1
|
||||
assert "File not found" in result.output
|
||||
|
||||
def test_empty_diff_warning(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "empty.diff"
|
||||
diff_file.write_text("")
|
||||
|
||||
result = runner.invoke(app, ["review", str(diff_file)])
|
||||
assert result.exit_code == 0
|
||||
assert "Empty diff" in result.output
|
||||
|
||||
def test_policy_not_found(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--policy", "nonexistent.yaml"])
|
||||
assert result.exit_code == 1
|
||||
assert "Policy file not found" in result.output
|
||||
|
||||
def test_reads_from_stdin(self) -> None:
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", "-"], input="+ added line\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_json_output_format(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert '"verdict"' in result.output
|
||||
assert '"findings"' in result.output
|
||||
|
||||
def test_markdown_output_format(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "# Arbiter Review" in result.output
|
||||
|
||||
def test_loads_policy_file(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
policy_file = tmp_path / "policy.yaml"
|
||||
policy_file.write_text("""
|
||||
version: "1.0"
|
||||
agents:
|
||||
security:
|
||||
enabled: true
|
||||
style:
|
||||
enabled: false
|
||||
complexity:
|
||||
enabled: false
|
||||
""")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--policy", str(policy_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Verify policy was passed to _run_review
|
||||
call_args = mock_run.call_args
|
||||
policy = call_args[0][1] # Second positional arg is policy
|
||||
assert len(policy.get_enabled_agents()) == 1
|
||||
|
||||
def test_model_override(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--model", "gpt-4o-mini"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Verify model was set in policy
|
||||
call_args = mock_run.call_args
|
||||
policy = call_args[0][1]
|
||||
for config in policy.agents.values():
|
||||
assert config.model == "gpt-4o-mini"
|
||||
|
||||
def test_static_analysis_flag(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--no-static-analysis"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Verify static_analysis was False
|
||||
call_args = mock_run.call_args
|
||||
assert call_args.kwargs.get("static_analysis") is False
|
||||
|
||||
|
||||
class TestNoArgsHelp:
|
||||
def test_no_args_shows_help(self) -> None:
|
||||
result = runner.invoke(app, [])
|
||||
assert result.exit_code == 0
|
||||
assert "A multi-agent code review system" in result.output
|
||||
|
||||
|
||||
class TestOutputFormatting:
|
||||
def test_severity_color(self) -> None:
|
||||
assert _severity_color(Severity.CRITICAL) == "red bold"
|
||||
assert _severity_color(Severity.HIGH) == "red"
|
||||
assert _severity_color(Severity.MEDIUM) == "yellow"
|
||||
assert _severity_color(Severity.LOW) == "blue"
|
||||
assert _severity_color(Severity.INFO) == "dim"
|
||||
|
||||
def test_severity_icon(self) -> None:
|
||||
assert _severity_icon(Severity.CRITICAL) == "!!"
|
||||
assert _severity_icon(Severity.HIGH) == "!"
|
||||
assert _severity_icon(Severity.MEDIUM) == "*"
|
||||
assert _severity_icon(Severity.LOW) == "-"
|
||||
assert _severity_icon(Severity.INFO) == "i"
|
||||
|
||||
def test_verdict_color(self) -> None:
|
||||
assert _verdict_color(Verdict.APPROVE) == "green"
|
||||
assert _verdict_color(Verdict.COMMENT) == "yellow"
|
||||
assert _verdict_color(Verdict.REQUEST_CHANGES) == "red"
|
||||
|
||||
def test_verdict_icon(self) -> None:
|
||||
assert _verdict_icon(Verdict.APPROVE) == "[ok]"
|
||||
assert _verdict_icon(Verdict.COMMENT) == "[..]"
|
||||
assert _verdict_icon(Verdict.REQUEST_CHANGES) == "[!!]"
|
||||
|
||||
|
||||
class TestRichOutput:
|
||||
def test_rich_format_with_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
finding = Finding(
|
||||
id="test-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input in query",
|
||||
reasoning="Direct concatenation",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return(findings=[finding])
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_rich_format_no_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No issues found" in result.output
|
||||
|
||||
def test_rich_format_critical_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
finding = Finding(
|
||||
id="test-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=10,
|
||||
severity=Severity.CRITICAL,
|
||||
confidence=0.95,
|
||||
title="Critical Issue",
|
||||
description="This is critical",
|
||||
reasoning="Very bad",
|
||||
suggestion="Fix it immediately",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
findings=[finding],
|
||||
verdict=Verdict.REQUEST_CHANGES,
|
||||
verdict_confidence=0.95,
|
||||
verdict_reasoning="Critical issue found",
|
||||
total_findings=1,
|
||||
critical_count=1,
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[finding],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestMarkdownOutput:
|
||||
def test_markdown_with_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
finding = Finding(
|
||||
id="test-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input in query",
|
||||
reasoning="Direct concatenation",
|
||||
suggestion="Use parameterized queries",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return(findings=[finding])
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "## Findings" in result.output
|
||||
assert "SQL Injection" in result.output
|
||||
|
||||
def test_markdown_verdict_badges(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
for verdict in [Verdict.APPROVE, Verdict.COMMENT, Verdict.REQUEST_CHANGES]:
|
||||
deliberation = DeliberationResult(
|
||||
verdict=verdict,
|
||||
verdict_confidence=0.9,
|
||||
verdict_reasoning="Test",
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert verdict.value.upper() in result.output
|
||||
|
||||
|
||||
class TestJsonOutput:
|
||||
def test_json_with_conflicts(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
|
||||
conflict = Conflict(
|
||||
id="test-conflict",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.TRADE_OFF,
|
||||
description="Test conflict",
|
||||
severity_weight=0.8,
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.7,
|
||||
verdict_reasoning="Conflicts found",
|
||||
conflicts=[conflict],
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
output = json.loads(result.output)
|
||||
assert "conflicts" in output
|
||||
assert len(output["conflicts"]) == 1
|
||||
|
||||
|
||||
class TestWorkDirHandling:
|
||||
def test_work_dir_option(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
work_dir = tmp_path / "src"
|
||||
work_dir.mkdir()
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--work-dir", str(work_dir)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
call_args = mock_run.call_args
|
||||
assert call_args.kwargs.get("work_dir") == work_dir.resolve()
|
||||
|
||||
|
||||
class TestRichOutputWithConflicts:
|
||||
def test_rich_conflicts(self, tmp_path: Path) -> None:
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
from arbiter.deliberation.synthesis import Resolution
|
||||
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
conflict = Conflict(
|
||||
id="test-conflict",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.TRADE_OFF,
|
||||
description="Security vs complexity trade-off",
|
||||
severity_weight=0.8,
|
||||
)
|
||||
|
||||
resolution = Resolution(
|
||||
conflict_id="test-conflict",
|
||||
decision="prefer_first",
|
||||
reasoning="Security takes priority",
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.7,
|
||||
verdict_reasoning="Conflicts found",
|
||||
conflicts=[conflict],
|
||||
resolutions=[resolution],
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestMarkdownOutputWithConflicts:
|
||||
def test_markdown_conflicts(self, tmp_path: Path) -> None:
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
from arbiter.deliberation.synthesis import Resolution
|
||||
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
conflict = Conflict(
|
||||
id="test-conflict",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.CONTRADICTORY,
|
||||
description="Contradictory recommendations",
|
||||
severity_weight=0.9,
|
||||
)
|
||||
|
||||
resolution = Resolution(
|
||||
conflict_id="test-conflict",
|
||||
decision="merge",
|
||||
reasoning="Both concerns addressed by combined fix",
|
||||
merged_suggestion="Do both things",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.7,
|
||||
verdict_reasoning="Conflicts found",
|
||||
conflicts=[conflict],
|
||||
resolutions=[resolution],
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "## Conflicts" in result.output
|
||||
assert "Contradictory" in result.output
|
||||
assert "Resolution" in result.output
|
||||
|
||||
def test_markdown_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
findings = [
|
||||
Finding(
|
||||
id="f1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="Security Issue",
|
||||
description="Vulnerable code",
|
||||
reasoning="Bad pattern",
|
||||
suggestion="Fix it this way",
|
||||
prompt_version="test-v1.0",
|
||||
),
|
||||
Finding(
|
||||
id="f2",
|
||||
agent=AgentName.STYLE,
|
||||
file="test.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.LOW,
|
||||
confidence=0.8,
|
||||
title="Style Issue",
|
||||
description="Could be cleaner",
|
||||
reasoning="Convention",
|
||||
prompt_version="test-v1.0",
|
||||
),
|
||||
]
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
findings=findings,
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.75,
|
||||
verdict_reasoning="Issues found",
|
||||
total_findings=2,
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[findings[0]],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
),
|
||||
ReviewResult(
|
||||
agent_name=AgentName.STYLE,
|
||||
findings=[findings[1]],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
),
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Security Issue" in result.output
|
||||
assert "Style Issue" in result.output
|
||||
assert "Fix it this way" in result.output
|
||||
313
tests/test_llm.py
Normal file
313
tests/test_llm.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Tests for LLM client and prompts."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.llm.client import LiteLLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||
from tests.conftest import MockLLMClient
|
||||
|
||||
|
||||
class TestLiteLLMClient:
|
||||
def test_init_default_values(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
assert client.timeout == 60
|
||||
assert client.max_retries == 3
|
||||
|
||||
def test_init_custom_values(self) -> None:
|
||||
client = LiteLLMClient(timeout=120, max_retries=5)
|
||||
assert client.timeout == 120
|
||||
assert client.max_retries == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_returns_response(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
# Mock the litellm.acompletion function
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.001
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.tokens_in == 10
|
||||
assert response.tokens_out == 5
|
||||
assert response.cost_usd == 0.001
|
||||
|
||||
mock_acomp.assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
timeout=60,
|
||||
num_retries=3,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_handles_empty_content(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = None # None content
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 5
|
||||
mock_response.usage.completion_tokens = 0
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.content == "" # Should be empty string, not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_handles_missing_usage(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = None # No usage data
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.tokens_in == 0
|
||||
assert response.tokens_out == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_uses_fallback_model(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.model = None # No model in response
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 5
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "claude-3-opus")
|
||||
|
||||
# Should use the passed model as fallback
|
||||
assert response.model == "claude-3-opus"
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
def test_response_creation(self) -> None:
|
||||
response = LLMResponse(
|
||||
content="Hello, world!",
|
||||
model="gpt-4o",
|
||||
tokens_in=10,
|
||||
tokens_out=5,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
assert response.content == "Hello, world!"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.tokens_in == 10
|
||||
assert response.tokens_out == 5
|
||||
assert response.cost_usd == 0.001
|
||||
|
||||
|
||||
class TestMockLLMClient:
|
||||
@pytest.mark.asyncio
|
||||
async def test_records_calls(self) -> None:
|
||||
client = MockLLMClient()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert len(client.calls) == 1
|
||||
assert client.calls[0]["messages"] == messages
|
||||
assert client.calls[0]["model"] == "gpt-4o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_canned_responses(self) -> None:
|
||||
client = MockLLMClient(responses=["First", "Second"])
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
response1 = await client.complete(messages, "gpt-4o")
|
||||
response2 = await client.complete(messages, "gpt-4o")
|
||||
response3 = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response1.content == "First"
|
||||
assert response2.content == "Second"
|
||||
assert response3.content == "" # Exhausted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset(self) -> None:
|
||||
client = MockLLMClient(responses=["Hello"])
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
|
||||
await client.complete(messages, "gpt-4o")
|
||||
assert len(client.calls) == 1
|
||||
|
||||
client.reset()
|
||||
assert len(client.calls) == 0
|
||||
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
assert response.content == "Hello" # Responses reset too
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
def test_template_creation(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="security",
|
||||
version="1.0",
|
||||
content="Review: {{diff}}",
|
||||
)
|
||||
assert template.name == "security"
|
||||
assert template.version == "1.0"
|
||||
assert template.full_name == "security-v1.0"
|
||||
|
||||
def test_render_substitution(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="File: {{file}}\nDiff: {{diff}}",
|
||||
)
|
||||
result = template.render(file="test.py", diff="+ added line")
|
||||
assert result == "File: test.py\nDiff: + added line"
|
||||
|
||||
def test_render_missing_variable(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="Value: {{value}}",
|
||||
)
|
||||
result = template.render()
|
||||
assert result == "Value: {{value}}"
|
||||
|
||||
def test_render_multiple_occurrences(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="{{name}} and {{name}} again",
|
||||
)
|
||||
result = template.render(name="test")
|
||||
assert result == "test and test again"
|
||||
|
||||
|
||||
class TestPromptRegistry:
|
||||
def test_get_template(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template = registry.get("security", "1.0")
|
||||
|
||||
assert template.name == "security"
|
||||
assert template.version == "1.0"
|
||||
assert "{{diff}}" in template.content
|
||||
|
||||
def test_get_template_cached(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template1 = registry.get("security", "1.0")
|
||||
template2 = registry.get("security", "1.0")
|
||||
|
||||
assert template1 is template2
|
||||
|
||||
def test_get_template_not_found(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
registry.get("missing", "1.0")
|
||||
|
||||
def test_list_templates(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||
(templates_dir / "style-v2.0.md").write_text("Content")
|
||||
(templates_dir / "readme.md").write_text("Not a template")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert len(templates) == 2
|
||||
assert ("security", "1.0") in templates
|
||||
assert ("style", "2.0") in templates
|
||||
|
||||
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert templates == []
|
||||
|
||||
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "missing"
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert templates == []
|
||||
|
||||
def test_clear_cache(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "test-v1.0.md").write_text("Original")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template1 = registry.get("test", "1.0")
|
||||
assert template1.content == "Original"
|
||||
|
||||
# Modify file
|
||||
(templates_dir / "test-v1.0.md").write_text("Modified")
|
||||
|
||||
# Still cached
|
||||
template2 = registry.get("test", "1.0")
|
||||
assert template2.content == "Original"
|
||||
|
||||
# Clear cache
|
||||
registry.clear_cache()
|
||||
|
||||
# Now reads new content
|
||||
template3 = registry.get("test", "1.0")
|
||||
assert template3.content == "Modified"
|
||||
224
tests/test_models.py
Normal file
224
tests/test_models.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Tests for data models."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.models import (
|
||||
AgentConfig,
|
||||
AgentName,
|
||||
Finding,
|
||||
Policy,
|
||||
ReviewResult,
|
||||
Severity,
|
||||
Verdict,
|
||||
)
|
||||
|
||||
|
||||
class TestEnums:
|
||||
def test_severity_values(self) -> None:
|
||||
assert Severity.CRITICAL == "critical"
|
||||
assert Severity.HIGH == "high"
|
||||
assert Severity.MEDIUM == "medium"
|
||||
assert Severity.LOW == "low"
|
||||
assert Severity.INFO == "info"
|
||||
|
||||
def test_verdict_values(self) -> None:
|
||||
assert Verdict.APPROVE == "approve"
|
||||
assert Verdict.REQUEST_CHANGES == "request_changes"
|
||||
assert Verdict.COMMENT == "comment"
|
||||
|
||||
def test_agent_name_values(self) -> None:
|
||||
assert AgentName.SECURITY == "security"
|
||||
assert AgentName.STYLE == "style"
|
||||
assert AgentName.COMPLEXITY == "complexity"
|
||||
|
||||
def test_severity_from_string(self) -> None:
|
||||
assert Severity("critical") == Severity.CRITICAL
|
||||
assert Severity("high") == Severity.HIGH
|
||||
|
||||
|
||||
class TestFinding:
|
||||
def test_finding_creation(self) -> None:
|
||||
finding = Finding(
|
||||
id="test-123",
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input concatenated in SQL query",
|
||||
reasoning="Allows attackers to execute arbitrary SQL",
|
||||
suggestion="Use parameterized queries",
|
||||
references=["https://owasp.org"],
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
assert finding.id == "test-123"
|
||||
assert finding.agent == AgentName.SECURITY
|
||||
assert finding.severity == Severity.HIGH
|
||||
assert finding.confidence == 0.9
|
||||
|
||||
def test_finding_confidence_validation(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
Finding(
|
||||
id="test",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=1,
|
||||
line_end=1,
|
||||
severity=Severity.INFO,
|
||||
confidence=1.5,
|
||||
title="Test",
|
||||
description="Test",
|
||||
reasoning="Test",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
def test_finding_line_validation(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
Finding(
|
||||
id="test",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=0,
|
||||
line_end=1,
|
||||
severity=Severity.INFO,
|
||||
confidence=0.5,
|
||||
title="Test",
|
||||
description="Test",
|
||||
reasoning="Test",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
def test_finding_serialization(self) -> None:
|
||||
finding = Finding(
|
||||
id="test-123",
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="Test",
|
||||
description="Test desc",
|
||||
reasoning="Test reason",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
data = finding.model_dump()
|
||||
assert data["id"] == "test-123"
|
||||
assert data["agent"] == "security"
|
||||
assert data["severity"] == "high"
|
||||
|
||||
|
||||
class TestAgentConfig:
|
||||
def test_default_values(self) -> None:
|
||||
config = AgentConfig()
|
||||
assert config.enabled is True
|
||||
assert config.model is None
|
||||
assert config.severity_threshold == Severity.INFO
|
||||
assert config.prompt_additions is None
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
config = AgentConfig(
|
||||
enabled=False,
|
||||
model="gpt-4o",
|
||||
severity_threshold=Severity.MEDIUM,
|
||||
prompt_additions="Focus on auth",
|
||||
)
|
||||
assert config.enabled is False
|
||||
assert config.model == "gpt-4o"
|
||||
assert config.severity_threshold == Severity.MEDIUM
|
||||
|
||||
|
||||
class TestPolicy:
|
||||
def test_default_policy(self) -> None:
|
||||
policy = Policy()
|
||||
assert len(policy.agents) == 3
|
||||
assert AgentName.SECURITY in policy.agents
|
||||
assert AgentName.STYLE in policy.agents
|
||||
assert AgentName.COMPLEXITY in policy.agents
|
||||
|
||||
def test_get_enabled_agents(self) -> None:
|
||||
policy = Policy()
|
||||
enabled = policy.get_enabled_agents()
|
||||
assert len(enabled) == 3
|
||||
assert AgentName.SECURITY in enabled
|
||||
|
||||
def test_get_enabled_agents_with_disabled(self) -> None:
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(enabled=True),
|
||||
AgentName.STYLE: AgentConfig(enabled=False),
|
||||
AgentName.COMPLEXITY: AgentConfig(enabled=True),
|
||||
}
|
||||
)
|
||||
enabled = policy.get_enabled_agents()
|
||||
assert len(enabled) == 2
|
||||
assert AgentName.STYLE not in enabled
|
||||
|
||||
def test_load_from_yaml(self, tmp_path: Path) -> None:
|
||||
policy_file = tmp_path / "policy.yaml"
|
||||
policy_file.write_text("""
|
||||
version: "1.1"
|
||||
agents:
|
||||
security:
|
||||
enabled: true
|
||||
model: gpt-4o
|
||||
severity_threshold: high
|
||||
style:
|
||||
enabled: false
|
||||
complexity:
|
||||
enabled: true
|
||||
""")
|
||||
policy = Policy.load(policy_file)
|
||||
assert policy.version == "1.1"
|
||||
assert policy.agents[AgentName.SECURITY].model == "gpt-4o"
|
||||
assert policy.agents[AgentName.SECURITY].severity_threshold == Severity.HIGH
|
||||
assert policy.agents[AgentName.STYLE].enabled is False
|
||||
|
||||
def test_load_empty_yaml(self, tmp_path: Path) -> None:
|
||||
policy_file = tmp_path / "policy.yaml"
|
||||
policy_file.write_text("")
|
||||
policy = Policy.load(policy_file)
|
||||
assert policy.version == "1.0"
|
||||
assert len(policy.agents) == 3
|
||||
|
||||
|
||||
class TestReviewResult:
|
||||
def test_review_result_creation(self) -> None:
|
||||
result = ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=1000,
|
||||
tokens_used=500,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
assert result.agent_name == AgentName.SECURITY
|
||||
assert result.duration_ms == 1000
|
||||
assert result.cost_usd == 0.01
|
||||
|
||||
def test_review_result_with_findings(self) -> None:
|
||||
finding = Finding(
|
||||
id="test-123",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=1,
|
||||
line_end=1,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="Test",
|
||||
description="Test",
|
||||
reasoning="Test",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
result = ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[finding],
|
||||
duration_ms=1000,
|
||||
tokens_used=500,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].id == "test-123"
|
||||
Reference in New Issue
Block a user