feat(agents): implement agent framework and CLI
This commit is contained in:
111
pyproject.toml
Normal file
111
pyproject.toml
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "arbiter"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A multi-agent code review system that shows its work"
|
||||||
|
readme = "readme.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
license = "MIT"
|
||||||
|
authors = [{ name = "Kai Chappell", email = "git@kschappell.com" }]
|
||||||
|
keywords = ["code-review", "ai", "llm", "agents"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Environment :: Console",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Software Development :: Quality Assurance",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"litellm>=1.0.0",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
|
"pydantic-settings>=2.0.0",
|
||||||
|
"typer>=0.9.0",
|
||||||
|
"rich>=13.0.0",
|
||||||
|
"pyyaml>=6.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.23.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"ruff>=0.4.0",
|
||||||
|
"mypy>=1.10.0",
|
||||||
|
"types-PyYAML>=6.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
arbiter = "arbiter.cli:app"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/arbiter"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py312"
|
||||||
|
line-length = 100
|
||||||
|
src = ["src", "tests"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # Pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"ARG", # flake8-unused-arguments
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"TCH", # flake8-type-checking
|
||||||
|
"PTH", # flake8-use-pathlib
|
||||||
|
"RUF", # Ruff-specific rules
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["arbiter"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.12"
|
||||||
|
strict = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
warn_redundant_casts = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
show_error_codes = true
|
||||||
|
files = ["src"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["litellm.*"]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
addopts = "-ra -q"
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["src/arbiter"]
|
||||||
|
branch = true
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
exclude_lines = [
|
||||||
|
"pragma: no cover",
|
||||||
|
"def __repr__",
|
||||||
|
"raise NotImplementedError",
|
||||||
|
"if TYPE_CHECKING:",
|
||||||
|
"if __name__ == .__main__.:",
|
||||||
|
]
|
||||||
|
fail_under = 85
|
||||||
3
src/arbiter/__init__.py
Normal file
3
src/arbiter/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Arbiter: A multi-agent code review system that shows its work."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
14
src/arbiter/agents/__init__.py
Normal file
14
src/arbiter/agents/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Review agents for code analysis."""
|
||||||
|
|
||||||
|
from arbiter.agents.base import Agent, ReviewContext
|
||||||
|
from arbiter.agents.complexity import ComplexityAgent
|
||||||
|
from arbiter.agents.security import SecurityAgent
|
||||||
|
from arbiter.agents.style import StyleAgent
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Agent",
|
||||||
|
"ComplexityAgent",
|
||||||
|
"ReviewContext",
|
||||||
|
"SecurityAgent",
|
||||||
|
"StyleAgent",
|
||||||
|
]
|
||||||
277
src/arbiter/agents/base.py
Normal file
277
src/arbiter/agents/base.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Base agent class and review context."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from arbiter.llm.client import LLMClient
|
||||||
|
from arbiter.llm.prompts import PromptRegistry
|
||||||
|
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewContext(BaseModel):
|
||||||
|
"""Context for a code review."""
|
||||||
|
|
||||||
|
diff: str = Field(description="The diff content to review")
|
||||||
|
policy: Policy = Field(description="Review policy configuration")
|
||||||
|
file_path: str | None = Field(default=None, description="Path to the diff file")
|
||||||
|
static_analysis_context: str = Field(
|
||||||
|
default="No static analysis data available.",
|
||||||
|
description="Static analysis results as formatted string",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExplainContext(BaseModel):
|
||||||
|
"""Context for explaining a finding in a follow-up conversation."""
|
||||||
|
|
||||||
|
question: str = Field(description="The user's follow-up question")
|
||||||
|
finding: Finding = Field(description="The finding being asked about")
|
||||||
|
diff: str = Field(description="The relevant code diff")
|
||||||
|
conversation_history: list[dict[str, str]] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Previous messages in this conversation",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExplainResult(BaseModel):
|
||||||
|
"""Result from an agent's explain() method."""
|
||||||
|
|
||||||
|
response: str = Field(description="The explanation response text")
|
||||||
|
tokens_used: int = Field(ge=0, description="Total tokens used for this response")
|
||||||
|
cost_usd: float = Field(ge=0.0, description="Cost in USD for this response")
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0, description="Confidence in the explanation")
|
||||||
|
|
||||||
|
|
||||||
|
class Agent(ABC):
|
||||||
|
"""Abstract base class for review agents."""
|
||||||
|
|
||||||
|
name: AgentName
|
||||||
|
prompt_name: str
|
||||||
|
prompt_version: str = "1.0"
|
||||||
|
default_model: str = "gpt-4o"
|
||||||
|
explain_prompt_name: str = "" # Set by subclasses
|
||||||
|
explain_prompt_version: str = "1.0"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: LLMClient,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
self.llm_client = llm_client
|
||||||
|
self.prompt_registry = prompt_registry
|
||||||
|
|
||||||
|
async def review(self, context: ReviewContext) -> ReviewResult:
|
||||||
|
"""Perform a code review.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Review context with diff and policy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReviewResult with findings and metadata.
|
||||||
|
"""
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
# Get model from policy or use default
|
||||||
|
agent_config = context.policy.agents.get(self.name)
|
||||||
|
model = agent_config.model if agent_config and agent_config.model else self.default_model
|
||||||
|
|
||||||
|
# Get prompt additions from policy
|
||||||
|
prompt_additions = ""
|
||||||
|
if agent_config and agent_config.prompt_additions:
|
||||||
|
prompt_additions = agent_config.prompt_additions
|
||||||
|
|
||||||
|
# Build and send messages
|
||||||
|
messages = self._build_messages(context, prompt_additions)
|
||||||
|
response = await self._call_llm(messages, model)
|
||||||
|
|
||||||
|
# Parse response into findings
|
||||||
|
findings = self._parse_response(response.content, context)
|
||||||
|
|
||||||
|
# Filter by severity threshold
|
||||||
|
if agent_config:
|
||||||
|
findings = self._filter_by_severity(findings, agent_config.severity_threshold)
|
||||||
|
|
||||||
|
duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
|
|
||||||
|
return ReviewResult(
|
||||||
|
agent_name=self.name,
|
||||||
|
findings=findings,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
tokens_used=response.tokens_in + response.tokens_out,
|
||||||
|
cost_usd=response.cost_usd,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_messages(
|
||||||
|
self,
|
||||||
|
context: ReviewContext,
|
||||||
|
prompt_additions: str,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
template = self.prompt_registry.get(self.prompt_name, self.prompt_version)
|
||||||
|
content = template.render(
|
||||||
|
diff=context.diff,
|
||||||
|
static_analysis_context=context.static_analysis_context,
|
||||||
|
prompt_additions=prompt_additions,
|
||||||
|
)
|
||||||
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
async def _call_llm(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
) -> Any:
|
||||||
|
return await self.llm_client.complete(messages, model)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: ...
|
||||||
|
|
||||||
|
def _filter_by_severity(
|
||||||
|
self,
|
||||||
|
findings: list[Finding],
|
||||||
|
threshold: Severity,
|
||||||
|
) -> list[Finding]:
|
||||||
|
"""Filter findings by severity threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
findings: List of findings to filter.
|
||||||
|
threshold: Minimum severity to include.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of findings.
|
||||||
|
"""
|
||||||
|
severity_order = [
|
||||||
|
Severity.CRITICAL,
|
||||||
|
Severity.HIGH,
|
||||||
|
Severity.MEDIUM,
|
||||||
|
Severity.LOW,
|
||||||
|
Severity.INFO,
|
||||||
|
]
|
||||||
|
threshold_index = severity_order.index(threshold)
|
||||||
|
return [f for f in findings if severity_order.index(f.severity) <= threshold_index]
|
||||||
|
|
||||||
|
def _parse_json_findings(self, content: str, _context: ReviewContext) -> list[Finding]:
|
||||||
|
"""Parse JSON array of findings from LLM response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw LLM response content.
|
||||||
|
context: Review context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Finding objects.
|
||||||
|
"""
|
||||||
|
# Extract JSON from response (handle markdown code blocks)
|
||||||
|
json_content = content.strip()
|
||||||
|
if json_content.startswith("```"):
|
||||||
|
lines = json_content.split("\n")
|
||||||
|
# Remove first and last lines (code block markers)
|
||||||
|
json_lines = []
|
||||||
|
in_block = False
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("```") and not in_block:
|
||||||
|
in_block = True
|
||||||
|
continue
|
||||||
|
if line.startswith("```") and in_block:
|
||||||
|
break
|
||||||
|
if in_block:
|
||||||
|
json_lines.append(line)
|
||||||
|
json_content = "\n".join(json_lines)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(json_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
findings = []
|
||||||
|
for item in data:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
finding = Finding(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent=self.name,
|
||||||
|
file=item.get("file", "unknown"),
|
||||||
|
line_start=item.get("line_start", 1),
|
||||||
|
line_end=item.get("line_end") or item.get("line_start", 1),
|
||||||
|
severity=Severity(item.get("severity", "info")),
|
||||||
|
confidence=float(item.get("confidence", 0.5)),
|
||||||
|
title=item.get("title", "Untitled finding"),
|
||||||
|
description=item.get("description", ""),
|
||||||
|
reasoning=item.get("reasoning", ""),
|
||||||
|
suggestion=item.get("suggestion"),
|
||||||
|
references=item.get("references", []),
|
||||||
|
prompt_version=f"{self.prompt_name}-v{self.prompt_version}",
|
||||||
|
static_analysis_context=None,
|
||||||
|
)
|
||||||
|
findings.append(finding)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return findings
|
||||||
|
|
||||||
|
async def explain(
|
||||||
|
self,
|
||||||
|
context: ExplainContext,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> ExplainResult:
|
||||||
|
"""Provide a detailed explanation about a specific finding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The explanation context with question, finding, and history.
|
||||||
|
model: Optional model override. Defaults to agent's default_model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExplainResult with response text and cost tracking.
|
||||||
|
"""
|
||||||
|
if not self.explain_prompt_name:
|
||||||
|
# Derive from prompt_name if not set
|
||||||
|
explain_name = f"{self.prompt_name}-explain"
|
||||||
|
else:
|
||||||
|
explain_name = self.explain_prompt_name
|
||||||
|
|
||||||
|
template = self.prompt_registry.get(explain_name, self.explain_prompt_version)
|
||||||
|
|
||||||
|
# Format conversation history
|
||||||
|
history_text = ""
|
||||||
|
if context.conversation_history:
|
||||||
|
history_lines = []
|
||||||
|
for msg in context.conversation_history:
|
||||||
|
role = msg.get("role", "user").capitalize()
|
||||||
|
content = msg.get("content", "")
|
||||||
|
history_lines.append(f"**{role}:** {content}")
|
||||||
|
history_text = "\n\n".join(history_lines)
|
||||||
|
|
||||||
|
# Build finding context
|
||||||
|
finding_lines = (
|
||||||
|
f"{context.finding.line_start}"
|
||||||
|
if context.finding.line_start == context.finding.line_end
|
||||||
|
else f"{context.finding.line_start}-{context.finding.line_end}"
|
||||||
|
)
|
||||||
|
|
||||||
|
content = template.render(
|
||||||
|
finding_title=context.finding.title,
|
||||||
|
finding_file=context.finding.file,
|
||||||
|
finding_lines=finding_lines,
|
||||||
|
finding_description=context.finding.description,
|
||||||
|
finding_reasoning=context.finding.reasoning,
|
||||||
|
finding_severity=context.finding.severity.value,
|
||||||
|
finding_suggestion=context.finding.suggestion or "No suggestion provided.",
|
||||||
|
diff=context.diff,
|
||||||
|
conversation_history=history_text or "No previous conversation.",
|
||||||
|
question=context.question,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
response = await self._call_llm(messages, model or self.default_model)
|
||||||
|
|
||||||
|
return ExplainResult(
|
||||||
|
response=response.content.strip(),
|
||||||
|
tokens_used=response.tokens_in + response.tokens_out,
|
||||||
|
cost_usd=response.cost_usd,
|
||||||
|
confidence=0.8, # Default confidence for explanations
|
||||||
|
)
|
||||||
17
src/arbiter/agents/complexity.py
Normal file
17
src/arbiter/agents/complexity.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Complexity review agent."""
|
||||||
|
|
||||||
|
from arbiter.agents.base import Agent, ReviewContext
|
||||||
|
from arbiter.models import AgentName, Finding
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexityAgent(Agent):
|
||||||
|
"""Agent focused on code complexity and architecture."""
|
||||||
|
|
||||||
|
name = AgentName.COMPLEXITY
|
||||||
|
prompt_name = "complexity"
|
||||||
|
prompt_version = "1.0"
|
||||||
|
explain_prompt_name = "complexity-explain"
|
||||||
|
default_model = "gpt-4o-mini"
|
||||||
|
|
||||||
|
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
|
||||||
|
return self._parse_json_findings(content, context)
|
||||||
17
src/arbiter/agents/security.py
Normal file
17
src/arbiter/agents/security.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Security review agent."""
|
||||||
|
|
||||||
|
from arbiter.agents.base import Agent, ReviewContext
|
||||||
|
from arbiter.models import AgentName, Finding
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityAgent(Agent):
|
||||||
|
"""Agent focused on security vulnerabilities and risks."""
|
||||||
|
|
||||||
|
name = AgentName.SECURITY
|
||||||
|
prompt_name = "security"
|
||||||
|
prompt_version = "1.0"
|
||||||
|
explain_prompt_name = "security-explain"
|
||||||
|
default_model = "gpt-4o"
|
||||||
|
|
||||||
|
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
|
||||||
|
return self._parse_json_findings(content, context)
|
||||||
17
src/arbiter/agents/style.py
Normal file
17
src/arbiter/agents/style.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Style review agent."""
|
||||||
|
|
||||||
|
from arbiter.agents.base import Agent, ReviewContext
|
||||||
|
from arbiter.models import AgentName, Finding
|
||||||
|
|
||||||
|
|
||||||
|
class StyleAgent(Agent):
|
||||||
|
"""Agent focused on code style, naming, and readability."""
|
||||||
|
|
||||||
|
name = AgentName.STYLE
|
||||||
|
prompt_name = "style"
|
||||||
|
prompt_version = "1.0"
|
||||||
|
explain_prompt_name = "style-explain"
|
||||||
|
default_model = "gpt-4o-mini"
|
||||||
|
|
||||||
|
def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]:
|
||||||
|
return self._parse_json_findings(content, context)
|
||||||
407
src/arbiter/cli.py
Normal file
407
src/arbiter/cli.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
"""Command-line interface for Arbiter."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
|
||||||
|
from arbiter.analysis import DiffParser, StaticAnalysisRunner
|
||||||
|
from arbiter.config import get_settings
|
||||||
|
from arbiter.deliberation import Coordinator, DeliberationResult
|
||||||
|
from arbiter.llm import LiteLLMClient, PromptRegistry
|
||||||
|
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity, Verdict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from arbiter.agents.base import Agent
|
||||||
|
|
||||||
|
app = typer.Typer(
|
||||||
|
name="arbiter",
|
||||||
|
help="A multi-agent code review system that shows its work.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def _severity_color(severity: Severity) -> str:
|
||||||
|
colors = {
|
||||||
|
Severity.CRITICAL: "red bold",
|
||||||
|
Severity.HIGH: "red",
|
||||||
|
Severity.MEDIUM: "yellow",
|
||||||
|
Severity.LOW: "blue",
|
||||||
|
Severity.INFO: "dim",
|
||||||
|
}
|
||||||
|
return colors.get(severity, "white")
|
||||||
|
|
||||||
|
|
||||||
|
def _severity_icon(severity: Severity) -> str:
|
||||||
|
icons = {
|
||||||
|
Severity.CRITICAL: "!!",
|
||||||
|
Severity.HIGH: "!",
|
||||||
|
Severity.MEDIUM: "*",
|
||||||
|
Severity.LOW: "-",
|
||||||
|
Severity.INFO: "i",
|
||||||
|
}
|
||||||
|
return icons.get(severity, " ")
|
||||||
|
|
||||||
|
|
||||||
|
def _verdict_color(verdict: Verdict) -> str:
|
||||||
|
colors = {
|
||||||
|
Verdict.APPROVE: "green",
|
||||||
|
Verdict.COMMENT: "yellow",
|
||||||
|
Verdict.REQUEST_CHANGES: "red",
|
||||||
|
}
|
||||||
|
return colors.get(verdict, "white")
|
||||||
|
|
||||||
|
|
||||||
|
def _verdict_icon(verdict: Verdict) -> str:
|
||||||
|
icons = {
|
||||||
|
Verdict.APPROVE: "[ok]",
|
||||||
|
Verdict.COMMENT: "[..]",
|
||||||
|
Verdict.REQUEST_CHANGES: "[!!]",
|
||||||
|
}
|
||||||
|
return icons.get(verdict, "")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_rich(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
|
||||||
|
"""Format results using Rich for terminal output."""
|
||||||
|
total_tokens = sum(r.tokens_used for r in agent_results) + result.tokens_used
|
||||||
|
total_cost = sum(r.cost_usd for r in agent_results) + result.cost_usd
|
||||||
|
|
||||||
|
# Verdict panel
|
||||||
|
verdict_style = _verdict_color(result.verdict)
|
||||||
|
verdict_icon = _verdict_icon(result.verdict)
|
||||||
|
verdict_text = f"[{verdict_style} bold]{verdict_icon} {result.verdict.value.upper()}[/]"
|
||||||
|
verdict_text += f" (confidence: {result.verdict_confidence:.0%})"
|
||||||
|
|
||||||
|
summary = f"{verdict_text}\n\n"
|
||||||
|
summary += f"[bold]{result.total_findings}[/bold] findings from [bold]{len(agent_results)}[/bold] agents"
|
||||||
|
if result.conflicts:
|
||||||
|
summary += f" | [yellow]{len(result.conflicts)}[/yellow] conflicts"
|
||||||
|
summary += f"\nTokens: {total_tokens:,} | Cost: ${total_cost:.4f}"
|
||||||
|
|
||||||
|
console.print(Panel(summary, title="Arbiter Review", border_style="blue"))
|
||||||
|
|
||||||
|
if result.verdict_reasoning:
|
||||||
|
console.print(f"\n[dim]Reason:[/dim] {result.verdict_reasoning}")
|
||||||
|
|
||||||
|
if result.total_findings == 0:
|
||||||
|
console.print("\n[green]No issues found![/green]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Findings table
|
||||||
|
console.print("\n[bold]Findings[/bold]")
|
||||||
|
|
||||||
|
table = Table(show_header=True, header_style="bold")
|
||||||
|
table.add_column("Sev", width=4)
|
||||||
|
table.add_column("Agent", width=12)
|
||||||
|
table.add_column("File", width=30)
|
||||||
|
table.add_column("Line", width=8)
|
||||||
|
table.add_column("Title", width=50)
|
||||||
|
|
||||||
|
for finding in result.findings:
|
||||||
|
sev_icon = _severity_icon(finding.severity)
|
||||||
|
sev_style = _severity_color(finding.severity)
|
||||||
|
line_range = (
|
||||||
|
f"{finding.line_start}-{finding.line_end}"
|
||||||
|
if finding.line_start != finding.line_end
|
||||||
|
else str(finding.line_start)
|
||||||
|
)
|
||||||
|
table.add_row(
|
||||||
|
f"[{sev_style}]{sev_icon}[/{sev_style}]",
|
||||||
|
finding.agent.value,
|
||||||
|
finding.file,
|
||||||
|
line_range,
|
||||||
|
finding.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
# Conflicts section
|
||||||
|
if result.conflicts:
|
||||||
|
console.print("\n[bold yellow]Conflicts[/bold yellow]")
|
||||||
|
for i, conflict in enumerate(result.conflicts, 1):
|
||||||
|
console.print(f"\n{i}. [{conflict.nature.value}] {conflict.description}")
|
||||||
|
if result.resolutions:
|
||||||
|
for resolution in result.resolutions:
|
||||||
|
if resolution.conflict_id == conflict.id:
|
||||||
|
console.print(f" [dim]Resolution:[/dim] {resolution.reasoning}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Details for critical/high findings
|
||||||
|
critical_high = [f for f in result.findings if f.severity in (Severity.CRITICAL, Severity.HIGH)]
|
||||||
|
if critical_high:
|
||||||
|
console.print("\n[bold]Details[/bold]")
|
||||||
|
for finding in critical_high:
|
||||||
|
console.print(
|
||||||
|
f"\n[{_severity_color(finding.severity)}]{finding.title}[/]"
|
||||||
|
f" ({finding.file}:{finding.line_start})"
|
||||||
|
)
|
||||||
|
console.print(f" {finding.description}")
|
||||||
|
if finding.suggestion:
|
||||||
|
console.print(f" [dim]Suggestion:[/dim] {finding.suggestion}")
|
||||||
|
|
||||||
|
# Deliberation log
|
||||||
|
console.print("\n[bold dim]Deliberation Log[/bold dim]")
|
||||||
|
for step in result.steps:
|
||||||
|
console.print(f" [{step.step_type.value}] {step.description}")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_json(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
|
||||||
|
"""Format results as JSON."""
|
||||||
|
output = {
|
||||||
|
"verdict": result.verdict.value,
|
||||||
|
"verdict_confidence": result.verdict_confidence,
|
||||||
|
"verdict_reasoning": result.verdict_reasoning,
|
||||||
|
"findings": [f.model_dump() for f in result.findings],
|
||||||
|
"conflicts": [c.model_dump() for c in result.conflicts],
|
||||||
|
"resolutions": [r.model_dump() for r in result.resolutions],
|
||||||
|
"deliberation_steps": [s.model_dump() for s in result.steps],
|
||||||
|
"agents": [r.model_dump() for r in agent_results],
|
||||||
|
"summary": {
|
||||||
|
"total_findings": result.total_findings,
|
||||||
|
"critical_count": result.critical_count,
|
||||||
|
"high_count": result.high_count,
|
||||||
|
"conflicts_count": len(result.conflicts),
|
||||||
|
"total_tokens": sum(r.tokens_used for r in agent_results) + result.tokens_used,
|
||||||
|
"total_cost_usd": sum(r.cost_usd for r in agent_results) + result.cost_usd,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
console.print(json.dumps(output, indent=2, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
def _format_markdown(result: DeliberationResult, agent_results: list[ReviewResult]) -> None:
|
||||||
|
"""Format results as Markdown."""
|
||||||
|
lines = ["# Arbiter Review", ""]
|
||||||
|
|
||||||
|
# Verdict
|
||||||
|
verdict_badge = {
|
||||||
|
Verdict.APPROVE: ":white_check_mark:",
|
||||||
|
Verdict.COMMENT: ":speech_balloon:",
|
||||||
|
Verdict.REQUEST_CHANGES: ":x:",
|
||||||
|
}.get(result.verdict, "")
|
||||||
|
|
||||||
|
lines.append(f"## Verdict: {verdict_badge} {result.verdict.value.upper()}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"**Confidence:** {result.verdict_confidence:.0%}")
|
||||||
|
if result.verdict_reasoning:
|
||||||
|
lines.append(f"**Reason:** {result.verdict_reasoning}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
lines.append(f"**{result.total_findings} findings** from {len(agent_results)} agents")
|
||||||
|
if result.conflicts:
|
||||||
|
lines.append(f"**{len(result.conflicts)} conflicts** detected")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Findings by severity
|
||||||
|
if result.findings:
|
||||||
|
lines.append("## Findings")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for finding in result.findings:
|
||||||
|
severity_badge = f"**{finding.severity.value.upper()}**"
|
||||||
|
lines.append(
|
||||||
|
f"- {severity_badge} [{finding.agent.value}] `{finding.file}:{finding.line_start}` - {finding.title}"
|
||||||
|
)
|
||||||
|
lines.append(f" - {finding.description}")
|
||||||
|
if finding.suggestion:
|
||||||
|
lines.append(f" - *Suggestion:* {finding.suggestion}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Conflicts
|
||||||
|
if result.conflicts:
|
||||||
|
lines.append("## Conflicts")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for conflict in result.conflicts:
|
||||||
|
lines.append(f"### {conflict.nature.value.title()}")
|
||||||
|
lines.append(conflict.description)
|
||||||
|
for resolution in result.resolutions:
|
||||||
|
if resolution.conflict_id == conflict.id:
|
||||||
|
lines.append(f"**Resolution:** {resolution.reasoning}")
|
||||||
|
break
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Deliberation log
|
||||||
|
lines.append("## Deliberation Log")
|
||||||
|
lines.append("")
|
||||||
|
for step in result.steps:
|
||||||
|
lines.append(f"1. **{step.step_type.value}**: {step.description}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
console.print("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_review(
|
||||||
|
diff: str,
|
||||||
|
policy: Policy,
|
||||||
|
templates_dir: Path,
|
||||||
|
static_analysis: bool = True,
|
||||||
|
work_dir: Path | None = None,
|
||||||
|
) -> tuple[list[ReviewResult], DeliberationResult]:
|
||||||
|
"""Run all enabled agents on the diff with deliberation."""
|
||||||
|
settings = get_settings()
|
||||||
|
llm_client = LiteLLMClient(
|
||||||
|
timeout=settings.llm_timeout,
|
||||||
|
max_retries=settings.llm_max_retries,
|
||||||
|
)
|
||||||
|
prompt_registry = PromptRegistry(templates_dir)
|
||||||
|
|
||||||
|
# Parse the diff
|
||||||
|
parser = DiffParser()
|
||||||
|
parsed_diff = parser.parse(diff)
|
||||||
|
|
||||||
|
# Run static analysis if enabled
|
||||||
|
static_findings: list[Finding] = []
|
||||||
|
if static_analysis and work_dir:
|
||||||
|
runner = StaticAnalysisRunner()
|
||||||
|
static_result = await runner.run(parsed_diff, work_dir)
|
||||||
|
for sf in static_result.findings:
|
||||||
|
static_findings.append(runner.convert_to_finding(sf))
|
||||||
|
|
||||||
|
context = ReviewContext(
|
||||||
|
diff=diff,
|
||||||
|
policy=policy,
|
||||||
|
static_analysis_context=f"Found {len(static_findings)} static analysis findings."
|
||||||
|
if static_findings
|
||||||
|
else "No static analysis data available.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agents for enabled agent types
|
||||||
|
agent_classes: dict[AgentName, type[Agent]] = {
|
||||||
|
AgentName.SECURITY: SecurityAgent,
|
||||||
|
AgentName.STYLE: StyleAgent,
|
||||||
|
AgentName.COMPLEXITY: ComplexityAgent,
|
||||||
|
}
|
||||||
|
|
||||||
|
agents: list[Agent] = []
|
||||||
|
for agent_name in policy.get_enabled_agents():
|
||||||
|
if agent_name in agent_classes:
|
||||||
|
agents.append(agent_classes[agent_name](llm_client, prompt_registry))
|
||||||
|
|
||||||
|
# Run agents in parallel
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[agent.review(context) for agent in agents],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out exceptions and log them
|
||||||
|
valid_results: list[ReviewResult] = []
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
console.print(f"[red]Agent error:[/red] {result}")
|
||||||
|
else:
|
||||||
|
valid_results.append(result)
|
||||||
|
|
||||||
|
# Run deliberation
|
||||||
|
coordinator = Coordinator(llm_client=llm_client)
|
||||||
|
deliberation_result = await coordinator.deliberate(
|
||||||
|
valid_results,
|
||||||
|
static_findings=static_findings if static_findings else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_results, deliberation_result
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def review(
|
||||||
|
diff_file: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(help="Path to diff file, or '-' for stdin"),
|
||||||
|
],
|
||||||
|
policy: Annotated[
|
||||||
|
Path | None,
|
||||||
|
typer.Option("--policy", "-p", help="Path to policy YAML file"),
|
||||||
|
] = None,
|
||||||
|
model: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option("--model", "-m", help="Override LLM model for all agents"),
|
||||||
|
] = None,
|
||||||
|
output_format: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--format", "-f", help="Output format: rich, json, markdown"),
|
||||||
|
] = "rich",
|
||||||
|
static_analysis: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option("--static-analysis/--no-static-analysis", help="Run static analysis"),
|
||||||
|
] = True,
|
||||||
|
work_dir: Annotated[
|
||||||
|
Path | None,
|
||||||
|
typer.Option("--work-dir", "-w", help="Working directory for static analysis"),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Review a diff file using AI agents."""
|
||||||
|
# Read diff content
|
||||||
|
if diff_file == "-":
|
||||||
|
diff_content = sys.stdin.read()
|
||||||
|
else:
|
||||||
|
diff_path = Path(diff_file)
|
||||||
|
if not diff_path.exists():
|
||||||
|
console.print(f"[red]Error:[/red] File not found: {diff_file}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
diff_content = diff_path.read_text()
|
||||||
|
|
||||||
|
if not diff_content.strip():
|
||||||
|
console.print("[yellow]Warning:[/yellow] Empty diff provided")
|
||||||
|
raise typer.Exit(0)
|
||||||
|
|
||||||
|
# Load policy
|
||||||
|
review_policy = Policy()
|
||||||
|
if policy:
|
||||||
|
if not policy.exists():
|
||||||
|
console.print(f"[red]Error:[/red] Policy file not found: {policy}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
review_policy = Policy.load(policy)
|
||||||
|
|
||||||
|
# Apply model override
|
||||||
|
if model:
|
||||||
|
for agent_config in review_policy.agents.values():
|
||||||
|
agent_config.model = model
|
||||||
|
|
||||||
|
# Determine prompts directory
|
||||||
|
settings = get_settings()
|
||||||
|
templates_dir = settings.templates_dir
|
||||||
|
if not templates_dir.is_absolute():
|
||||||
|
templates_dir = Path.cwd() / templates_dir
|
||||||
|
|
||||||
|
# Resolve work directory for static analysis
|
||||||
|
resolved_work_dir = (
|
||||||
|
(work_dir.resolve() if work_dir else Path.cwd()) if static_analysis else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run review
|
||||||
|
agent_results, deliberation_result = asyncio.run(
|
||||||
|
_run_review(
|
||||||
|
diff_content,
|
||||||
|
review_policy,
|
||||||
|
templates_dir,
|
||||||
|
static_analysis=static_analysis,
|
||||||
|
work_dir=resolved_work_dir,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
if output_format == "json":
|
||||||
|
_format_json(deliberation_result, agent_results)
|
||||||
|
elif output_format == "markdown":
|
||||||
|
_format_markdown(deliberation_result, agent_results)
|
||||||
|
else:
|
||||||
|
_format_rich(deliberation_result, agent_results)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def version() -> None:
|
||||||
|
from arbiter import __version__
|
||||||
|
|
||||||
|
console.print(f"arbiter {__version__}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
121
src/arbiter/config.py
Normal file
121
src/arbiter/config.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Configuration settings for Arbiter."""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import Field, SecretStr
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="ARBITER_",
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM settings
|
||||||
|
default_model: str = Field(default="gpt-4o", description="Default LLM model for agents")
|
||||||
|
llm_timeout: int = Field(default=60, description="LLM request timeout in seconds")
|
||||||
|
llm_max_retries: int = Field(default=3, description="Maximum LLM retry attempts")
|
||||||
|
|
||||||
|
# Cost controls
|
||||||
|
max_tokens_per_review: int = Field(
|
||||||
|
default=50000, description="Maximum tokens allowed per review"
|
||||||
|
)
|
||||||
|
max_cost_per_review_usd: float = Field(
|
||||||
|
default=0.50, description="Maximum cost per review in USD"
|
||||||
|
)
|
||||||
|
cache_ttl_hours: int = Field(default=24, description="Cache TTL in hours")
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
templates_dir: Path = Field(
|
||||||
|
default=Path("templates"), description="Directory containing prompt templates"
|
||||||
|
)
|
||||||
|
policy_path: Path | None = Field(default=None, description="Default policy file path")
|
||||||
|
|
||||||
|
# Output settings
|
||||||
|
output_format: str = Field(default="rich", description="Output format: rich, json, or markdown")
|
||||||
|
|
||||||
|
# Database settings
|
||||||
|
database_url: str = Field(
|
||||||
|
default="postgresql+asyncpg://arbiter:arbiter@localhost:5432/arbiter",
|
||||||
|
description="PostgreSQL connection URL",
|
||||||
|
)
|
||||||
|
database_pool_size: int = Field(default=5, description="Database connection pool size")
|
||||||
|
database_max_overflow: int = Field(default=10, description="Max overflow connections")
|
||||||
|
|
||||||
|
# Redis settings
|
||||||
|
redis_url: str = Field(
|
||||||
|
default="redis://localhost:6379/0",
|
||||||
|
description="Redis connection URL",
|
||||||
|
)
|
||||||
|
redis_max_connections: int = Field(default=10, description="Redis max connections")
|
||||||
|
|
||||||
|
# Webhook secrets
|
||||||
|
github_webhook_secret: SecretStr | None = Field(
|
||||||
|
default=None, description="GitHub webhook secret for HMAC verification"
|
||||||
|
)
|
||||||
|
gitlab_webhook_token: SecretStr | None = Field(
|
||||||
|
default=None, description="GitLab webhook token for verification"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Platform integration settings
|
||||||
|
github_token: SecretStr | None = Field(
|
||||||
|
default=None, description="GitHub API token for fetching diffs and posting comments"
|
||||||
|
)
|
||||||
|
github_base_url: str = Field(
|
||||||
|
default="https://api.github.com", description="GitHub API base URL"
|
||||||
|
)
|
||||||
|
gitlab_token: SecretStr | None = Field(
|
||||||
|
default=None, description="GitLab API token for fetching diffs and posting comments"
|
||||||
|
)
|
||||||
|
gitlab_base_url: str = Field(
|
||||||
|
default="https://gitlab.com", description="GitLab instance base URL"
|
||||||
|
)
|
||||||
|
integration_timeout: int = Field(
|
||||||
|
default=30, description="Integration API request timeout in seconds"
|
||||||
|
)
|
||||||
|
integration_max_retries: int = Field(
|
||||||
|
default=3, description="Maximum retry attempts for integration API calls"
|
||||||
|
)
|
||||||
|
status_check_context: str = Field(
|
||||||
|
default="arbiter", description="Context name for commit status checks"
|
||||||
|
)
|
||||||
|
post_comments: bool = Field(default=True, description="Whether to post review comments on PRs")
|
||||||
|
update_status: bool = Field(default=True, description="Whether to update commit status checks")
|
||||||
|
|
||||||
|
# API settings
|
||||||
|
api_title: str = Field(default="Arbiter API", description="API title for OpenAPI")
|
||||||
|
api_version: str = Field(default="0.5.0", description="API version")
|
||||||
|
cors_origins: list[str] = Field(
|
||||||
|
default=["http://localhost:3000"], description="Allowed CORS origins"
|
||||||
|
)
|
||||||
|
api_rate_limit_per_minute: int = Field(
|
||||||
|
default=60, description="API rate limit per minute per client"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Worker settings
|
||||||
|
worker_max_jobs: int = Field(default=10, description="Max concurrent worker jobs")
|
||||||
|
worker_job_timeout: int = Field(default=300, description="Job timeout in seconds")
|
||||||
|
worker_retry_attempts: int = Field(default=3, description="Number of retry attempts")
|
||||||
|
|
||||||
|
# Follow-up conversation settings
|
||||||
|
followup_enabled: bool = Field(default=True, description="Enable follow-up question handling")
|
||||||
|
followup_confidence_threshold: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum confidence to process a follow-up question",
|
||||||
|
)
|
||||||
|
followup_max_tokens_per_response: int = Field(
|
||||||
|
default=2000, description="Maximum tokens per follow-up response"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
12
src/arbiter/llm/__init__.py
Normal file
12
src/arbiter/llm/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""LLM client and prompt management."""
|
||||||
|
|
||||||
|
from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse
|
||||||
|
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMClient",
|
||||||
|
"LLMResponse",
|
||||||
|
"LiteLLMClient",
|
||||||
|
"PromptRegistry",
|
||||||
|
"PromptTemplate",
|
||||||
|
]
|
||||||
66
src/arbiter/llm/client.py
Normal file
66
src/arbiter/llm/client.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""LLM client abstraction."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResponse(BaseModel):
|
||||||
|
"""Response from an LLM completion."""
|
||||||
|
|
||||||
|
content: str = Field(description="The generated text content")
|
||||||
|
model: str = Field(description="Model that generated the response")
|
||||||
|
tokens_in: int = Field(ge=0, description="Number of input tokens")
|
||||||
|
tokens_out: int = Field(ge=0, description="Number of output tokens")
|
||||||
|
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(ABC):
|
||||||
|
"""Abstract base class for LLM clients."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse: ...
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMClient(LLMClient):
|
||||||
|
"""LLM client implementation using LiteLLM."""
|
||||||
|
|
||||||
|
def __init__(self, timeout: int = 60, max_retries: int = 3) -> None:
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Generate a completion using LiteLLM."""
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
timeout=self.timeout,
|
||||||
|
num_retries=self.max_retries,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content or ""
|
||||||
|
usage = response.usage
|
||||||
|
|
||||||
|
# Calculate cost using litellm's cost tracking
|
||||||
|
cost = litellm.completion_cost(completion_response=response)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=response.model or model,
|
||||||
|
tokens_in=usage.prompt_tokens if usage else 0,
|
||||||
|
tokens_out=usage.completion_tokens if usage else 0,
|
||||||
|
cost_usd=cost,
|
||||||
|
)
|
||||||
81
src/arbiter/llm/prompts.py
Normal file
81
src/arbiter/llm/prompts.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Prompt template management."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplate(BaseModel):
|
||||||
|
"""A versioned prompt template."""
|
||||||
|
|
||||||
|
name: str = Field(description="Template name (e.g., 'security')")
|
||||||
|
version: str = Field(description="Semantic version (e.g., '1.0')")
|
||||||
|
content: str = Field(description="Raw template content with placeholders")
|
||||||
|
|
||||||
|
def render(self, **kwargs: Any) -> str:
|
||||||
|
result = self.content
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
placeholder = f"{{{{{key}}}}}"
|
||||||
|
result = result.replace(placeholder, str(value))
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_name(self) -> str:
|
||||||
|
return f"{self.name}-v{self.version}"
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRegistry:
|
||||||
|
"""Registry for loading and caching prompt templates."""
|
||||||
|
|
||||||
|
_cache: ClassVar[dict[str, PromptTemplate]] = {}
|
||||||
|
_pattern: ClassVar[re.Pattern[str]] = re.compile(r"^(.+)-v(\d+\.\d+)\.md$")
|
||||||
|
|
||||||
|
def __init__(self, templates_dir: Path) -> None:
|
||||||
|
self.templates_dir = templates_dir
|
||||||
|
|
||||||
|
def get(self, name: str, version: str) -> PromptTemplate:
|
||||||
|
"""Get a prompt template by name and version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Template name (e.g., 'security').
|
||||||
|
version: Template version (e.g., '1.0').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PromptTemplate instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the template file doesn't exist.
|
||||||
|
"""
|
||||||
|
cache_key = f"{name}-v{version}"
|
||||||
|
if cache_key in self._cache:
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
file_path = self.templates_dir / f"{name}-v{version}.md"
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"Prompt template not found: {file_path}")
|
||||||
|
|
||||||
|
content = file_path.read_text()
|
||||||
|
template = PromptTemplate(name=name, version=version, content=content)
|
||||||
|
self._cache[cache_key] = template
|
||||||
|
return template
|
||||||
|
|
||||||
|
def list_templates(self) -> list[tuple[str, str]]:
|
||||||
|
"""List all available templates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (name, version) tuples.
|
||||||
|
"""
|
||||||
|
templates: list[tuple[str, str]] = []
|
||||||
|
if not self.templates_dir.exists():
|
||||||
|
return templates
|
||||||
|
|
||||||
|
for file_path in self.templates_dir.glob("*.md"):
|
||||||
|
match = self._pattern.match(file_path.name)
|
||||||
|
if match:
|
||||||
|
templates.append((match.group(1), match.group(2)))
|
||||||
|
return sorted(templates)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
self._cache.clear()
|
||||||
16
src/arbiter/models/__init__.py
Normal file
16
src/arbiter/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Arbiter data models."""
|
||||||
|
|
||||||
|
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||||
|
from arbiter.models.finding import Finding
|
||||||
|
from arbiter.models.policy import AgentConfig, Policy
|
||||||
|
from arbiter.models.review import ReviewResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentConfig",
|
||||||
|
"AgentName",
|
||||||
|
"Finding",
|
||||||
|
"Policy",
|
||||||
|
"ReviewResult",
|
||||||
|
"Severity",
|
||||||
|
"Verdict",
|
||||||
|
]
|
||||||
29
src/arbiter/models/enums.py
Normal file
29
src/arbiter/models/enums.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""Core enumerations for Arbiter."""
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class Severity(StrEnum):
|
||||||
|
"""Finding severity levels."""
|
||||||
|
|
||||||
|
CRITICAL = "critical"
|
||||||
|
HIGH = "high"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
LOW = "low"
|
||||||
|
INFO = "info"
|
||||||
|
|
||||||
|
|
||||||
|
class Verdict(StrEnum):
|
||||||
|
"""Review verdict types."""
|
||||||
|
|
||||||
|
APPROVE = "approve"
|
||||||
|
REQUEST_CHANGES = "request_changes"
|
||||||
|
COMMENT = "comment"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentName(StrEnum):
|
||||||
|
"""Available review agent names."""
|
||||||
|
|
||||||
|
SECURITY = "security"
|
||||||
|
STYLE = "style"
|
||||||
|
COMPLEXITY = "complexity"
|
||||||
28
src/arbiter/models/finding.py
Normal file
28
src/arbiter/models/finding.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Finding model for review results."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from arbiter.models.enums import AgentName, Severity
|
||||||
|
|
||||||
|
|
||||||
|
class Finding(BaseModel):
|
||||||
|
"""A single finding from an agent's review."""
|
||||||
|
|
||||||
|
id: str = Field(description="Unique identifier for this finding")
|
||||||
|
agent: AgentName = Field(description="Agent that produced this finding")
|
||||||
|
file: str = Field(description="File path where the issue was found")
|
||||||
|
line_start: int = Field(ge=1, description="Starting line number")
|
||||||
|
line_end: int = Field(ge=1, description="Ending line number")
|
||||||
|
severity: Severity = Field(description="Severity level of the finding")
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0, description="Confidence score (0.0-1.0)")
|
||||||
|
title: str = Field(description="Short title summarizing the finding")
|
||||||
|
description: str = Field(description="Detailed description of the issue")
|
||||||
|
reasoning: str = Field(description="Explanation of why this was flagged")
|
||||||
|
suggestion: str | None = Field(default=None, description="Optional fix recommendation")
|
||||||
|
references: list[str] = Field(default_factory=list, description="Links to docs, OWASP, etc.")
|
||||||
|
prompt_version: str = Field(description="Version of the prompt that produced this finding")
|
||||||
|
static_analysis_context: dict[str, Any] | None = Field(
|
||||||
|
default=None, description="Related static analysis data"
|
||||||
|
)
|
||||||
58
src/arbiter/models/policy.py
Normal file
58
src/arbiter/models/policy.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Policy configuration models."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from arbiter.models.enums import AgentName, Severity
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Configuration for a single agent."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=True, description="Whether the agent is enabled")
|
||||||
|
model: str | None = Field(default=None, description="LLM model override for this agent")
|
||||||
|
severity_threshold: Severity = Field(
|
||||||
|
default=Severity.INFO, description="Minimum severity to report"
|
||||||
|
)
|
||||||
|
prompt_additions: str | None = Field(
|
||||||
|
default=None, description="Additional instructions to append to the prompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Policy(BaseModel):
|
||||||
|
"""Review policy configuration."""
|
||||||
|
|
||||||
|
version: str = Field(default="1.0", description="Policy schema version")
|
||||||
|
agents: dict[AgentName, AgentConfig] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
AgentName.SECURITY: AgentConfig(),
|
||||||
|
AgentName.STYLE: AgentConfig(),
|
||||||
|
AgentName.COMPLEXITY: AgentConfig(),
|
||||||
|
},
|
||||||
|
description="Agent configurations",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: Path) -> Self:
|
||||||
|
"""Load policy from a YAML file."""
|
||||||
|
with path.open() as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
# Convert string keys to AgentName enums
|
||||||
|
if "agents" in data and isinstance(data["agents"], dict):
|
||||||
|
agents = {}
|
||||||
|
for key, value in data["agents"].items():
|
||||||
|
agent_name = AgentName(key) if isinstance(key, str) else key
|
||||||
|
agents[agent_name] = AgentConfig(**value) if isinstance(value, dict) else value
|
||||||
|
data["agents"] = agents
|
||||||
|
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def get_enabled_agents(self) -> list[AgentName]:
|
||||||
|
return [name for name, config in self.agents.items() if config.enabled]
|
||||||
16
src/arbiter/models/review.py
Normal file
16
src/arbiter/models/review.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Review result models."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from arbiter.models.enums import AgentName
|
||||||
|
from arbiter.models.finding import Finding
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewResult(BaseModel):
|
||||||
|
"""Result from a single agent's review."""
|
||||||
|
|
||||||
|
agent_name: AgentName = Field(description="Name of the agent that performed the review")
|
||||||
|
findings: list[Finding] = Field(default_factory=list, description="Findings from the review")
|
||||||
|
duration_ms: int = Field(ge=0, description="Time taken in milliseconds")
|
||||||
|
tokens_used: int = Field(ge=0, description="Total tokens consumed")
|
||||||
|
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")
|
||||||
0
src/arbiter/py.typed
Normal file
0
src/arbiter/py.typed
Normal file
56
templates/complexity-v1.0.md
Normal file
56
templates/complexity-v1.0.md
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Complexity Review Agent
|
||||||
|
|
||||||
|
You are a code complexity reviewer focused on architecture and maintainability. Analyze the provided diff for complexity issues that impact long-term code health.
|
||||||
|
|
||||||
|
## Focus Areas
|
||||||
|
|
||||||
|
- **Cyclomatic complexity**: Functions with too many branches or paths
|
||||||
|
- **Cognitive complexity**: Code that is hard to understand or follow
|
||||||
|
- **Function length**: Functions doing too many things
|
||||||
|
- **Class design**: God objects, tight coupling, missing abstractions
|
||||||
|
- **Dependency management**: Circular dependencies, excessive coupling
|
||||||
|
- **Over-engineering**: Unnecessary abstractions, premature optimization
|
||||||
|
- **Under-engineering**: Missing error handling, ignored edge cases
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
{{static_analysis_context}}
|
||||||
|
|
||||||
|
## Diff to Review
|
||||||
|
|
||||||
|
```diff
|
||||||
|
{{diff}}
|
||||||
|
```
|
||||||
|
|
||||||
|
{{prompt_additions}}
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
Respond with a JSON array of findings. Each finding must have this structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"file": "path/to/file.py",
|
||||||
|
"line_start": 10,
|
||||||
|
"line_end": 50,
|
||||||
|
"severity": "critical|high|medium|low|info",
|
||||||
|
"confidence": 0.80,
|
||||||
|
"title": "Short title describing the issue",
|
||||||
|
"description": "Detailed description of the complexity concern",
|
||||||
|
"reasoning": "Why this complexity is problematic",
|
||||||
|
"suggestion": "How to simplify or refactor (optional)",
|
||||||
|
"references": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
If no complexity issues are found, return an empty array: `[]`
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
1. Consider context - complex code may be justified for complex problems
|
||||||
|
2. Flag over-engineering as readily as under-engineering
|
||||||
|
3. High complexity is only critical if it's likely to cause bugs
|
||||||
|
4. Suggest specific refactoring strategies when possible
|
||||||
|
5. Reference static analysis metrics (cyclomatic complexity, etc.) when available
|
||||||
55
templates/security-v1.0.md
Normal file
55
templates/security-v1.0.md
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Security Review Agent
|
||||||
|
|
||||||
|
You are a security-focused code reviewer. Analyze the provided diff for security vulnerabilities and potential risks.
|
||||||
|
|
||||||
|
## Focus Areas
|
||||||
|
|
||||||
|
- **Injection vulnerabilities**: SQL injection, command injection, XSS, template injection
|
||||||
|
- **Authentication/Authorization**: Missing auth checks, privilege escalation, insecure session handling
|
||||||
|
- **Data exposure**: Hardcoded secrets, PII leaks, sensitive data in logs
|
||||||
|
- **Cryptographic issues**: Weak algorithms, improper key management, missing encryption
|
||||||
|
- **Input validation**: Missing or insufficient validation, type confusion
|
||||||
|
- **OWASP Top 10**: All categories including broken access control, security misconfiguration
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
{{static_analysis_context}}
|
||||||
|
|
||||||
|
## Diff to Review
|
||||||
|
|
||||||
|
```diff
|
||||||
|
{{diff}}
|
||||||
|
```
|
||||||
|
|
||||||
|
{{prompt_additions}}
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
Respond with a JSON array of findings. Each finding must have this structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"file": "path/to/file.py",
|
||||||
|
"line_start": 10,
|
||||||
|
"line_end": 15,
|
||||||
|
"severity": "critical|high|medium|low|info",
|
||||||
|
"confidence": 0.95,
|
||||||
|
"title": "Short title describing the issue",
|
||||||
|
"description": "Detailed description of the vulnerability",
|
||||||
|
"reasoning": "Why this is a security concern",
|
||||||
|
"suggestion": "How to fix this issue (optional)",
|
||||||
|
"references": ["https://owasp.org/..."]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
If no security issues are found, return an empty array: `[]`
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
1. Only report genuine security concerns, not style or performance issues
|
||||||
|
2. Assign appropriate severity based on exploitability and impact
|
||||||
|
3. Set confidence based on how certain you are this is a real vulnerability
|
||||||
|
4. Provide actionable suggestions when possible
|
||||||
|
5. Include relevant OWASP or CWE references
|
||||||
55
templates/style-v1.0.md
Normal file
55
templates/style-v1.0.md
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Style Review Agent
|
||||||
|
|
||||||
|
You are a code style reviewer focused on readability and consistency. Analyze the provided diff for style issues that impact code maintainability.
|
||||||
|
|
||||||
|
## Focus Areas
|
||||||
|
|
||||||
|
- **Naming conventions**: Variable, function, class, and file naming consistency
|
||||||
|
- **Code organisation**: Logical grouping, import ordering, module structure
|
||||||
|
- **Readability**: Clear variable names, appropriate comments, self-documenting code
|
||||||
|
- **Consistency**: Adherence to existing patterns in the codebase
|
||||||
|
- **Best practices**: Language-specific idioms and conventions
|
||||||
|
- **Documentation**: Missing or outdated docstrings, misleading comments
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
{{static_analysis_context}}
|
||||||
|
|
||||||
|
## Diff to Review
|
||||||
|
|
||||||
|
```diff
|
||||||
|
{{diff}}
|
||||||
|
```
|
||||||
|
|
||||||
|
{{prompt_additions}}
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
Respond with a JSON array of findings. Each finding must have this structure:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"file": "path/to/file.py",
|
||||||
|
"line_start": 10,
|
||||||
|
"line_end": 15,
|
||||||
|
"severity": "critical|high|medium|low|info",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"title": "Short title describing the issue",
|
||||||
|
"description": "Detailed description of the style concern",
|
||||||
|
"reasoning": "Why this matters for maintainability",
|
||||||
|
"suggestion": "How to improve this (optional)",
|
||||||
|
"references": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
If no style issues are found, return an empty array: `[]`
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
1. Focus on readability and maintainability, not personal preferences
|
||||||
|
2. Respect existing codebase conventions even if they differ from common standards
|
||||||
|
3. Most style issues should be low or info severity unless they significantly impact readability
|
||||||
|
4. Only flag high severity for style issues that could cause confusion or bugs
|
||||||
|
5. Provide concrete suggestions with example code when helpful
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Arbiter test suite."""
|
||||||
490
tests/conftest.py
Normal file
490
tests/conftest.py
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
"""Pytest configuration and fixtures."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from arbiter.db.models import Base, FindingModel, ReviewModel
|
||||||
|
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
|
||||||
|
from arbiter.llm.client import LLMClient, LLMResponse
|
||||||
|
from arbiter.llm.prompts import PromptRegistry
|
||||||
|
from arbiter.models import Policy
|
||||||
|
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMClient(LLMClient):
|
||||||
|
"""Mock LLM client for testing."""
|
||||||
|
|
||||||
|
def __init__(self, responses: list[str] | None = None) -> None:
|
||||||
|
self.responses = responses or []
|
||||||
|
self.calls: list[dict[str, Any]] = []
|
||||||
|
self._call_index = 0
|
||||||
|
|
||||||
|
async def complete(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Record the call and return a canned response."""
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"messages": messages,
|
||||||
|
"model": model,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
if self._call_index < len(self.responses):
|
||||||
|
content = self.responses[self._call_index]
|
||||||
|
self._call_index += 1
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model,
|
||||||
|
tokens_in=100,
|
||||||
|
tokens_out=50,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.calls = []
|
||||||
|
self._call_index = 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm() -> MockLLMClient:
|
||||||
|
return MockLLMClient()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_with_findings() -> MockLLMClient:
|
||||||
|
response = """```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"file": "src/auth.py",
|
||||||
|
"line_start": 10,
|
||||||
|
"line_end": 15,
|
||||||
|
"severity": "high",
|
||||||
|
"confidence": 0.9,
|
||||||
|
"title": "SQL Injection vulnerability",
|
||||||
|
"description": "User input is directly concatenated into SQL query",
|
||||||
|
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||||
|
"suggestion": "Use parameterized queries instead",
|
||||||
|
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```"""
|
||||||
|
return MockLLMClient(responses=[response])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_diff() -> str:
|
||||||
|
return """diff --git a/src/auth.py b/src/auth.py
|
||||||
|
index 1234567..abcdefg 100644
|
||||||
|
--- a/src/auth.py
|
||||||
|
+++ b/src/auth.py
|
||||||
|
@@ -8,6 +8,12 @@ def authenticate(username, password):
|
||||||
|
if not username or not password:
|
||||||
|
return None
|
||||||
|
|
||||||
|
+ # Check user in database
|
||||||
|
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||||
|
+ cursor.execute(query)
|
||||||
|
+ user = cursor.fetchone()
|
||||||
|
+ if user and user.password == password:
|
||||||
|
+ return user
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_policy() -> Policy:
|
||||||
|
return Policy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def prompt_registry(tmp_path: Path) -> PromptRegistry:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
|
||||||
|
(templates_dir / "security-v1.0.md").write_text(
|
||||||
|
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||||
|
)
|
||||||
|
(templates_dir / "style-v1.0.md").write_text(
|
||||||
|
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||||
|
)
|
||||||
|
(templates_dir / "complexity-v1.0.md").write_text(
|
||||||
|
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return PromptRegistry(templates_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# Database fixtures for integration tests
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_engine():
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
session_factory = async_sessionmaker(
|
||||||
|
bind=async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
async with session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
from arbiter.api.deps import get_db, get_redis
|
||||||
|
from arbiter.main import app
|
||||||
|
|
||||||
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
|
||||||
|
yield MockRedis()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_redis] = override_get_redis
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class MockRedis:
|
||||||
|
"""Mock Redis client for testing."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
return self._data.get(key)
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||||
|
self._data[key] = value
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> int:
|
||||||
|
if key in self._data:
|
||||||
|
del self._data[key]
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def llen(self, key: str) -> int:
|
||||||
|
data = self._data.get(key, [])
|
||||||
|
return len(data) if isinstance(data, list) else 0
|
||||||
|
|
||||||
|
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
|
||||||
|
return type("Job", (), {"job_id": "test-job-id"})()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis() -> MockRedis:
|
||||||
|
return MockRedis()
|
||||||
|
|
||||||
|
|
||||||
|
class MockPlatformClient(PlatformClient):
|
||||||
|
"""Mock platform client for testing."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._comments: list[Comment] = []
|
||||||
|
self._posted_comments: list[dict[str, Any]] = []
|
||||||
|
self._status_updates: list[dict[str, Any]] = []
|
||||||
|
self._diff = "mock diff content"
|
||||||
|
self._closed = False
|
||||||
|
self._fail_on: set[str] = set() # Methods to fail on
|
||||||
|
|
||||||
|
@property
|
||||||
|
def platform(self) -> Platform:
|
||||||
|
return Platform.GITHUB
|
||||||
|
|
||||||
|
async def get_pr_diff(
|
||||||
|
self,
|
||||||
|
repository: str, # noqa: ARG002
|
||||||
|
pr_number: int, # noqa: ARG002
|
||||||
|
) -> str:
|
||||||
|
if "get_pr_diff" in self._fail_on:
|
||||||
|
from arbiter.integrations import IntegrationError
|
||||||
|
|
||||||
|
raise IntegrationError("Mock failure: get_pr_diff")
|
||||||
|
return self._diff
|
||||||
|
|
||||||
|
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
||||||
|
if "post_comment" in self._fail_on:
|
||||||
|
from arbiter.integrations import IntegrationError
|
||||||
|
|
||||||
|
raise IntegrationError("Mock failure: post_comment")
|
||||||
|
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
|
||||||
|
self._posted_comments.append(
|
||||||
|
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
|
||||||
|
)
|
||||||
|
return comment_url
|
||||||
|
|
||||||
|
async def update_commit_status(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
sha: str,
|
||||||
|
status: CommitStatus,
|
||||||
|
description: str,
|
||||||
|
context: str,
|
||||||
|
target_url: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if "update_commit_status" in self._fail_on:
|
||||||
|
from arbiter.integrations import IntegrationError
|
||||||
|
|
||||||
|
raise IntegrationError("Mock failure: update_commit_status")
|
||||||
|
self._status_updates.append(
|
||||||
|
{
|
||||||
|
"repository": repository,
|
||||||
|
"sha": sha,
|
||||||
|
"status": status,
|
||||||
|
"description": description,
|
||||||
|
"context": context,
|
||||||
|
"target_url": target_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
|
||||||
|
if "get_pr_info" in self._fail_on:
|
||||||
|
from arbiter.integrations import IntegrationError
|
||||||
|
|
||||||
|
raise IntegrationError("Mock failure: get_pr_info")
|
||||||
|
from arbiter.integrations.base import PullRequestInfo
|
||||||
|
|
||||||
|
return PullRequestInfo(
|
||||||
|
platform=Platform.GITHUB,
|
||||||
|
repository=repository,
|
||||||
|
pr_number=pr_number,
|
||||||
|
head_sha="abc123",
|
||||||
|
base_sha="def456",
|
||||||
|
head_ref="feature",
|
||||||
|
base_ref="main",
|
||||||
|
title="Test PR",
|
||||||
|
author="testuser",
|
||||||
|
url=f"https://github.com/{repository}/pull/{pr_number}",
|
||||||
|
is_draft=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_comments(
|
||||||
|
self,
|
||||||
|
repository: str, # noqa: ARG002
|
||||||
|
pr_number: int, # noqa: ARG002
|
||||||
|
) -> list[Comment]:
|
||||||
|
if "get_comments" in self._fail_on:
|
||||||
|
from arbiter.integrations import IntegrationError
|
||||||
|
|
||||||
|
raise IntegrationError("Mock failure: get_comments")
|
||||||
|
return self._comments
|
||||||
|
|
||||||
|
async def update_comment(
|
||||||
|
self, repository: str, pr_number: int, comment_id: str, body: str
|
||||||
|
) -> str:
|
||||||
|
if "update_comment" in self._fail_on:
|
||||||
|
from arbiter.integrations import IntegrationError
|
||||||
|
|
||||||
|
raise IntegrationError("Mock failure: update_comment")
|
||||||
|
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
|
||||||
|
self._posted_comments.append(
|
||||||
|
{
|
||||||
|
"repository": repository,
|
||||||
|
"pr_number": pr_number,
|
||||||
|
"comment_id": comment_id,
|
||||||
|
"body": body,
|
||||||
|
"url": comment_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return comment_url
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_platform_client() -> MockPlatformClient:
|
||||||
|
return MockPlatformClient()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
|
||||||
|
review = ReviewModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
pr_title="Test PR",
|
||||||
|
base_sha="abc1234567890",
|
||||||
|
head_sha="def0987654321",
|
||||||
|
author="testuser",
|
||||||
|
is_draft=False,
|
||||||
|
status="completed",
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.75,
|
||||||
|
verdict_reasoning="Found some issues to discuss",
|
||||||
|
total_tokens=1500,
|
||||||
|
total_cost_usd=0.015,
|
||||||
|
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
|
||||||
|
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
started_at=datetime.now(UTC),
|
||||||
|
completed_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
db_session.add(review)
|
||||||
|
|
||||||
|
# Add findings
|
||||||
|
finding1 = FindingModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review.id,
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="src/auth.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="SQL Injection vulnerability",
|
||||||
|
description="User input is directly concatenated into SQL query",
|
||||||
|
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||||
|
suggestion="Use parameterized queries instead",
|
||||||
|
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
|
||||||
|
prompt_version="security-v1.0",
|
||||||
|
)
|
||||||
|
finding2 = FindingModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review.id,
|
||||||
|
agent=AgentName.STYLE,
|
||||||
|
file="src/auth.py",
|
||||||
|
line_start=20,
|
||||||
|
line_end=25,
|
||||||
|
severity=Severity.MEDIUM,
|
||||||
|
confidence=0.85,
|
||||||
|
title="Long function",
|
||||||
|
description="Function exceeds 50 lines",
|
||||||
|
reasoning="Long functions are harder to test and maintain",
|
||||||
|
suggestion="Consider breaking into smaller functions",
|
||||||
|
references=[],
|
||||||
|
prompt_version="style-v1.0",
|
||||||
|
)
|
||||||
|
db_session.add(finding1)
|
||||||
|
db_session.add(finding2)
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
await db_session.refresh(review)
|
||||||
|
return review
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
|
||||||
|
reviews = []
|
||||||
|
for i in range(5):
|
||||||
|
review = ReviewModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
repository="owner/repo" if i < 3 else "other/repo",
|
||||||
|
pr_number=i + 1,
|
||||||
|
pr_title=f"Test PR #{i + 1}",
|
||||||
|
base_sha=f"base{i:040d}",
|
||||||
|
head_sha=f"head{i:040d}",
|
||||||
|
author="testuser" if i % 2 == 0 else "otheruser",
|
||||||
|
is_draft=False,
|
||||||
|
status="completed" if i < 4 else "failed",
|
||||||
|
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
|
||||||
|
verdict_confidence=0.8 if i < 4 else None,
|
||||||
|
total_tokens=1000 * (i + 1),
|
||||||
|
total_cost_usd=0.01 * (i + 1),
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
completed_at=datetime.now(UTC) if i < 4 else None,
|
||||||
|
)
|
||||||
|
db_session.add(review)
|
||||||
|
reviews.append(review)
|
||||||
|
|
||||||
|
# Add a finding to each completed review
|
||||||
|
if i < 4:
|
||||||
|
finding = FindingModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review.id,
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="src/test.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title=f"Finding {i + 1}",
|
||||||
|
description="Test finding",
|
||||||
|
reasoning="Test reasoning",
|
||||||
|
prompt_version="security-v1.0",
|
||||||
|
)
|
||||||
|
db_session.add(finding)
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
for review in reviews:
|
||||||
|
await db_session.refresh(review)
|
||||||
|
return reviews
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings() -> Any:
|
||||||
|
class MockSecretStr:
|
||||||
|
def __init__(self, value: str) -> None:
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
def get_secret_value(self) -> str:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
class MockSettings:
|
||||||
|
github_token = MockSecretStr("ghp_test_token")
|
||||||
|
gitlab_token = MockSecretStr("glpat_test_token")
|
||||||
|
github_base_url = "https://api.github.com"
|
||||||
|
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||||
|
github_webhook_secret = MockSecretStr("webhook_secret")
|
||||||
|
gitlab_webhook_token = MockSecretStr("gitlab_token")
|
||||||
|
integration_timeout = 30
|
||||||
|
integration_max_retries = 3
|
||||||
|
llm_timeout = 60
|
||||||
|
llm_max_retries = 3
|
||||||
|
post_comments = True
|
||||||
|
update_status = True
|
||||||
|
status_check_context = "arbiter"
|
||||||
|
templates_dir = Path("templates")
|
||||||
|
followup_enabled = True
|
||||||
|
followup_confidence_threshold = 0.5
|
||||||
|
|
||||||
|
return MockSettings()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings_no_github() -> Any:
|
||||||
|
class MockSettings:
|
||||||
|
github_token = None
|
||||||
|
gitlab_token = None
|
||||||
|
github_base_url = "https://api.github.com"
|
||||||
|
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||||
|
integration_timeout = 30
|
||||||
|
integration_max_retries = 3
|
||||||
|
post_comments = False
|
||||||
|
update_status = False
|
||||||
|
status_check_context = "arbiter"
|
||||||
|
templates_dir = Path("templates")
|
||||||
|
|
||||||
|
return MockSettings()
|
||||||
69
tests/fixtures/complex-function.diff
vendored
Normal file
69
tests/fixtures/complex-function.diff
vendored
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
diff --git a/src/processor.py b/src/processor.py
|
||||||
|
index 1234567..abcdefg 100644
|
||||||
|
--- a/src/processor.py
|
||||||
|
+++ b/src/processor.py
|
||||||
|
@@ -1,5 +1,65 @@
|
||||||
|
"""Data processor module."""
|
||||||
|
|
||||||
|
|
||||||
|
+def process_data(data: dict, config: dict, options: dict | None = None) -> dict:
|
||||||
|
+ """Process data with many nested conditions."""
|
||||||
|
+ result = {}
|
||||||
|
+ options = options or {}
|
||||||
|
+
|
||||||
|
+ if data.get("type") == "A":
|
||||||
|
+ if config.get("mode") == "strict":
|
||||||
|
+ if options.get("validate"):
|
||||||
|
+ if data.get("value") > 100:
|
||||||
|
+ if config.get("transform"):
|
||||||
|
+ result["processed"] = data["value"] * 2
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = data["value"]
|
||||||
|
+ else:
|
||||||
|
+ if options.get("default"):
|
||||||
|
+ result["processed"] = options["default"]
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = 0
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = data.get("value", 0)
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = data.get("value", 0)
|
||||||
|
+ elif data.get("type") == "B":
|
||||||
|
+ if config.get("mode") == "strict":
|
||||||
|
+ if options.get("validate"):
|
||||||
|
+ if data.get("items"):
|
||||||
|
+ result["processed"] = len(data["items"])
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = 0
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = len(data.get("items", []))
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = len(data.get("items", []))
|
||||||
|
+ elif data.get("type") == "C":
|
||||||
|
+ if config.get("mode") == "strict":
|
||||||
|
+ if options.get("validate"):
|
||||||
|
+ if data.get("text"):
|
||||||
|
+ result["processed"] = data["text"].upper()
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = ""
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = data.get("text", "").upper()
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = data.get("text", "").upper()
|
||||||
|
+ else:
|
||||||
|
+ if config.get("fallback"):
|
||||||
|
+ result["processed"] = config["fallback"]
|
||||||
|
+ else:
|
||||||
|
+ result["processed"] = None
|
||||||
|
+
|
||||||
|
+ if options.get("timestamp"):
|
||||||
|
+ result["timestamp"] = options["timestamp"]
|
||||||
|
+ if options.get("source"):
|
||||||
|
+ result["source"] = options["source"]
|
||||||
|
+
|
||||||
|
+ return result
|
||||||
|
+
|
||||||
|
+
|
||||||
|
def simple_function(x: int) -> int:
|
||||||
|
"""A simple function."""
|
||||||
|
return x * 2
|
||||||
31
tests/fixtures/security-issue.diff
vendored
Normal file
31
tests/fixtures/security-issue.diff
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
diff --git a/src/auth.py b/src/auth.py
|
||||||
|
index 1234567..abcdefg 100644
|
||||||
|
--- a/src/auth.py
|
||||||
|
+++ b/src/auth.py
|
||||||
|
@@ -1,10 +1,25 @@
|
||||||
|
"""Authentication module."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
+import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_user(username: str) -> dict | None:
|
||||||
|
"""Get user from database."""
|
||||||
|
conn = sqlite3.connect("users.db")
|
||||||
|
cursor = conn.cursor()
|
||||||
|
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
||||||
|
+ # FIXME: this is vulnerable to SQL injection
|
||||||
|
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||||
|
+ cursor.execute(query)
|
||||||
|
return cursor.fetchone()
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+def run_command(cmd: str) -> str:
|
||||||
|
+ """Run a shell command."""
|
||||||
|
+ # Command injection vulnerability
|
||||||
|
+ return os.popen(cmd).read()
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+# Hardcoded credentials
|
||||||
|
+API_KEY = "sk-1234567890abcdef"
|
||||||
|
+DB_PASSWORD = "admin123"
|
||||||
16
tests/fixtures/simple.diff
vendored
Normal file
16
tests/fixtures/simple.diff
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
diff --git a/src/utils.py b/src/utils.py
|
||||||
|
index 1234567..abcdefg 100644
|
||||||
|
--- a/src/utils.py
|
||||||
|
+++ b/src/utils.py
|
||||||
|
@@ -1,5 +1,8 @@
|
||||||
|
"""Utility functions."""
|
||||||
|
|
||||||
|
|
||||||
|
+def add(a: int, b: int) -> int:
|
||||||
|
+ """Add two numbers."""
|
||||||
|
+ return a + b
|
||||||
|
+
|
||||||
|
+
|
||||||
|
def subtract(a: int, b: int) -> int:
|
||||||
|
"""Subtract two numbers."""
|
||||||
|
return a - b
|
||||||
305
tests/test_agents.py
Normal file
305
tests/test_agents.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""Tests for review agents."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
|
||||||
|
from arbiter.llm.prompts import PromptRegistry
|
||||||
|
from arbiter.models import AgentConfig, AgentName, Policy, Severity
|
||||||
|
from tests.conftest import MockLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecurityAgent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_review_returns_result(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.agent_name == AgentName.SECURITY
|
||||||
|
assert result.findings == []
|
||||||
|
assert result.duration_ms >= 0
|
||||||
|
assert result.tokens_used == 150 # 100 in + 50 out from mock
|
||||||
|
assert result.cost_usd == 0.001
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parses_json_findings(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
response = """```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"file": "src/auth.py",
|
||||||
|
"line_start": 10,
|
||||||
|
"line_end": 15,
|
||||||
|
"severity": "high",
|
||||||
|
"confidence": 0.9,
|
||||||
|
"title": "SQL Injection",
|
||||||
|
"description": "User input concatenated",
|
||||||
|
"reasoning": "Allows SQL injection",
|
||||||
|
"suggestion": "Use parameterized queries",
|
||||||
|
"references": ["https://owasp.org"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```"""
|
||||||
|
mock_llm = MockLLMClient(responses=[response])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ query = ...", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert len(result.findings) == 1
|
||||||
|
finding = result.findings[0]
|
||||||
|
assert finding.file == "src/auth.py"
|
||||||
|
assert finding.severity == Severity.HIGH
|
||||||
|
assert finding.confidence == 0.9
|
||||||
|
assert finding.title == "SQL Injection"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_uses_configured_model(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
policy = Policy(
|
||||||
|
agents={
|
||||||
|
AgentName.SECURITY: AgentConfig(model="gpt-4o-mini"),
|
||||||
|
AgentName.STYLE: AgentConfig(),
|
||||||
|
AgentName.COMPLEXITY: AgentConfig(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
context = ReviewContext(diff="+ code", policy=policy)
|
||||||
|
|
||||||
|
await agent.review(context)
|
||||||
|
|
||||||
|
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filters_by_severity(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
response = """[
|
||||||
|
{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "high", "confidence": 0.9, "title": "High", "description": "", "reasoning": ""},
|
||||||
|
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "low", "confidence": 0.9, "title": "Low", "description": "", "reasoning": ""},
|
||||||
|
{"file": "c.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.9, "title": "Info", "description": "", "reasoning": ""}
|
||||||
|
]"""
|
||||||
|
mock_llm = MockLLMClient(responses=[response])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
policy = Policy(
|
||||||
|
agents={
|
||||||
|
AgentName.SECURITY: AgentConfig(severity_threshold=Severity.MEDIUM),
|
||||||
|
AgentName.STYLE: AgentConfig(),
|
||||||
|
AgentName.COMPLEXITY: AgentConfig(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
context = ReviewContext(diff="+ code", policy=policy)
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
# Only high severity should pass (medium threshold filters low and info)
|
||||||
|
assert len(result.findings) == 1
|
||||||
|
assert result.findings[0].severity == Severity.HIGH
|
||||||
|
|
||||||
|
|
||||||
|
class TestStyleAgent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_review_returns_result(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = StyleAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.agent_name == AgentName.STYLE
|
||||||
|
assert result.findings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_uses_default_model(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = StyleAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
await agent.review(context)
|
||||||
|
|
||||||
|
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
|
class TestComplexityAgent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_review_returns_result(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = ComplexityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.agent_name == AgentName.COMPLEXITY
|
||||||
|
assert result.findings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parses_complexity_findings(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
response = """[
|
||||||
|
{
|
||||||
|
"file": "processor.py",
|
||||||
|
"line_start": 1,
|
||||||
|
"line_end": 50,
|
||||||
|
"severity": "medium",
|
||||||
|
"confidence": 0.8,
|
||||||
|
"title": "High cyclomatic complexity",
|
||||||
|
"description": "Function has 15 branches",
|
||||||
|
"reasoning": "Makes testing and maintenance difficult"
|
||||||
|
}
|
||||||
|
]"""
|
||||||
|
mock_llm = MockLLMClient(responses=[response])
|
||||||
|
agent = ComplexityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ complex code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert len(result.findings) == 1
|
||||||
|
assert result.findings[0].severity == Severity.MEDIUM
|
||||||
|
assert "complexity" in result.findings[0].title.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentResponseParsing:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_empty_response(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=[""])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.findings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_invalid_json(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["not valid json"])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.findings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_json_without_code_block(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
response = '[{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Test", "description": "", "reasoning": ""}]'
|
||||||
|
mock_llm = MockLLMClient(responses=[response])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert len(result.findings) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_malformed_finding(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
response = """[
|
||||||
|
{"file": "a.py", "line_start": 1, "severity": "invalid_severity", "confidence": 0.5, "title": "Bad", "description": "", "reasoning": ""},
|
||||||
|
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Valid", "description": "", "reasoning": ""}
|
||||||
|
]"""
|
||||||
|
mock_llm = MockLLMClient(responses=[response])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
# Only the valid finding should be included (first has invalid severity)
|
||||||
|
assert len(result.findings) == 1
|
||||||
|
assert result.findings[0].title == "Valid"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_includes_prompt_additions(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
policy = Policy(
|
||||||
|
agents={
|
||||||
|
AgentName.SECURITY: AgentConfig(prompt_additions="Focus on authentication"),
|
||||||
|
AgentName.STYLE: AgentConfig(),
|
||||||
|
AgentName.COMPLEXITY: AgentConfig(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
context = ReviewContext(diff="+ code", policy=policy)
|
||||||
|
|
||||||
|
await agent.review(context)
|
||||||
|
|
||||||
|
message_content = mock_llm.calls[0]["messages"][0]["content"]
|
||||||
|
assert "Focus on authentication" in message_content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_non_list_json(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=['{"not": "a list"}'])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.findings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handles_non_dict_items(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=['["string", 123, null]'])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
context = ReviewContext(diff="+ code", policy=Policy())
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
assert result.findings == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_without_config_uses_defaults(
|
||||||
|
self,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
mock_llm = MockLLMClient(responses=["[]"])
|
||||||
|
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||||
|
# Create policy with empty agents dict
|
||||||
|
policy = Policy(agents={})
|
||||||
|
context = ReviewContext(diff="+ code", policy=policy)
|
||||||
|
|
||||||
|
result = await agent.review(context)
|
||||||
|
|
||||||
|
# Should use default model (gpt-4o for security)
|
||||||
|
assert mock_llm.calls[0]["model"] == "gpt-4o"
|
||||||
|
assert result.findings == []
|
||||||
561
tests/test_cli.py
Normal file
561
tests/test_cli.py
Normal file
@@ -0,0 +1,561 @@
|
|||||||
|
"""Tests for CLI commands."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from arbiter.cli import (
|
||||||
|
_severity_color,
|
||||||
|
_severity_icon,
|
||||||
|
_verdict_color,
|
||||||
|
_verdict_icon,
|
||||||
|
app,
|
||||||
|
)
|
||||||
|
from arbiter.deliberation import DeliberationResult
|
||||||
|
from arbiter.models import AgentName, Finding, ReviewResult, Severity, Verdict
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_return(
|
||||||
|
findings: list[Finding] | None = None, verdict: Verdict = Verdict.APPROVE
|
||||||
|
) -> tuple[list[ReviewResult], DeliberationResult]:
|
||||||
|
"""Create a mock return value for _run_review."""
|
||||||
|
agent_results = [
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=findings or [],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
deliberation_result = DeliberationResult(
|
||||||
|
findings=findings or [],
|
||||||
|
verdict=verdict,
|
||||||
|
verdict_confidence=0.9,
|
||||||
|
verdict_reasoning="Test reasoning",
|
||||||
|
total_findings=len(findings) if findings else 0,
|
||||||
|
)
|
||||||
|
return agent_results, deliberation_result
|
||||||
|
|
||||||
|
|
||||||
|
class TestVersionCommand:
|
||||||
|
def test_version_output(self) -> None:
|
||||||
|
result = runner.invoke(app, ["version"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "arbiter" in result.output
|
||||||
|
assert "0.3.0" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewCommand:
|
||||||
|
def test_file_not_found(self) -> None:
|
||||||
|
result = runner.invoke(app, ["review", "nonexistent.diff"])
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "File not found" in result.output
|
||||||
|
|
||||||
|
def test_empty_diff_warning(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "empty.diff"
|
||||||
|
diff_file.write_text("")
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Empty diff" in result.output
|
||||||
|
|
||||||
|
def test_policy_not_found(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--policy", "nonexistent.yaml"])
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Policy file not found" in result.output
|
||||||
|
|
||||||
|
def test_reads_from_stdin(self) -> None:
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", "-"], input="+ added line\n")
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_json_output_format(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert '"verdict"' in result.output
|
||||||
|
assert '"findings"' in result.output
|
||||||
|
|
||||||
|
def test_markdown_output_format(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "# Arbiter Review" in result.output
|
||||||
|
|
||||||
|
def test_loads_policy_file(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
policy_file = tmp_path / "policy.yaml"
|
||||||
|
policy_file.write_text("""
|
||||||
|
version: "1.0"
|
||||||
|
agents:
|
||||||
|
security:
|
||||||
|
enabled: true
|
||||||
|
style:
|
||||||
|
enabled: false
|
||||||
|
complexity:
|
||||||
|
enabled: false
|
||||||
|
""")
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--policy", str(policy_file)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Verify policy was passed to _run_review
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
policy = call_args[0][1] # Second positional arg is policy
|
||||||
|
assert len(policy.get_enabled_agents()) == 1
|
||||||
|
|
||||||
|
def test_model_override(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--model", "gpt-4o-mini"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Verify model was set in policy
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
policy = call_args[0][1]
|
||||||
|
for config in policy.agents.values():
|
||||||
|
assert config.model == "gpt-4o-mini"
|
||||||
|
|
||||||
|
def test_static_analysis_flag(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--no-static-analysis"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Verify static_analysis was False
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
assert call_args.kwargs.get("static_analysis") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoArgsHelp:
|
||||||
|
def test_no_args_shows_help(self) -> None:
|
||||||
|
result = runner.invoke(app, [])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "A multi-agent code review system" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestOutputFormatting:
|
||||||
|
def test_severity_color(self) -> None:
|
||||||
|
assert _severity_color(Severity.CRITICAL) == "red bold"
|
||||||
|
assert _severity_color(Severity.HIGH) == "red"
|
||||||
|
assert _severity_color(Severity.MEDIUM) == "yellow"
|
||||||
|
assert _severity_color(Severity.LOW) == "blue"
|
||||||
|
assert _severity_color(Severity.INFO) == "dim"
|
||||||
|
|
||||||
|
def test_severity_icon(self) -> None:
|
||||||
|
assert _severity_icon(Severity.CRITICAL) == "!!"
|
||||||
|
assert _severity_icon(Severity.HIGH) == "!"
|
||||||
|
assert _severity_icon(Severity.MEDIUM) == "*"
|
||||||
|
assert _severity_icon(Severity.LOW) == "-"
|
||||||
|
assert _severity_icon(Severity.INFO) == "i"
|
||||||
|
|
||||||
|
def test_verdict_color(self) -> None:
|
||||||
|
assert _verdict_color(Verdict.APPROVE) == "green"
|
||||||
|
assert _verdict_color(Verdict.COMMENT) == "yellow"
|
||||||
|
assert _verdict_color(Verdict.REQUEST_CHANGES) == "red"
|
||||||
|
|
||||||
|
def test_verdict_icon(self) -> None:
|
||||||
|
assert _verdict_icon(Verdict.APPROVE) == "[ok]"
|
||||||
|
assert _verdict_icon(Verdict.COMMENT) == "[..]"
|
||||||
|
assert _verdict_icon(Verdict.REQUEST_CHANGES) == "[!!]"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRichOutput:
|
||||||
|
def test_rich_format_with_findings(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
finding = Finding(
|
||||||
|
id="test-finding-1",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="SQL Injection",
|
||||||
|
description="User input in query",
|
||||||
|
reasoning="Direct concatenation",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return(findings=[finding])
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
def test_rich_format_no_findings(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "No issues found" in result.output
|
||||||
|
|
||||||
|
def test_rich_format_critical_findings(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
finding = Finding(
|
||||||
|
id="test-finding-1",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=10,
|
||||||
|
severity=Severity.CRITICAL,
|
||||||
|
confidence=0.95,
|
||||||
|
title="Critical Issue",
|
||||||
|
description="This is critical",
|
||||||
|
reasoning="Very bad",
|
||||||
|
suggestion="Fix it immediately",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
deliberation = DeliberationResult(
|
||||||
|
findings=[finding],
|
||||||
|
verdict=Verdict.REQUEST_CHANGES,
|
||||||
|
verdict_confidence=0.95,
|
||||||
|
verdict_reasoning="Critical issue found",
|
||||||
|
total_findings=1,
|
||||||
|
critical_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = (
|
||||||
|
[
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[finding],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
deliberation,
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarkdownOutput:
|
||||||
|
def test_markdown_with_findings(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
finding = Finding(
|
||||||
|
id="test-finding-1",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="SQL Injection",
|
||||||
|
description="User input in query",
|
||||||
|
reasoning="Direct concatenation",
|
||||||
|
suggestion="Use parameterized queries",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return(findings=[finding])
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "## Findings" in result.output
|
||||||
|
assert "SQL Injection" in result.output
|
||||||
|
|
||||||
|
def test_markdown_verdict_badges(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
for verdict in [Verdict.APPROVE, Verdict.COMMENT, Verdict.REQUEST_CHANGES]:
|
||||||
|
deliberation = DeliberationResult(
|
||||||
|
verdict=verdict,
|
||||||
|
verdict_confidence=0.9,
|
||||||
|
verdict_reasoning="Test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = (
|
||||||
|
[
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
deliberation,
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert verdict.value.upper() in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonOutput:
|
||||||
|
def test_json_with_conflicts(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||||
|
|
||||||
|
conflict = Conflict(
|
||||||
|
id="test-conflict",
|
||||||
|
finding_ids=["f1", "f2"],
|
||||||
|
nature=ConflictNature.TRADE_OFF,
|
||||||
|
description="Test conflict",
|
||||||
|
severity_weight=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
deliberation = DeliberationResult(
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.7,
|
||||||
|
verdict_reasoning="Conflicts found",
|
||||||
|
conflicts=[conflict],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = (
|
||||||
|
[
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
deliberation,
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
output = json.loads(result.output)
|
||||||
|
assert "conflicts" in output
|
||||||
|
assert len(output["conflicts"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkDirHandling:
|
||||||
|
def test_work_dir_option(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
work_dir = tmp_path / "src"
|
||||||
|
work_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = make_mock_return()
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--work-dir", str(work_dir)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
call_args = mock_run.call_args
|
||||||
|
assert call_args.kwargs.get("work_dir") == work_dir.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRichOutputWithConflicts:
|
||||||
|
def test_rich_conflicts(self, tmp_path: Path) -> None:
|
||||||
|
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||||
|
from arbiter.deliberation.synthesis import Resolution
|
||||||
|
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
conflict = Conflict(
|
||||||
|
id="test-conflict",
|
||||||
|
finding_ids=["f1", "f2"],
|
||||||
|
nature=ConflictNature.TRADE_OFF,
|
||||||
|
description="Security vs complexity trade-off",
|
||||||
|
severity_weight=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolution = Resolution(
|
||||||
|
conflict_id="test-conflict",
|
||||||
|
decision="prefer_first",
|
||||||
|
reasoning="Security takes priority",
|
||||||
|
confidence=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
deliberation = DeliberationResult(
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.7,
|
||||||
|
verdict_reasoning="Conflicts found",
|
||||||
|
conflicts=[conflict],
|
||||||
|
resolutions=[resolution],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = (
|
||||||
|
[
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
deliberation,
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarkdownOutputWithConflicts:
|
||||||
|
def test_markdown_conflicts(self, tmp_path: Path) -> None:
|
||||||
|
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||||
|
from arbiter.deliberation.synthesis import Resolution
|
||||||
|
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
conflict = Conflict(
|
||||||
|
id="test-conflict",
|
||||||
|
finding_ids=["f1", "f2"],
|
||||||
|
nature=ConflictNature.CONTRADICTORY,
|
||||||
|
description="Contradictory recommendations",
|
||||||
|
severity_weight=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolution = Resolution(
|
||||||
|
conflict_id="test-conflict",
|
||||||
|
decision="merge",
|
||||||
|
reasoning="Both concerns addressed by combined fix",
|
||||||
|
merged_suggestion="Do both things",
|
||||||
|
confidence=0.85,
|
||||||
|
)
|
||||||
|
|
||||||
|
deliberation = DeliberationResult(
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.7,
|
||||||
|
verdict_reasoning="Conflicts found",
|
||||||
|
conflicts=[conflict],
|
||||||
|
resolutions=[resolution],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = (
|
||||||
|
[
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
deliberation,
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "## Conflicts" in result.output
|
||||||
|
assert "Contradictory" in result.output
|
||||||
|
assert "Resolution" in result.output
|
||||||
|
|
||||||
|
def test_markdown_findings(self, tmp_path: Path) -> None:
|
||||||
|
diff_file = tmp_path / "test.diff"
|
||||||
|
diff_file.write_text("+ some change")
|
||||||
|
|
||||||
|
findings = [
|
||||||
|
Finding(
|
||||||
|
id="f1",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="Security Issue",
|
||||||
|
description="Vulnerable code",
|
||||||
|
reasoning="Bad pattern",
|
||||||
|
suggestion="Fix it this way",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
),
|
||||||
|
Finding(
|
||||||
|
id="f2",
|
||||||
|
agent=AgentName.STYLE,
|
||||||
|
file="test.py",
|
||||||
|
line_start=20,
|
||||||
|
line_end=25,
|
||||||
|
severity=Severity.LOW,
|
||||||
|
confidence=0.8,
|
||||||
|
title="Style Issue",
|
||||||
|
description="Could be cleaner",
|
||||||
|
reasoning="Convention",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
deliberation = DeliberationResult(
|
||||||
|
findings=findings,
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.75,
|
||||||
|
verdict_reasoning="Issues found",
|
||||||
|
total_findings=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_run.return_value = (
|
||||||
|
[
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[findings[0]],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
),
|
||||||
|
ReviewResult(
|
||||||
|
agent_name=AgentName.STYLE,
|
||||||
|
findings=[findings[1]],
|
||||||
|
duration_ms=100,
|
||||||
|
tokens_used=100,
|
||||||
|
cost_usd=0.001,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
deliberation,
|
||||||
|
)
|
||||||
|
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Security Issue" in result.output
|
||||||
|
assert "Style Issue" in result.output
|
||||||
|
assert "Fix it this way" in result.output
|
||||||
313
tests/test_llm.py
Normal file
313
tests/test_llm.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
"""Tests for LLM client and prompts."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbiter.llm.client import LiteLLMClient, LLMResponse
|
||||||
|
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||||
|
from tests.conftest import MockLLMClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestLiteLLMClient:
|
||||||
|
def test_init_default_values(self) -> None:
|
||||||
|
client = LiteLLMClient()
|
||||||
|
assert client.timeout == 60
|
||||||
|
assert client.max_retries == 3
|
||||||
|
|
||||||
|
def test_init_custom_values(self) -> None:
|
||||||
|
client = LiteLLMClient(timeout=120, max_retries=5)
|
||||||
|
assert client.timeout == 120
|
||||||
|
assert client.max_retries == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_returns_response(self) -> None:
|
||||||
|
client = LiteLLMClient()
|
||||||
|
|
||||||
|
# Mock the litellm.acompletion function
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.model = "gpt-4o"
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 10
|
||||||
|
mock_response.usage.completion_tokens = 5
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||||
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||||
|
):
|
||||||
|
mock_acomp.return_value = mock_response
|
||||||
|
mock_cost.return_value = 0.001
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
response = await client.complete(messages, "gpt-4o")
|
||||||
|
|
||||||
|
assert response.content == "Test response"
|
||||||
|
assert response.model == "gpt-4o"
|
||||||
|
assert response.tokens_in == 10
|
||||||
|
assert response.tokens_out == 5
|
||||||
|
assert response.cost_usd == 0.001
|
||||||
|
|
||||||
|
mock_acomp.assert_called_once_with(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=messages,
|
||||||
|
timeout=60,
|
||||||
|
num_retries=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_handles_empty_content(self) -> None:
|
||||||
|
client = LiteLLMClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = None # None content
|
||||||
|
mock_response.model = "gpt-4o"
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 5
|
||||||
|
mock_response.usage.completion_tokens = 0
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||||
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||||
|
):
|
||||||
|
mock_acomp.return_value = mock_response
|
||||||
|
mock_cost.return_value = 0.0
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
response = await client.complete(messages, "gpt-4o")
|
||||||
|
|
||||||
|
assert response.content == "" # Should be empty string, not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_handles_missing_usage(self) -> None:
|
||||||
|
client = LiteLLMClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Response"
|
||||||
|
mock_response.model = "gpt-4o"
|
||||||
|
mock_response.usage = None # No usage data
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||||
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||||
|
):
|
||||||
|
mock_acomp.return_value = mock_response
|
||||||
|
mock_cost.return_value = 0.0
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
response = await client.complete(messages, "gpt-4o")
|
||||||
|
|
||||||
|
assert response.tokens_in == 0
|
||||||
|
assert response.tokens_out == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_uses_fallback_model(self) -> None:
|
||||||
|
client = LiteLLMClient()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Response"
|
||||||
|
mock_response.model = None # No model in response
|
||||||
|
mock_response.usage = MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 5
|
||||||
|
mock_response.usage.completion_tokens = 3
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||||
|
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||||
|
):
|
||||||
|
mock_acomp.return_value = mock_response
|
||||||
|
mock_cost.return_value = 0.0
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
response = await client.complete(messages, "claude-3-opus")
|
||||||
|
|
||||||
|
# Should use the passed model as fallback
|
||||||
|
assert response.model == "claude-3-opus"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMResponse:
|
||||||
|
def test_response_creation(self) -> None:
|
||||||
|
response = LLMResponse(
|
||||||
|
content="Hello, world!",
|
||||||
|
model="gpt-4o",
|
||||||
|
tokens_in=10,
|
||||||
|
tokens_out=5,
|
||||||
|
cost_usd=0.001,
|
||||||
|
)
|
||||||
|
assert response.content == "Hello, world!"
|
||||||
|
assert response.model == "gpt-4o"
|
||||||
|
assert response.tokens_in == 10
|
||||||
|
assert response.tokens_out == 5
|
||||||
|
assert response.cost_usd == 0.001
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockLLMClient:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_records_calls(self) -> None:
|
||||||
|
client = MockLLMClient()
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
await client.complete(messages, "gpt-4o")
|
||||||
|
|
||||||
|
assert len(client.calls) == 1
|
||||||
|
assert client.calls[0]["messages"] == messages
|
||||||
|
assert client.calls[0]["model"] == "gpt-4o"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_canned_responses(self) -> None:
|
||||||
|
client = MockLLMClient(responses=["First", "Second"])
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
response1 = await client.complete(messages, "gpt-4o")
|
||||||
|
response2 = await client.complete(messages, "gpt-4o")
|
||||||
|
response3 = await client.complete(messages, "gpt-4o")
|
||||||
|
|
||||||
|
assert response1.content == "First"
|
||||||
|
assert response2.content == "Second"
|
||||||
|
assert response3.content == "" # Exhausted
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset(self) -> None:
|
||||||
|
client = MockLLMClient(responses=["Hello"])
|
||||||
|
messages = [{"role": "user", "content": "Hi"}]
|
||||||
|
|
||||||
|
await client.complete(messages, "gpt-4o")
|
||||||
|
assert len(client.calls) == 1
|
||||||
|
|
||||||
|
client.reset()
|
||||||
|
assert len(client.calls) == 0
|
||||||
|
|
||||||
|
response = await client.complete(messages, "gpt-4o")
|
||||||
|
assert response.content == "Hello" # Responses reset too
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplate:
|
||||||
|
def test_template_creation(self) -> None:
|
||||||
|
template = PromptTemplate(
|
||||||
|
name="security",
|
||||||
|
version="1.0",
|
||||||
|
content="Review: {{diff}}",
|
||||||
|
)
|
||||||
|
assert template.name == "security"
|
||||||
|
assert template.version == "1.0"
|
||||||
|
assert template.full_name == "security-v1.0"
|
||||||
|
|
||||||
|
def test_render_substitution(self) -> None:
|
||||||
|
template = PromptTemplate(
|
||||||
|
name="test",
|
||||||
|
version="1.0",
|
||||||
|
content="File: {{file}}\nDiff: {{diff}}",
|
||||||
|
)
|
||||||
|
result = template.render(file="test.py", diff="+ added line")
|
||||||
|
assert result == "File: test.py\nDiff: + added line"
|
||||||
|
|
||||||
|
def test_render_missing_variable(self) -> None:
|
||||||
|
template = PromptTemplate(
|
||||||
|
name="test",
|
||||||
|
version="1.0",
|
||||||
|
content="Value: {{value}}",
|
||||||
|
)
|
||||||
|
result = template.render()
|
||||||
|
assert result == "Value: {{value}}"
|
||||||
|
|
||||||
|
def test_render_multiple_occurrences(self) -> None:
|
||||||
|
template = PromptTemplate(
|
||||||
|
name="test",
|
||||||
|
version="1.0",
|
||||||
|
content="{{name}} and {{name}} again",
|
||||||
|
)
|
||||||
|
result = template.render(name="test")
|
||||||
|
assert result == "test and test again"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptRegistry:
|
||||||
|
def test_get_template(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
template = registry.get("security", "1.0")
|
||||||
|
|
||||||
|
assert template.name == "security"
|
||||||
|
assert template.version == "1.0"
|
||||||
|
assert "{{diff}}" in template.content
|
||||||
|
|
||||||
|
def test_get_template_cached(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
template1 = registry.get("security", "1.0")
|
||||||
|
template2 = registry.get("security", "1.0")
|
||||||
|
|
||||||
|
assert template1 is template2
|
||||||
|
|
||||||
|
def test_get_template_not_found(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
registry.get("missing", "1.0")
|
||||||
|
|
||||||
|
def test_list_templates(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||||
|
(templates_dir / "style-v2.0.md").write_text("Content")
|
||||||
|
(templates_dir / "readme.md").write_text("Not a template")
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
templates = registry.list_templates()
|
||||||
|
|
||||||
|
assert len(templates) == 2
|
||||||
|
assert ("security", "1.0") in templates
|
||||||
|
assert ("style", "2.0") in templates
|
||||||
|
|
||||||
|
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
templates = registry.list_templates()
|
||||||
|
|
||||||
|
assert templates == []
|
||||||
|
|
||||||
|
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "missing"
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
templates = registry.list_templates()
|
||||||
|
|
||||||
|
assert templates == []
|
||||||
|
|
||||||
|
def test_clear_cache(self, tmp_path: Path) -> None:
|
||||||
|
templates_dir = tmp_path / "templates"
|
||||||
|
templates_dir.mkdir()
|
||||||
|
(templates_dir / "test-v1.0.md").write_text("Original")
|
||||||
|
|
||||||
|
registry = PromptRegistry(templates_dir)
|
||||||
|
template1 = registry.get("test", "1.0")
|
||||||
|
assert template1.content == "Original"
|
||||||
|
|
||||||
|
# Modify file
|
||||||
|
(templates_dir / "test-v1.0.md").write_text("Modified")
|
||||||
|
|
||||||
|
# Still cached
|
||||||
|
template2 = registry.get("test", "1.0")
|
||||||
|
assert template2.content == "Original"
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
registry.clear_cache()
|
||||||
|
|
||||||
|
# Now reads new content
|
||||||
|
template3 = registry.get("test", "1.0")
|
||||||
|
assert template3.content == "Modified"
|
||||||
224
tests/test_models.py
Normal file
224
tests/test_models.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
"""Tests for data models."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbiter.models import (
|
||||||
|
AgentConfig,
|
||||||
|
AgentName,
|
||||||
|
Finding,
|
||||||
|
Policy,
|
||||||
|
ReviewResult,
|
||||||
|
Severity,
|
||||||
|
Verdict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnums:
|
||||||
|
def test_severity_values(self) -> None:
|
||||||
|
assert Severity.CRITICAL == "critical"
|
||||||
|
assert Severity.HIGH == "high"
|
||||||
|
assert Severity.MEDIUM == "medium"
|
||||||
|
assert Severity.LOW == "low"
|
||||||
|
assert Severity.INFO == "info"
|
||||||
|
|
||||||
|
def test_verdict_values(self) -> None:
|
||||||
|
assert Verdict.APPROVE == "approve"
|
||||||
|
assert Verdict.REQUEST_CHANGES == "request_changes"
|
||||||
|
assert Verdict.COMMENT == "comment"
|
||||||
|
|
||||||
|
def test_agent_name_values(self) -> None:
|
||||||
|
assert AgentName.SECURITY == "security"
|
||||||
|
assert AgentName.STYLE == "style"
|
||||||
|
assert AgentName.COMPLEXITY == "complexity"
|
||||||
|
|
||||||
|
def test_severity_from_string(self) -> None:
|
||||||
|
assert Severity("critical") == Severity.CRITICAL
|
||||||
|
assert Severity("high") == Severity.HIGH
|
||||||
|
|
||||||
|
|
||||||
|
class TestFinding:
|
||||||
|
def test_finding_creation(self) -> None:
|
||||||
|
finding = Finding(
|
||||||
|
id="test-123",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="src/auth.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="SQL Injection",
|
||||||
|
description="User input concatenated in SQL query",
|
||||||
|
reasoning="Allows attackers to execute arbitrary SQL",
|
||||||
|
suggestion="Use parameterized queries",
|
||||||
|
references=["https://owasp.org"],
|
||||||
|
prompt_version="security-v1.0",
|
||||||
|
)
|
||||||
|
assert finding.id == "test-123"
|
||||||
|
assert finding.agent == AgentName.SECURITY
|
||||||
|
assert finding.severity == Severity.HIGH
|
||||||
|
assert finding.confidence == 0.9
|
||||||
|
|
||||||
|
def test_finding_confidence_validation(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Finding(
|
||||||
|
id="test",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=1,
|
||||||
|
line_end=1,
|
||||||
|
severity=Severity.INFO,
|
||||||
|
confidence=1.5,
|
||||||
|
title="Test",
|
||||||
|
description="Test",
|
||||||
|
reasoning="Test",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_finding_line_validation(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Finding(
|
||||||
|
id="test",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=0,
|
||||||
|
line_end=1,
|
||||||
|
severity=Severity.INFO,
|
||||||
|
confidence=0.5,
|
||||||
|
title="Test",
|
||||||
|
description="Test",
|
||||||
|
reasoning="Test",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_finding_serialization(self) -> None:
|
||||||
|
finding = Finding(
|
||||||
|
id="test-123",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="src/auth.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="Test",
|
||||||
|
description="Test desc",
|
||||||
|
reasoning="Test reason",
|
||||||
|
prompt_version="security-v1.0",
|
||||||
|
)
|
||||||
|
data = finding.model_dump()
|
||||||
|
assert data["id"] == "test-123"
|
||||||
|
assert data["agent"] == "security"
|
||||||
|
assert data["severity"] == "high"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentConfig:
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
config = AgentConfig()
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.model is None
|
||||||
|
assert config.severity_threshold == Severity.INFO
|
||||||
|
assert config.prompt_additions is None
|
||||||
|
|
||||||
|
def test_custom_values(self) -> None:
|
||||||
|
config = AgentConfig(
|
||||||
|
enabled=False,
|
||||||
|
model="gpt-4o",
|
||||||
|
severity_threshold=Severity.MEDIUM,
|
||||||
|
prompt_additions="Focus on auth",
|
||||||
|
)
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.model == "gpt-4o"
|
||||||
|
assert config.severity_threshold == Severity.MEDIUM
|
||||||
|
|
||||||
|
|
||||||
|
class TestPolicy:
|
||||||
|
def test_default_policy(self) -> None:
|
||||||
|
policy = Policy()
|
||||||
|
assert len(policy.agents) == 3
|
||||||
|
assert AgentName.SECURITY in policy.agents
|
||||||
|
assert AgentName.STYLE in policy.agents
|
||||||
|
assert AgentName.COMPLEXITY in policy.agents
|
||||||
|
|
||||||
|
def test_get_enabled_agents(self) -> None:
|
||||||
|
policy = Policy()
|
||||||
|
enabled = policy.get_enabled_agents()
|
||||||
|
assert len(enabled) == 3
|
||||||
|
assert AgentName.SECURITY in enabled
|
||||||
|
|
||||||
|
def test_get_enabled_agents_with_disabled(self) -> None:
|
||||||
|
policy = Policy(
|
||||||
|
agents={
|
||||||
|
AgentName.SECURITY: AgentConfig(enabled=True),
|
||||||
|
AgentName.STYLE: AgentConfig(enabled=False),
|
||||||
|
AgentName.COMPLEXITY: AgentConfig(enabled=True),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
enabled = policy.get_enabled_agents()
|
||||||
|
assert len(enabled) == 2
|
||||||
|
assert AgentName.STYLE not in enabled
|
||||||
|
|
||||||
|
def test_load_from_yaml(self, tmp_path: Path) -> None:
|
||||||
|
policy_file = tmp_path / "policy.yaml"
|
||||||
|
policy_file.write_text("""
|
||||||
|
version: "1.1"
|
||||||
|
agents:
|
||||||
|
security:
|
||||||
|
enabled: true
|
||||||
|
model: gpt-4o
|
||||||
|
severity_threshold: high
|
||||||
|
style:
|
||||||
|
enabled: false
|
||||||
|
complexity:
|
||||||
|
enabled: true
|
||||||
|
""")
|
||||||
|
policy = Policy.load(policy_file)
|
||||||
|
assert policy.version == "1.1"
|
||||||
|
assert policy.agents[AgentName.SECURITY].model == "gpt-4o"
|
||||||
|
assert policy.agents[AgentName.SECURITY].severity_threshold == Severity.HIGH
|
||||||
|
assert policy.agents[AgentName.STYLE].enabled is False
|
||||||
|
|
||||||
|
def test_load_empty_yaml(self, tmp_path: Path) -> None:
|
||||||
|
policy_file = tmp_path / "policy.yaml"
|
||||||
|
policy_file.write_text("")
|
||||||
|
policy = Policy.load(policy_file)
|
||||||
|
assert policy.version == "1.0"
|
||||||
|
assert len(policy.agents) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewResult:
|
||||||
|
def test_review_result_creation(self) -> None:
|
||||||
|
result = ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[],
|
||||||
|
duration_ms=1000,
|
||||||
|
tokens_used=500,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
assert result.agent_name == AgentName.SECURITY
|
||||||
|
assert result.duration_ms == 1000
|
||||||
|
assert result.cost_usd == 0.01
|
||||||
|
|
||||||
|
def test_review_result_with_findings(self) -> None:
|
||||||
|
finding = Finding(
|
||||||
|
id="test-123",
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="test.py",
|
||||||
|
line_start=1,
|
||||||
|
line_end=1,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="Test",
|
||||||
|
description="Test",
|
||||||
|
reasoning="Test",
|
||||||
|
prompt_version="test-v1.0",
|
||||||
|
)
|
||||||
|
result = ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[finding],
|
||||||
|
duration_ms=1000,
|
||||||
|
tokens_used=500,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
assert len(result.findings) == 1
|
||||||
|
assert result.findings[0].id == "test-123"
|
||||||
Reference in New Issue
Block a user