From f22ca1d5bd865e6f4f43eeddbd3cb60efda1e70f Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Sat, 8 Mar 2025 15:52:29 +0000 Subject: [PATCH] feat(agents): implement agent framework and CLI --- pyproject.toml | 111 ++++++ src/arbiter/__init__.py | 3 + src/arbiter/agents/__init__.py | 14 + src/arbiter/agents/base.py | 277 +++++++++++++ src/arbiter/agents/complexity.py | 17 + src/arbiter/agents/security.py | 17 + src/arbiter/agents/style.py | 17 + src/arbiter/cli.py | 407 +++++++++++++++++++ src/arbiter/config.py | 121 ++++++ src/arbiter/llm/__init__.py | 12 + src/arbiter/llm/client.py | 66 ++++ src/arbiter/llm/prompts.py | 81 ++++ src/arbiter/models/__init__.py | 16 + src/arbiter/models/enums.py | 29 ++ src/arbiter/models/finding.py | 28 ++ src/arbiter/models/policy.py | 58 +++ src/arbiter/models/review.py | 16 + src/arbiter/py.typed | 0 templates/complexity-v1.0.md | 56 +++ templates/security-v1.0.md | 55 +++ templates/style-v1.0.md | 55 +++ tests/__init__.py | 1 + tests/conftest.py | 490 +++++++++++++++++++++++ tests/fixtures/complex-function.diff | 69 ++++ tests/fixtures/security-issue.diff | 31 ++ tests/fixtures/simple.diff | 16 + tests/test_agents.py | 305 +++++++++++++++ tests/test_cli.py | 561 +++++++++++++++++++++++++++ tests/test_llm.py | 313 +++++++++++++++ tests/test_models.py | 224 +++++++++++ 30 files changed, 3466 insertions(+) create mode 100644 pyproject.toml create mode 100644 src/arbiter/__init__.py create mode 100644 src/arbiter/agents/__init__.py create mode 100644 src/arbiter/agents/base.py create mode 100644 src/arbiter/agents/complexity.py create mode 100644 src/arbiter/agents/security.py create mode 100644 src/arbiter/agents/style.py create mode 100644 src/arbiter/cli.py create mode 100644 src/arbiter/config.py create mode 100644 src/arbiter/llm/__init__.py create mode 100644 src/arbiter/llm/client.py create mode 100644 src/arbiter/llm/prompts.py create mode 100644 src/arbiter/models/__init__.py create mode 100644 src/arbiter/models/enums.py create mode 100644 src/arbiter/models/finding.py create mode 100644 src/arbiter/models/policy.py create mode 100644 src/arbiter/models/review.py create mode 100644 src/arbiter/py.typed create mode 100644 templates/complexity-v1.0.md create mode 100644 templates/security-v1.0.md create mode 100644 templates/style-v1.0.md create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/fixtures/complex-function.diff create mode 100644 tests/fixtures/security-issue.diff create mode 100644 tests/fixtures/simple.diff create mode 100644 tests/test_agents.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_llm.py create mode 100644 tests/test_models.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..52b8da2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,111 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "arbiter" +version = "0.1.0" +description = "A multi-agent code review system that shows its work" +readme = "readme.md" +requires-python = ">=3.12" +license = "MIT" +authors = [{ name = "Kai Chappell", email = "git@kschappell.com" }] +keywords = ["code-review", "ai", "llm", "agents"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Quality Assurance", +] +dependencies = [ + "litellm>=1.0.0", + "pydantic>=2.0.0", + "pydantic-settings>=2.0.0", + "typer>=0.9.0", + "rich>=13.0.0", + "pyyaml>=6.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.0.0", + "ruff>=0.4.0", + "mypy>=1.10.0", + "types-PyYAML>=6.0.0", +] + +[project.scripts] +arbiter = "arbiter.cli:app" + +[tool.hatch.build.targets.wheel] +packages = ["src/arbiter"] + +[tool.ruff] +target-version = "py312" +line-length = 100 +src = ["src", "tests"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade + "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify + "TCH", # flake8-type-checking + "PTH", # flake8-use-pathlib + "RUF", # Ruff-specific rules +] +ignore = [ + "E501", # line too long (handled by formatter) +] + +[tool.ruff.lint.isort] +known-first-party = ["arbiter"] + +[tool.mypy] +python_version = "3.12" +strict = true +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +disallow_untyped_decorators = true +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_configs = true +show_error_codes = true +files = ["src"] + +[[tool.mypy.overrides]] +module = ["litellm.*"] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +addopts = "-ra -q" + +[tool.coverage.run] +source = ["src/arbiter"] +branch = true + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", +] +fail_under = 85 diff --git a/src/arbiter/__init__.py b/src/arbiter/__init__.py new file mode 100644 index 0000000..873b421 --- /dev/null +++ b/src/arbiter/__init__.py @@ -0,0 +1,3 @@ +"""Arbiter: A multi-agent code review system that shows its work.""" + +__version__ = "0.1.0" diff --git a/src/arbiter/agents/__init__.py b/src/arbiter/agents/__init__.py new file mode 100644 index 0000000..ad2e995 --- /dev/null +++ b/src/arbiter/agents/__init__.py @@ -0,0 +1,14 @@ +"""Review agents for code analysis.""" + +from arbiter.agents.base import Agent, ReviewContext +from arbiter.agents.complexity import ComplexityAgent +from arbiter.agents.security import SecurityAgent +from arbiter.agents.style import StyleAgent + +__all__ = [ + "Agent", + "ComplexityAgent", + "ReviewContext", + "SecurityAgent", + "StyleAgent", +] diff --git a/src/arbiter/agents/base.py b/src/arbiter/agents/base.py new file mode 100644 index 0000000..b3995f8 --- /dev/null +++ b/src/arbiter/agents/base.py @@ -0,0 +1,277 @@ +"""Base agent class and review context.""" + +import json +import time +import uuid +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, Field + +from arbiter.llm.client import LLMClient +from arbiter.llm.prompts import PromptRegistry +from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity + + +class ReviewContext(BaseModel): + """Context for a code review.""" + + diff: str = Field(description="The diff content to review") + policy: Policy = Field(description="Review policy configuration") + file_path: str | None = Field(default=None, description="Path to the diff file") + static_analysis_context: str = Field( + default="No static analysis data available.", + description="Static analysis results as formatted string", + ) + + +class ExplainContext(BaseModel): + """Context for explaining a finding in a follow-up conversation.""" + + question: str = Field(description="The user's follow-up question") + finding: Finding = Field(description="The finding being asked about") + diff: str = Field(description="The relevant code diff") + conversation_history: list[dict[str, str]] = Field( + default_factory=list, + description="Previous messages in this conversation", + ) + + +class ExplainResult(BaseModel): + """Result from an agent's explain() method.""" + + response: str = Field(description="The explanation response text") + tokens_used: int = Field(ge=0, description="Total tokens used for this response") + cost_usd: float = Field(ge=0.0, description="Cost in USD for this response") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence in the explanation") + + +class Agent(ABC): + """Abstract base class for review agents.""" + + name: AgentName + prompt_name: str + prompt_version: str = "1.0" + default_model: str = "gpt-4o" + explain_prompt_name: str = "" # Set by subclasses + explain_prompt_version: str = "1.0" + + def __init__( + self, + llm_client: LLMClient, + prompt_registry: PromptRegistry, + ) -> None: + self.llm_client = llm_client + self.prompt_registry = prompt_registry + + async def review(self, context: ReviewContext) -> ReviewResult: + """Perform a code review. + + Args: + context: Review context with diff and policy. + + Returns: + ReviewResult with findings and metadata. + """ + start_time = time.monotonic() + + # Get model from policy or use default + agent_config = context.policy.agents.get(self.name) + model = agent_config.model if agent_config and agent_config.model else self.default_model + + # Get prompt additions from policy + prompt_additions = "" + if agent_config and agent_config.prompt_additions: + prompt_additions = agent_config.prompt_additions + + # Build and send messages + messages = self._build_messages(context, prompt_additions) + response = await self._call_llm(messages, model) + + # Parse response into findings + findings = self._parse_response(response.content, context) + + # Filter by severity threshold + if agent_config: + findings = self._filter_by_severity(findings, agent_config.severity_threshold) + + duration_ms = int((time.monotonic() - start_time) * 1000) + + return ReviewResult( + agent_name=self.name, + findings=findings, + duration_ms=duration_ms, + tokens_used=response.tokens_in + response.tokens_out, + cost_usd=response.cost_usd, + ) + + def _build_messages( + self, + context: ReviewContext, + prompt_additions: str, + ) -> list[dict[str, str]]: + template = self.prompt_registry.get(self.prompt_name, self.prompt_version) + content = template.render( + diff=context.diff, + static_analysis_context=context.static_analysis_context, + prompt_additions=prompt_additions, + ) + return [{"role": "user", "content": content}] + + async def _call_llm( + self, + messages: list[dict[str, str]], + model: str, + ) -> Any: + return await self.llm_client.complete(messages, model) + + @abstractmethod + def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: ... + + def _filter_by_severity( + self, + findings: list[Finding], + threshold: Severity, + ) -> list[Finding]: + """Filter findings by severity threshold. + + Args: + findings: List of findings to filter. + threshold: Minimum severity to include. + + Returns: + Filtered list of findings. + """ + severity_order = [ + Severity.CRITICAL, + Severity.HIGH, + Severity.MEDIUM, + Severity.LOW, + Severity.INFO, + ] + threshold_index = severity_order.index(threshold) + return [f for f in findings if severity_order.index(f.severity) <= threshold_index] + + def _parse_json_findings(self, content: str, _context: ReviewContext) -> list[Finding]: + """Parse JSON array of findings from LLM response. + + Args: + content: Raw LLM response content. + context: Review context. + + Returns: + List of Finding objects. + """ + # Extract JSON from response (handle markdown code blocks) + json_content = content.strip() + if json_content.startswith("```"): + lines = json_content.split("\n") + # Remove first and last lines (code block markers) + json_lines = [] + in_block = False + for line in lines: + if line.startswith("```") and not in_block: + in_block = True + continue + if line.startswith("```") and in_block: + break + if in_block: + json_lines.append(line) + json_content = "\n".join(json_lines) + + try: + data = json.loads(json_content) + except json.JSONDecodeError: + return [] + + if not isinstance(data, list): + return [] + + findings = [] + for item in data: + if not isinstance(item, dict): + continue + try: + finding = Finding( + id=str(uuid.uuid4()), + agent=self.name, + file=item.get("file", "unknown"), + line_start=item.get("line_start", 1), + line_end=item.get("line_end") or item.get("line_start", 1), + severity=Severity(item.get("severity", "info")), + confidence=float(item.get("confidence", 0.5)), + title=item.get("title", "Untitled finding"), + description=item.get("description", ""), + reasoning=item.get("reasoning", ""), + suggestion=item.get("suggestion"), + references=item.get("references", []), + prompt_version=f"{self.prompt_name}-v{self.prompt_version}", + static_analysis_context=None, + ) + findings.append(finding) + except (ValueError, KeyError): + continue + + return findings + + async def explain( + self, + context: ExplainContext, + model: str | None = None, + ) -> ExplainResult: + """Provide a detailed explanation about a specific finding. + + Args: + context: The explanation context with question, finding, and history. + model: Optional model override. Defaults to agent's default_model. + + Returns: + ExplainResult with response text and cost tracking. + """ + if not self.explain_prompt_name: + # Derive from prompt_name if not set + explain_name = f"{self.prompt_name}-explain" + else: + explain_name = self.explain_prompt_name + + template = self.prompt_registry.get(explain_name, self.explain_prompt_version) + + # Format conversation history + history_text = "" + if context.conversation_history: + history_lines = [] + for msg in context.conversation_history: + role = msg.get("role", "user").capitalize() + content = msg.get("content", "") + history_lines.append(f"**{role}:** {content}") + history_text = "\n\n".join(history_lines) + + # Build finding context + finding_lines = ( + f"{context.finding.line_start}" + if context.finding.line_start == context.finding.line_end + else f"{context.finding.line_start}-{context.finding.line_end}" + ) + + content = template.render( + finding_title=context.finding.title, + finding_file=context.finding.file, + finding_lines=finding_lines, + finding_description=context.finding.description, + finding_reasoning=context.finding.reasoning, + finding_severity=context.finding.severity.value, + finding_suggestion=context.finding.suggestion or "No suggestion provided.", + diff=context.diff, + conversation_history=history_text or "No previous conversation.", + question=context.question, + ) + + messages = [{"role": "user", "content": content}] + response = await self._call_llm(messages, model or self.default_model) + + return ExplainResult( + response=response.content.strip(), + tokens_used=response.tokens_in + response.tokens_out, + cost_usd=response.cost_usd, + confidence=0.8, # Default confidence for explanations + ) diff --git a/src/arbiter/agents/complexity.py b/src/arbiter/agents/complexity.py new file mode 100644 index 0000000..51f0085 --- /dev/null +++ b/src/arbiter/agents/complexity.py @@ -0,0 +1,17 @@ +"""Complexity review agent.""" + +from arbiter.agents.base import Agent, ReviewContext +from arbiter.models import AgentName, Finding + + +class ComplexityAgent(Agent): + """Agent focused on code complexity and architecture.""" + + name = AgentName.COMPLEXITY + prompt_name = "complexity" + prompt_version = "1.0" + explain_prompt_name = "complexity-explain" + default_model = "gpt-4o-mini" + + def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: + return self._parse_json_findings(content, context) diff --git a/src/arbiter/agents/security.py b/src/arbiter/agents/security.py new file mode 100644 index 0000000..9bc5f1e --- /dev/null +++ b/src/arbiter/agents/security.py @@ -0,0 +1,17 @@ +"""Security review agent.""" + +from arbiter.agents.base import Agent, ReviewContext +from arbiter.models import AgentName, Finding + + +class SecurityAgent(Agent): + """Agent focused on security vulnerabilities and risks.""" + + name = AgentName.SECURITY + prompt_name = "security" + prompt_version = "1.0" + explain_prompt_name = "security-explain" + default_model = "gpt-4o" + + def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: + return self._parse_json_findings(content, context) diff --git a/src/arbiter/agents/style.py b/src/arbiter/agents/style.py new file mode 100644 index 0000000..524ff10 --- /dev/null +++ b/src/arbiter/agents/style.py @@ -0,0 +1,17 @@ +"""Style review agent.""" + +from arbiter.agents.base import Agent, ReviewContext +from arbiter.models import AgentName, Finding + + +class StyleAgent(Agent): + """Agent focused on code style, naming, and readability.""" + + name = AgentName.STYLE + prompt_name = "style" + prompt_version = "1.0" + explain_prompt_name = "style-explain" + default_model = "gpt-4o-mini" + + def _parse_response(self, content: str, context: ReviewContext) -> list[Finding]: + return self._parse_json_findings(content, context) diff --git a/src/arbiter/cli.py b/src/arbiter/cli.py new file mode 100644 index 0000000..f2a4937 --- /dev/null +++ b/src/arbiter/cli.py @@ -0,0 +1,407 @@ +"""Command-line interface for Arbiter.""" + +import asyncio +import json +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Annotated + +import typer +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent +from arbiter.analysis import DiffParser, StaticAnalysisRunner +from arbiter.config import get_settings +from arbiter.deliberation import Coordinator, DeliberationResult +from arbiter.llm import LiteLLMClient, PromptRegistry +from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity, Verdict + +if TYPE_CHECKING: + from arbiter.agents.base import Agent + +app = typer.Typer( + name="arbiter", + help="A multi-agent code review system that shows its work.", + no_args_is_help=True, +) +console = Console() + + +def _severity_color(severity: Severity) -> str: + colors = { + Severity.CRITICAL: "red bold", + Severity.HIGH: "red", + Severity.MEDIUM: "yellow", + Severity.LOW: "blue", + Severity.INFO: "dim", + } + return colors.get(severity, "white") + + +def _severity_icon(severity: Severity) -> str: + icons = { + Severity.CRITICAL: "!!", + Severity.HIGH: "!", + Severity.MEDIUM: "*", + Severity.LOW: "-", + Severity.INFO: "i", + } + return icons.get(severity, " ") + + +def _verdict_color(verdict: Verdict) -> str: + colors = { + Verdict.APPROVE: "green", + Verdict.COMMENT: "yellow", + Verdict.REQUEST_CHANGES: "red", + } + return colors.get(verdict, "white") + + +def _verdict_icon(verdict: Verdict) -> str: + icons = { + Verdict.APPROVE: "[ok]", + Verdict.COMMENT: "[..]", + Verdict.REQUEST_CHANGES: "[!!]", + } + return icons.get(verdict, "") + + +def _format_rich(result: DeliberationResult, agent_results: list[ReviewResult]) -> None: + """Format results using Rich for terminal output.""" + total_tokens = sum(r.tokens_used for r in agent_results) + result.tokens_used + total_cost = sum(r.cost_usd for r in agent_results) + result.cost_usd + + # Verdict panel + verdict_style = _verdict_color(result.verdict) + verdict_icon = _verdict_icon(result.verdict) + verdict_text = f"[{verdict_style} bold]{verdict_icon} {result.verdict.value.upper()}[/]" + verdict_text += f" (confidence: {result.verdict_confidence:.0%})" + + summary = f"{verdict_text}\n\n" + summary += f"[bold]{result.total_findings}[/bold] findings from [bold]{len(agent_results)}[/bold] agents" + if result.conflicts: + summary += f" | [yellow]{len(result.conflicts)}[/yellow] conflicts" + summary += f"\nTokens: {total_tokens:,} | Cost: ${total_cost:.4f}" + + console.print(Panel(summary, title="Arbiter Review", border_style="blue")) + + if result.verdict_reasoning: + console.print(f"\n[dim]Reason:[/dim] {result.verdict_reasoning}") + + if result.total_findings == 0: + console.print("\n[green]No issues found![/green]") + return + + # Findings table + console.print("\n[bold]Findings[/bold]") + + table = Table(show_header=True, header_style="bold") + table.add_column("Sev", width=4) + table.add_column("Agent", width=12) + table.add_column("File", width=30) + table.add_column("Line", width=8) + table.add_column("Title", width=50) + + for finding in result.findings: + sev_icon = _severity_icon(finding.severity) + sev_style = _severity_color(finding.severity) + line_range = ( + f"{finding.line_start}-{finding.line_end}" + if finding.line_start != finding.line_end + else str(finding.line_start) + ) + table.add_row( + f"[{sev_style}]{sev_icon}[/{sev_style}]", + finding.agent.value, + finding.file, + line_range, + finding.title, + ) + + console.print(table) + + # Conflicts section + if result.conflicts: + console.print("\n[bold yellow]Conflicts[/bold yellow]") + for i, conflict in enumerate(result.conflicts, 1): + console.print(f"\n{i}. [{conflict.nature.value}] {conflict.description}") + if result.resolutions: + for resolution in result.resolutions: + if resolution.conflict_id == conflict.id: + console.print(f" [dim]Resolution:[/dim] {resolution.reasoning}") + break + + # Details for critical/high findings + critical_high = [f for f in result.findings if f.severity in (Severity.CRITICAL, Severity.HIGH)] + if critical_high: + console.print("\n[bold]Details[/bold]") + for finding in critical_high: + console.print( + f"\n[{_severity_color(finding.severity)}]{finding.title}[/]" + f" ({finding.file}:{finding.line_start})" + ) + console.print(f" {finding.description}") + if finding.suggestion: + console.print(f" [dim]Suggestion:[/dim] {finding.suggestion}") + + # Deliberation log + console.print("\n[bold dim]Deliberation Log[/bold dim]") + for step in result.steps: + console.print(f" [{step.step_type.value}] {step.description}") + + +def _format_json(result: DeliberationResult, agent_results: list[ReviewResult]) -> None: + """Format results as JSON.""" + output = { + "verdict": result.verdict.value, + "verdict_confidence": result.verdict_confidence, + "verdict_reasoning": result.verdict_reasoning, + "findings": [f.model_dump() for f in result.findings], + "conflicts": [c.model_dump() for c in result.conflicts], + "resolutions": [r.model_dump() for r in result.resolutions], + "deliberation_steps": [s.model_dump() for s in result.steps], + "agents": [r.model_dump() for r in agent_results], + "summary": { + "total_findings": result.total_findings, + "critical_count": result.critical_count, + "high_count": result.high_count, + "conflicts_count": len(result.conflicts), + "total_tokens": sum(r.tokens_used for r in agent_results) + result.tokens_used, + "total_cost_usd": sum(r.cost_usd for r in agent_results) + result.cost_usd, + }, + } + console.print(json.dumps(output, indent=2, default=str)) + + +def _format_markdown(result: DeliberationResult, agent_results: list[ReviewResult]) -> None: + """Format results as Markdown.""" + lines = ["# Arbiter Review", ""] + + # Verdict + verdict_badge = { + Verdict.APPROVE: ":white_check_mark:", + Verdict.COMMENT: ":speech_balloon:", + Verdict.REQUEST_CHANGES: ":x:", + }.get(result.verdict, "") + + lines.append(f"## Verdict: {verdict_badge} {result.verdict.value.upper()}") + lines.append("") + lines.append(f"**Confidence:** {result.verdict_confidence:.0%}") + if result.verdict_reasoning: + lines.append(f"**Reason:** {result.verdict_reasoning}") + lines.append("") + + # Summary + lines.append(f"**{result.total_findings} findings** from {len(agent_results)} agents") + if result.conflicts: + lines.append(f"**{len(result.conflicts)} conflicts** detected") + lines.append("") + + # Findings by severity + if result.findings: + lines.append("## Findings") + lines.append("") + + for finding in result.findings: + severity_badge = f"**{finding.severity.value.upper()}**" + lines.append( + f"- {severity_badge} [{finding.agent.value}] `{finding.file}:{finding.line_start}` - {finding.title}" + ) + lines.append(f" - {finding.description}") + if finding.suggestion: + lines.append(f" - *Suggestion:* {finding.suggestion}") + lines.append("") + + # Conflicts + if result.conflicts: + lines.append("## Conflicts") + lines.append("") + + for conflict in result.conflicts: + lines.append(f"### {conflict.nature.value.title()}") + lines.append(conflict.description) + for resolution in result.resolutions: + if resolution.conflict_id == conflict.id: + lines.append(f"**Resolution:** {resolution.reasoning}") + break + lines.append("") + + # Deliberation log + lines.append("## Deliberation Log") + lines.append("") + for step in result.steps: + lines.append(f"1. **{step.step_type.value}**: {step.description}") + lines.append("") + + console.print("\n".join(lines)) + + +async def _run_review( + diff: str, + policy: Policy, + templates_dir: Path, + static_analysis: bool = True, + work_dir: Path | None = None, +) -> tuple[list[ReviewResult], DeliberationResult]: + """Run all enabled agents on the diff with deliberation.""" + settings = get_settings() + llm_client = LiteLLMClient( + timeout=settings.llm_timeout, + max_retries=settings.llm_max_retries, + ) + prompt_registry = PromptRegistry(templates_dir) + + # Parse the diff + parser = DiffParser() + parsed_diff = parser.parse(diff) + + # Run static analysis if enabled + static_findings: list[Finding] = [] + if static_analysis and work_dir: + runner = StaticAnalysisRunner() + static_result = await runner.run(parsed_diff, work_dir) + for sf in static_result.findings: + static_findings.append(runner.convert_to_finding(sf)) + + context = ReviewContext( + diff=diff, + policy=policy, + static_analysis_context=f"Found {len(static_findings)} static analysis findings." + if static_findings + else "No static analysis data available.", + ) + + # Create agents for enabled agent types + agent_classes: dict[AgentName, type[Agent]] = { + AgentName.SECURITY: SecurityAgent, + AgentName.STYLE: StyleAgent, + AgentName.COMPLEXITY: ComplexityAgent, + } + + agents: list[Agent] = [] + for agent_name in policy.get_enabled_agents(): + if agent_name in agent_classes: + agents.append(agent_classes[agent_name](llm_client, prompt_registry)) + + # Run agents in parallel + results = await asyncio.gather( + *[agent.review(context) for agent in agents], + return_exceptions=True, + ) + + # Filter out exceptions and log them + valid_results: list[ReviewResult] = [] + for result in results: + if isinstance(result, BaseException): + console.print(f"[red]Agent error:[/red] {result}") + else: + valid_results.append(result) + + # Run deliberation + coordinator = Coordinator(llm_client=llm_client) + deliberation_result = await coordinator.deliberate( + valid_results, + static_findings=static_findings if static_findings else None, + ) + + return valid_results, deliberation_result + + +@app.command() +def review( + diff_file: Annotated[ + str, + typer.Argument(help="Path to diff file, or '-' for stdin"), + ], + policy: Annotated[ + Path | None, + typer.Option("--policy", "-p", help="Path to policy YAML file"), + ] = None, + model: Annotated[ + str | None, + typer.Option("--model", "-m", help="Override LLM model for all agents"), + ] = None, + output_format: Annotated[ + str, + typer.Option("--format", "-f", help="Output format: rich, json, markdown"), + ] = "rich", + static_analysis: Annotated[ + bool, + typer.Option("--static-analysis/--no-static-analysis", help="Run static analysis"), + ] = True, + work_dir: Annotated[ + Path | None, + typer.Option("--work-dir", "-w", help="Working directory for static analysis"), + ] = None, +) -> None: + """Review a diff file using AI agents.""" + # Read diff content + if diff_file == "-": + diff_content = sys.stdin.read() + else: + diff_path = Path(diff_file) + if not diff_path.exists(): + console.print(f"[red]Error:[/red] File not found: {diff_file}") + raise typer.Exit(1) + diff_content = diff_path.read_text() + + if not diff_content.strip(): + console.print("[yellow]Warning:[/yellow] Empty diff provided") + raise typer.Exit(0) + + # Load policy + review_policy = Policy() + if policy: + if not policy.exists(): + console.print(f"[red]Error:[/red] Policy file not found: {policy}") + raise typer.Exit(1) + review_policy = Policy.load(policy) + + # Apply model override + if model: + for agent_config in review_policy.agents.values(): + agent_config.model = model + + # Determine prompts directory + settings = get_settings() + templates_dir = settings.templates_dir + if not templates_dir.is_absolute(): + templates_dir = Path.cwd() / templates_dir + + # Resolve work directory for static analysis + resolved_work_dir = ( + (work_dir.resolve() if work_dir else Path.cwd()) if static_analysis else None + ) + + # Run review + agent_results, deliberation_result = asyncio.run( + _run_review( + diff_content, + review_policy, + templates_dir, + static_analysis=static_analysis, + work_dir=resolved_work_dir, + ) + ) + + # Format output + if output_format == "json": + _format_json(deliberation_result, agent_results) + elif output_format == "markdown": + _format_markdown(deliberation_result, agent_results) + else: + _format_rich(deliberation_result, agent_results) + + +@app.command() +def version() -> None: + from arbiter import __version__ + + console.print(f"arbiter {__version__}") + + +if __name__ == "__main__": + app() diff --git a/src/arbiter/config.py b/src/arbiter/config.py new file mode 100644 index 0000000..72b2393 --- /dev/null +++ b/src/arbiter/config.py @@ -0,0 +1,121 @@ +"""Configuration settings for Arbiter.""" + +from functools import lru_cache +from pathlib import Path + +from pydantic import Field, SecretStr +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + model_config = SettingsConfigDict( + env_prefix="ARBITER_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # LLM settings + default_model: str = Field(default="gpt-4o", description="Default LLM model for agents") + llm_timeout: int = Field(default=60, description="LLM request timeout in seconds") + llm_max_retries: int = Field(default=3, description="Maximum LLM retry attempts") + + # Cost controls + max_tokens_per_review: int = Field( + default=50000, description="Maximum tokens allowed per review" + ) + max_cost_per_review_usd: float = Field( + default=0.50, description="Maximum cost per review in USD" + ) + cache_ttl_hours: int = Field(default=24, description="Cache TTL in hours") + + # Paths + templates_dir: Path = Field( + default=Path("templates"), description="Directory containing prompt templates" + ) + policy_path: Path | None = Field(default=None, description="Default policy file path") + + # Output settings + output_format: str = Field(default="rich", description="Output format: rich, json, or markdown") + + # Database settings + database_url: str = Field( + default="postgresql+asyncpg://arbiter:arbiter@localhost:5432/arbiter", + description="PostgreSQL connection URL", + ) + database_pool_size: int = Field(default=5, description="Database connection pool size") + database_max_overflow: int = Field(default=10, description="Max overflow connections") + + # Redis settings + redis_url: str = Field( + default="redis://localhost:6379/0", + description="Redis connection URL", + ) + redis_max_connections: int = Field(default=10, description="Redis max connections") + + # Webhook secrets + github_webhook_secret: SecretStr | None = Field( + default=None, description="GitHub webhook secret for HMAC verification" + ) + gitlab_webhook_token: SecretStr | None = Field( + default=None, description="GitLab webhook token for verification" + ) + + # Platform integration settings + github_token: SecretStr | None = Field( + default=None, description="GitHub API token for fetching diffs and posting comments" + ) + github_base_url: str = Field( + default="https://api.github.com", description="GitHub API base URL" + ) + gitlab_token: SecretStr | None = Field( + default=None, description="GitLab API token for fetching diffs and posting comments" + ) + gitlab_base_url: str = Field( + default="https://gitlab.com", description="GitLab instance base URL" + ) + integration_timeout: int = Field( + default=30, description="Integration API request timeout in seconds" + ) + integration_max_retries: int = Field( + default=3, description="Maximum retry attempts for integration API calls" + ) + status_check_context: str = Field( + default="arbiter", description="Context name for commit status checks" + ) + post_comments: bool = Field(default=True, description="Whether to post review comments on PRs") + update_status: bool = Field(default=True, description="Whether to update commit status checks") + + # API settings + api_title: str = Field(default="Arbiter API", description="API title for OpenAPI") + api_version: str = Field(default="0.5.0", description="API version") + cors_origins: list[str] = Field( + default=["http://localhost:3000"], description="Allowed CORS origins" + ) + api_rate_limit_per_minute: int = Field( + default=60, description="API rate limit per minute per client" + ) + + # Worker settings + worker_max_jobs: int = Field(default=10, description="Max concurrent worker jobs") + worker_job_timeout: int = Field(default=300, description="Job timeout in seconds") + worker_retry_attempts: int = Field(default=3, description="Number of retry attempts") + + # Follow-up conversation settings + followup_enabled: bool = Field(default=True, description="Enable follow-up question handling") + followup_confidence_threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Minimum confidence to process a follow-up question", + ) + followup_max_tokens_per_response: int = Field( + default=2000, description="Maximum tokens per follow-up response" + ) + + +@lru_cache +def get_settings() -> Settings: + return Settings() diff --git a/src/arbiter/llm/__init__.py b/src/arbiter/llm/__init__.py new file mode 100644 index 0000000..2b8b971 --- /dev/null +++ b/src/arbiter/llm/__init__.py @@ -0,0 +1,12 @@ +"""LLM client and prompt management.""" + +from arbiter.llm.client import LiteLLMClient, LLMClient, LLMResponse +from arbiter.llm.prompts import PromptRegistry, PromptTemplate + +__all__ = [ + "LLMClient", + "LLMResponse", + "LiteLLMClient", + "PromptRegistry", + "PromptTemplate", +] diff --git a/src/arbiter/llm/client.py b/src/arbiter/llm/client.py new file mode 100644 index 0000000..45c92ed --- /dev/null +++ b/src/arbiter/llm/client.py @@ -0,0 +1,66 @@ +"""LLM client abstraction.""" + +from abc import ABC, abstractmethod +from typing import Any + +import litellm +from pydantic import BaseModel, Field + + +class LLMResponse(BaseModel): + """Response from an LLM completion.""" + + content: str = Field(description="The generated text content") + model: str = Field(description="Model that generated the response") + tokens_in: int = Field(ge=0, description="Number of input tokens") + tokens_out: int = Field(ge=0, description="Number of output tokens") + cost_usd: float = Field(ge=0.0, description="Estimated cost in USD") + + +class LLMClient(ABC): + """Abstract base class for LLM clients.""" + + @abstractmethod + async def complete( + self, + messages: list[dict[str, str]], + model: str, + **kwargs: Any, + ) -> LLMResponse: ... + + +class LiteLLMClient(LLMClient): + """LLM client implementation using LiteLLM.""" + + def __init__(self, timeout: int = 60, max_retries: int = 3) -> None: + self.timeout = timeout + self.max_retries = max_retries + + async def complete( + self, + messages: list[dict[str, str]], + model: str, + **kwargs: Any, + ) -> LLMResponse: + """Generate a completion using LiteLLM.""" + response = await litellm.acompletion( + model=model, + messages=messages, + timeout=self.timeout, + num_retries=self.max_retries, + **kwargs, + ) + + content = response.choices[0].message.content or "" + usage = response.usage + + # Calculate cost using litellm's cost tracking + cost = litellm.completion_cost(completion_response=response) + + return LLMResponse( + content=content, + model=response.model or model, + tokens_in=usage.prompt_tokens if usage else 0, + tokens_out=usage.completion_tokens if usage else 0, + cost_usd=cost, + ) diff --git a/src/arbiter/llm/prompts.py b/src/arbiter/llm/prompts.py new file mode 100644 index 0000000..fa03986 --- /dev/null +++ b/src/arbiter/llm/prompts.py @@ -0,0 +1,81 @@ +"""Prompt template management.""" + +import re +from pathlib import Path +from typing import Any, ClassVar + +from pydantic import BaseModel, Field + + +class PromptTemplate(BaseModel): + """A versioned prompt template.""" + + name: str = Field(description="Template name (e.g., 'security')") + version: str = Field(description="Semantic version (e.g., '1.0')") + content: str = Field(description="Raw template content with placeholders") + + def render(self, **kwargs: Any) -> str: + result = self.content + for key, value in kwargs.items(): + placeholder = f"{{{{{key}}}}}" + result = result.replace(placeholder, str(value)) + return result + + @property + def full_name(self) -> str: + return f"{self.name}-v{self.version}" + + +class PromptRegistry: + """Registry for loading and caching prompt templates.""" + + _cache: ClassVar[dict[str, PromptTemplate]] = {} + _pattern: ClassVar[re.Pattern[str]] = re.compile(r"^(.+)-v(\d+\.\d+)\.md$") + + def __init__(self, templates_dir: Path) -> None: + self.templates_dir = templates_dir + + def get(self, name: str, version: str) -> PromptTemplate: + """Get a prompt template by name and version. + + Args: + name: Template name (e.g., 'security'). + version: Template version (e.g., '1.0'). + + Returns: + PromptTemplate instance. + + Raises: + FileNotFoundError: If the template file doesn't exist. + """ + cache_key = f"{name}-v{version}" + if cache_key in self._cache: + return self._cache[cache_key] + + file_path = self.templates_dir / f"{name}-v{version}.md" + if not file_path.exists(): + raise FileNotFoundError(f"Prompt template not found: {file_path}") + + content = file_path.read_text() + template = PromptTemplate(name=name, version=version, content=content) + self._cache[cache_key] = template + return template + + def list_templates(self) -> list[tuple[str, str]]: + """List all available templates. + + Returns: + List of (name, version) tuples. + """ + templates: list[tuple[str, str]] = [] + if not self.templates_dir.exists(): + return templates + + for file_path in self.templates_dir.glob("*.md"): + match = self._pattern.match(file_path.name) + if match: + templates.append((match.group(1), match.group(2))) + return sorted(templates) + + def clear_cache(self) -> None: + self._cache.clear() diff --git a/src/arbiter/models/__init__.py b/src/arbiter/models/__init__.py new file mode 100644 index 0000000..f0d1ba2 --- /dev/null +++ b/src/arbiter/models/__init__.py @@ -0,0 +1,16 @@ +"""Arbiter data models.""" + +from arbiter.models.enums import AgentName, Severity, Verdict +from arbiter.models.finding import Finding +from arbiter.models.policy import AgentConfig, Policy +from arbiter.models.review import ReviewResult + +__all__ = [ + "AgentConfig", + "AgentName", + "Finding", + "Policy", + "ReviewResult", + "Severity", + "Verdict", +] diff --git a/src/arbiter/models/enums.py b/src/arbiter/models/enums.py new file mode 100644 index 0000000..6af31e7 --- /dev/null +++ b/src/arbiter/models/enums.py @@ -0,0 +1,29 @@ +"""Core enumerations for Arbiter.""" + +from enum import StrEnum + + +class Severity(StrEnum): + """Finding severity levels.""" + + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + INFO = "info" + + +class Verdict(StrEnum): + """Review verdict types.""" + + APPROVE = "approve" + REQUEST_CHANGES = "request_changes" + COMMENT = "comment" + + +class AgentName(StrEnum): + """Available review agent names.""" + + SECURITY = "security" + STYLE = "style" + COMPLEXITY = "complexity" diff --git a/src/arbiter/models/finding.py b/src/arbiter/models/finding.py new file mode 100644 index 0000000..c2296a0 --- /dev/null +++ b/src/arbiter/models/finding.py @@ -0,0 +1,28 @@ +"""Finding model for review results.""" + +from typing import Any + +from pydantic import BaseModel, Field + +from arbiter.models.enums import AgentName, Severity + + +class Finding(BaseModel): + """A single finding from an agent's review.""" + + id: str = Field(description="Unique identifier for this finding") + agent: AgentName = Field(description="Agent that produced this finding") + file: str = Field(description="File path where the issue was found") + line_start: int = Field(ge=1, description="Starting line number") + line_end: int = Field(ge=1, description="Ending line number") + severity: Severity = Field(description="Severity level of the finding") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence score (0.0-1.0)") + title: str = Field(description="Short title summarizing the finding") + description: str = Field(description="Detailed description of the issue") + reasoning: str = Field(description="Explanation of why this was flagged") + suggestion: str | None = Field(default=None, description="Optional fix recommendation") + references: list[str] = Field(default_factory=list, description="Links to docs, OWASP, etc.") + prompt_version: str = Field(description="Version of the prompt that produced this finding") + static_analysis_context: dict[str, Any] | None = Field( + default=None, description="Related static analysis data" + ) diff --git a/src/arbiter/models/policy.py b/src/arbiter/models/policy.py new file mode 100644 index 0000000..a7c59f9 --- /dev/null +++ b/src/arbiter/models/policy.py @@ -0,0 +1,58 @@ +"""Policy configuration models.""" + +from pathlib import Path +from typing import Self + +import yaml +from pydantic import BaseModel, Field + +from arbiter.models.enums import AgentName, Severity + + +class AgentConfig(BaseModel): + """Configuration for a single agent.""" + + enabled: bool = Field(default=True, description="Whether the agent is enabled") + model: str | None = Field(default=None, description="LLM model override for this agent") + severity_threshold: Severity = Field( + default=Severity.INFO, description="Minimum severity to report" + ) + prompt_additions: str | None = Field( + default=None, description="Additional instructions to append to the prompt" + ) + + +class Policy(BaseModel): + """Review policy configuration.""" + + version: str = Field(default="1.0", description="Policy schema version") + agents: dict[AgentName, AgentConfig] = Field( + default_factory=lambda: { + AgentName.SECURITY: AgentConfig(), + AgentName.STYLE: AgentConfig(), + AgentName.COMPLEXITY: AgentConfig(), + }, + description="Agent configurations", + ) + + @classmethod + def load(cls, path: Path) -> Self: + """Load policy from a YAML file.""" + with path.open() as f: + data = yaml.safe_load(f) + + if data is None: + return cls() + + # Convert string keys to AgentName enums + if "agents" in data and isinstance(data["agents"], dict): + agents = {} + for key, value in data["agents"].items(): + agent_name = AgentName(key) if isinstance(key, str) else key + agents[agent_name] = AgentConfig(**value) if isinstance(value, dict) else value + data["agents"] = agents + + return cls(**data) + + def get_enabled_agents(self) -> list[AgentName]: + return [name for name, config in self.agents.items() if config.enabled] diff --git a/src/arbiter/models/review.py b/src/arbiter/models/review.py new file mode 100644 index 0000000..d1d669f --- /dev/null +++ b/src/arbiter/models/review.py @@ -0,0 +1,16 @@ +"""Review result models.""" + +from pydantic import BaseModel, Field + +from arbiter.models.enums import AgentName +from arbiter.models.finding import Finding + + +class ReviewResult(BaseModel): + """Result from a single agent's review.""" + + agent_name: AgentName = Field(description="Name of the agent that performed the review") + findings: list[Finding] = Field(default_factory=list, description="Findings from the review") + duration_ms: int = Field(ge=0, description="Time taken in milliseconds") + tokens_used: int = Field(ge=0, description="Total tokens consumed") + cost_usd: float = Field(ge=0.0, description="Estimated cost in USD") diff --git a/src/arbiter/py.typed b/src/arbiter/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/templates/complexity-v1.0.md b/templates/complexity-v1.0.md new file mode 100644 index 0000000..20fd9e0 --- /dev/null +++ b/templates/complexity-v1.0.md @@ -0,0 +1,56 @@ +# Complexity Review Agent + +You are a code complexity reviewer focused on architecture and maintainability. Analyze the provided diff for complexity issues that impact long-term code health. + +## Focus Areas + +- **Cyclomatic complexity**: Functions with too many branches or paths +- **Cognitive complexity**: Code that is hard to understand or follow +- **Function length**: Functions doing too many things +- **Class design**: God objects, tight coupling, missing abstractions +- **Dependency management**: Circular dependencies, excessive coupling +- **Over-engineering**: Unnecessary abstractions, premature optimization +- **Under-engineering**: Missing error handling, ignored edge cases + +## Context + +{{static_analysis_context}} + +## Diff to Review + +```diff +{{diff}} +``` + +{{prompt_additions}} + +## Output Format + +Respond with a JSON array of findings. Each finding must have this structure: + +```json +[ + { + "file": "path/to/file.py", + "line_start": 10, + "line_end": 50, + "severity": "critical|high|medium|low|info", + "confidence": 0.80, + "title": "Short title describing the issue", + "description": "Detailed description of the complexity concern", + "reasoning": "Why this complexity is problematic", + "suggestion": "How to simplify or refactor (optional)", + "references": [] + } +] +``` + +If no complexity issues are found, return an empty array: `[]` + +## Guidelines + +1. Consider context - complex code may be justified for complex problems +2. Flag over-engineering as readily as under-engineering +3. High complexity is only critical if it's likely to cause bugs +4. Suggest specific refactoring strategies when possible +5. Reference static analysis metrics (cyclomatic complexity, etc.) when available diff --git a/templates/security-v1.0.md b/templates/security-v1.0.md new file mode 100644 index 0000000..52e2bfc --- /dev/null +++ b/templates/security-v1.0.md @@ -0,0 +1,55 @@ +# Security Review Agent + +You are a security-focused code reviewer. Analyze the provided diff for security vulnerabilities and potential risks. + +## Focus Areas + +- **Injection vulnerabilities**: SQL injection, command injection, XSS, template injection +- **Authentication/Authorization**: Missing auth checks, privilege escalation, insecure session handling +- **Data exposure**: Hardcoded secrets, PII leaks, sensitive data in logs +- **Cryptographic issues**: Weak algorithms, improper key management, missing encryption +- **Input validation**: Missing or insufficient validation, type confusion +- **OWASP Top 10**: All categories including broken access control, security misconfiguration + +## Context + +{{static_analysis_context}} + +## Diff to Review + +```diff +{{diff}} +``` + +{{prompt_additions}} + +## Output Format + +Respond with a JSON array of findings. Each finding must have this structure: + +```json +[ + { + "file": "path/to/file.py", + "line_start": 10, + "line_end": 15, + "severity": "critical|high|medium|low|info", + "confidence": 0.95, + "title": "Short title describing the issue", + "description": "Detailed description of the vulnerability", + "reasoning": "Why this is a security concern", + "suggestion": "How to fix this issue (optional)", + "references": ["https://owasp.org/..."] + } +] +``` + +If no security issues are found, return an empty array: `[]` + +## Guidelines + +1. Only report genuine security concerns, not style or performance issues +2. Assign appropriate severity based on exploitability and impact +3. Set confidence based on how certain you are this is a real vulnerability +4. Provide actionable suggestions when possible +5. Include relevant OWASP or CWE references diff --git a/templates/style-v1.0.md b/templates/style-v1.0.md new file mode 100644 index 0000000..5bed43b --- /dev/null +++ b/templates/style-v1.0.md @@ -0,0 +1,55 @@ +# Style Review Agent + +You are a code style reviewer focused on readability and consistency. Analyze the provided diff for style issues that impact code maintainability. + +## Focus Areas + +- **Naming conventions**: Variable, function, class, and file naming consistency +- **Code organisation**: Logical grouping, import ordering, module structure +- **Readability**: Clear variable names, appropriate comments, self-documenting code +- **Consistency**: Adherence to existing patterns in the codebase +- **Best practices**: Language-specific idioms and conventions +- **Documentation**: Missing or outdated docstrings, misleading comments + +## Context + +{{static_analysis_context}} + +## Diff to Review + +```diff +{{diff}} +``` + +{{prompt_additions}} + +## Output Format + +Respond with a JSON array of findings. Each finding must have this structure: + +```json +[ + { + "file": "path/to/file.py", + "line_start": 10, + "line_end": 15, + "severity": "critical|high|medium|low|info", + "confidence": 0.85, + "title": "Short title describing the issue", + "description": "Detailed description of the style concern", + "reasoning": "Why this matters for maintainability", + "suggestion": "How to improve this (optional)", + "references": [] + } +] +``` + +If no style issues are found, return an empty array: `[]` + +## Guidelines + +1. Focus on readability and maintainability, not personal preferences +2. Respect existing codebase conventions even if they differ from common standards +3. Most style issues should be low or info severity unless they significantly impact readability +4. Only flag high severity for style issues that could cause confusion or bugs +5. Provide concrete suggestions with example code when helpful diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..344696a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Arbiter test suite.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..96e104c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,490 @@ +"""Pytest configuration and fixtures.""" + +from collections.abc import AsyncGenerator +from datetime import UTC, datetime +from pathlib import Path +from typing import Any +from uuid import uuid4 + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from arbiter.db.models import Base, FindingModel, ReviewModel +from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient +from arbiter.llm.client import LLMClient, LLMResponse +from arbiter.llm.prompts import PromptRegistry +from arbiter.models import Policy +from arbiter.models.enums import AgentName, Severity, Verdict + + +class MockLLMClient(LLMClient): + """Mock LLM client for testing.""" + + def __init__(self, responses: list[str] | None = None) -> None: + self.responses = responses or [] + self.calls: list[dict[str, Any]] = [] + self._call_index = 0 + + async def complete( + self, + messages: list[dict[str, str]], + model: str, + **kwargs: Any, + ) -> LLMResponse: + """Record the call and return a canned response.""" + self.calls.append( + { + "messages": messages, + "model": model, + "kwargs": kwargs, + } + ) + + content = "" + if self._call_index < len(self.responses): + content = self.responses[self._call_index] + self._call_index += 1 + + return LLMResponse( + content=content, + model=model, + tokens_in=100, + tokens_out=50, + cost_usd=0.001, + ) + + def reset(self) -> None: + self.calls = [] + self._call_index = 0 + + +@pytest.fixture +def mock_llm() -> MockLLMClient: + return MockLLMClient() + + +@pytest.fixture +def mock_llm_with_findings() -> MockLLMClient: + response = """```json +[ + { + "file": "src/auth.py", + "line_start": 10, + "line_end": 15, + "severity": "high", + "confidence": 0.9, + "title": "SQL Injection vulnerability", + "description": "User input is directly concatenated into SQL query", + "reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL", + "suggestion": "Use parameterized queries instead", + "references": ["https://owasp.org/www-community/attacks/SQL_Injection"] + } +] +```""" + return MockLLMClient(responses=[response]) + + +@pytest.fixture +def sample_diff() -> str: + return """diff --git a/src/auth.py b/src/auth.py +index 1234567..abcdefg 100644 +--- a/src/auth.py ++++ b/src/auth.py +@@ -8,6 +8,12 @@ def authenticate(username, password): + if not username or not password: + return None + ++ # Check user in database ++ query = "SELECT * FROM users WHERE username = '" + username + "'" ++ cursor.execute(query) ++ user = cursor.fetchone() ++ if user and user.password == password: ++ return user + return None +""" + + +@pytest.fixture +def sample_policy() -> Policy: + return Policy() + + +@pytest.fixture +def prompt_registry(tmp_path: Path) -> PromptRegistry: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + + (templates_dir / "security-v1.0.md").write_text( + "Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}" + ) + (templates_dir / "style-v1.0.md").write_text( + "Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}" + ) + (templates_dir / "complexity-v1.0.md").write_text( + "Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}" + ) + + return PromptRegistry(templates_dir) + + +# Database fixtures for integration tests +@pytest.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + echo=False, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest.fixture +async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]: + session_factory = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + async with session_factory() as session: + yield session + + +@pytest.fixture +async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + from arbiter.api.deps import get_db, get_redis + from arbiter.main import app + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + async def override_get_redis() -> AsyncGenerator[MockRedis, None]: + yield MockRedis() + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_redis] = override_get_redis + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class MockRedis: + """Mock Redis client for testing.""" + + def __init__(self) -> None: + self._data: dict[str, Any] = {} + + async def get(self, key: str) -> str | None: + return self._data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002 + self._data[key] = value + return True + + async def delete(self, key: str) -> int: + if key in self._data: + del self._data[key] + return 1 + return 0 + + async def ping(self) -> bool: + return True + + async def llen(self, key: str) -> int: + data = self._data.get(key, []) + return len(data) if isinstance(data, list) else 0 + + async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002 + return type("Job", (), {"job_id": "test-job-id"})() + + +@pytest.fixture +def mock_redis() -> MockRedis: + return MockRedis() + + +class MockPlatformClient(PlatformClient): + """Mock platform client for testing.""" + + def __init__(self) -> None: + self._comments: list[Comment] = [] + self._posted_comments: list[dict[str, Any]] = [] + self._status_updates: list[dict[str, Any]] = [] + self._diff = "mock diff content" + self._closed = False + self._fail_on: set[str] = set() # Methods to fail on + + @property + def platform(self) -> Platform: + return Platform.GITHUB + + async def get_pr_diff( + self, + repository: str, # noqa: ARG002 + pr_number: int, # noqa: ARG002 + ) -> str: + if "get_pr_diff" in self._fail_on: + from arbiter.integrations import IntegrationError + + raise IntegrationError("Mock failure: get_pr_diff") + return self._diff + + async def post_comment(self, repository: str, pr_number: int, body: str) -> str: + if "post_comment" in self._fail_on: + from arbiter.integrations import IntegrationError + + raise IntegrationError("Mock failure: post_comment") + comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123" + self._posted_comments.append( + {"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url} + ) + return comment_url + + async def update_commit_status( + self, + repository: str, + sha: str, + status: CommitStatus, + description: str, + context: str, + target_url: str | None = None, + ) -> None: + if "update_commit_status" in self._fail_on: + from arbiter.integrations import IntegrationError + + raise IntegrationError("Mock failure: update_commit_status") + self._status_updates.append( + { + "repository": repository, + "sha": sha, + "status": status, + "description": description, + "context": context, + "target_url": target_url, + } + ) + + async def get_pr_info(self, repository: str, pr_number: int) -> Any: + if "get_pr_info" in self._fail_on: + from arbiter.integrations import IntegrationError + + raise IntegrationError("Mock failure: get_pr_info") + from arbiter.integrations.base import PullRequestInfo + + return PullRequestInfo( + platform=Platform.GITHUB, + repository=repository, + pr_number=pr_number, + head_sha="abc123", + base_sha="def456", + head_ref="feature", + base_ref="main", + title="Test PR", + author="testuser", + url=f"https://github.com/{repository}/pull/{pr_number}", + is_draft=False, + ) + + async def get_comments( + self, + repository: str, # noqa: ARG002 + pr_number: int, # noqa: ARG002 + ) -> list[Comment]: + if "get_comments" in self._fail_on: + from arbiter.integrations import IntegrationError + + raise IntegrationError("Mock failure: get_comments") + return self._comments + + async def update_comment( + self, repository: str, pr_number: int, comment_id: str, body: str + ) -> str: + if "update_comment" in self._fail_on: + from arbiter.integrations import IntegrationError + + raise IntegrationError("Mock failure: update_comment") + comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}" + self._posted_comments.append( + { + "repository": repository, + "pr_number": pr_number, + "comment_id": comment_id, + "body": body, + "url": comment_url, + } + ) + return comment_url + + async def close(self) -> None: + self._closed = True + + +@pytest.fixture +def mock_platform_client() -> MockPlatformClient: + return MockPlatformClient() + + +@pytest.fixture +async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel: + review = ReviewModel( + id=str(uuid4()), + repository="owner/repo", + pr_number=42, + pr_title="Test PR", + base_sha="abc1234567890", + head_sha="def0987654321", + author="testuser", + is_draft=False, + status="completed", + verdict=Verdict.COMMENT, + verdict_confidence=0.75, + verdict_reasoning="Found some issues to discuss", + total_tokens=1500, + total_cost_usd=0.015, + tokens_by_agent={"security": 500, "style": 500, "complexity": 500}, + cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005}, + created_at=datetime.now(UTC), + started_at=datetime.now(UTC), + completed_at=datetime.now(UTC), + ) + db_session.add(review) + + # Add findings + finding1 = FindingModel( + id=str(uuid4()), + review_id=review.id, + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection vulnerability", + description="User input is directly concatenated into SQL query", + reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL", + suggestion="Use parameterized queries instead", + references=["https://owasp.org/www-community/attacks/SQL_Injection"], + prompt_version="security-v1.0", + ) + finding2 = FindingModel( + id=str(uuid4()), + review_id=review.id, + agent=AgentName.STYLE, + file="src/auth.py", + line_start=20, + line_end=25, + severity=Severity.MEDIUM, + confidence=0.85, + title="Long function", + description="Function exceeds 50 lines", + reasoning="Long functions are harder to test and maintain", + suggestion="Consider breaking into smaller functions", + references=[], + prompt_version="style-v1.0", + ) + db_session.add(finding1) + db_session.add(finding2) + + await db_session.commit() + await db_session.refresh(review) + return review + + +@pytest.fixture +async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]: + reviews = [] + for i in range(5): + review = ReviewModel( + id=str(uuid4()), + repository="owner/repo" if i < 3 else "other/repo", + pr_number=i + 1, + pr_title=f"Test PR #{i + 1}", + base_sha=f"base{i:040d}", + head_sha=f"head{i:040d}", + author="testuser" if i % 2 == 0 else "otheruser", + is_draft=False, + status="completed" if i < 4 else "failed", + verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None), + verdict_confidence=0.8 if i < 4 else None, + total_tokens=1000 * (i + 1), + total_cost_usd=0.01 * (i + 1), + created_at=datetime.now(UTC), + completed_at=datetime.now(UTC) if i < 4 else None, + ) + db_session.add(review) + reviews.append(review) + + # Add a finding to each completed review + if i < 4: + finding = FindingModel( + id=str(uuid4()), + review_id=review.id, + agent=AgentName.SECURITY, + file="src/test.py", + line_start=10, + line_end=15, + severity=Severity.CRITICAL if i == 0 else Severity.HIGH, + confidence=0.9, + title=f"Finding {i + 1}", + description="Test finding", + reasoning="Test reasoning", + prompt_version="security-v1.0", + ) + db_session.add(finding) + + await db_session.commit() + for review in reviews: + await db_session.refresh(review) + return reviews + + +@pytest.fixture +def mock_settings() -> Any: + class MockSecretStr: + def __init__(self, value: str) -> None: + self._value = value + + def get_secret_value(self) -> str: + return self._value + + class MockSettings: + github_token = MockSecretStr("ghp_test_token") + gitlab_token = MockSecretStr("glpat_test_token") + github_base_url = "https://api.github.com" + gitlab_base_url = "https://gitlab.com/api/v4" + github_webhook_secret = MockSecretStr("webhook_secret") + gitlab_webhook_token = MockSecretStr("gitlab_token") + integration_timeout = 30 + integration_max_retries = 3 + llm_timeout = 60 + llm_max_retries = 3 + post_comments = True + update_status = True + status_check_context = "arbiter" + templates_dir = Path("templates") + followup_enabled = True + followup_confidence_threshold = 0.5 + + return MockSettings() + + +@pytest.fixture +def mock_settings_no_github() -> Any: + class MockSettings: + github_token = None + gitlab_token = None + github_base_url = "https://api.github.com" + gitlab_base_url = "https://gitlab.com/api/v4" + integration_timeout = 30 + integration_max_retries = 3 + post_comments = False + update_status = False + status_check_context = "arbiter" + templates_dir = Path("templates") + + return MockSettings() diff --git a/tests/fixtures/complex-function.diff b/tests/fixtures/complex-function.diff new file mode 100644 index 0000000..cfeb34b --- /dev/null +++ b/tests/fixtures/complex-function.diff @@ -0,0 +1,69 @@ +diff --git a/src/processor.py b/src/processor.py +index 1234567..abcdefg 100644 +--- a/src/processor.py ++++ b/src/processor.py +@@ -1,5 +1,65 @@ + """Data processor module.""" + + ++def process_data(data: dict, config: dict, options: dict | None = None) -> dict: ++ """Process data with many nested conditions.""" ++ result = {} ++ options = options or {} ++ ++ if data.get("type") == "A": ++ if config.get("mode") == "strict": ++ if options.get("validate"): ++ if data.get("value") > 100: ++ if config.get("transform"): ++ result["processed"] = data["value"] * 2 ++ else: ++ result["processed"] = data["value"] ++ else: ++ if options.get("default"): ++ result["processed"] = options["default"] ++ else: ++ result["processed"] = 0 ++ else: ++ result["processed"] = data.get("value", 0) ++ else: ++ result["processed"] = data.get("value", 0) ++ elif data.get("type") == "B": ++ if config.get("mode") == "strict": ++ if options.get("validate"): ++ if data.get("items"): ++ result["processed"] = len(data["items"]) ++ else: ++ result["processed"] = 0 ++ else: ++ result["processed"] = len(data.get("items", [])) ++ else: ++ result["processed"] = len(data.get("items", [])) ++ elif data.get("type") == "C": ++ if config.get("mode") == "strict": ++ if options.get("validate"): ++ if data.get("text"): ++ result["processed"] = data["text"].upper() ++ else: ++ result["processed"] = "" ++ else: ++ result["processed"] = data.get("text", "").upper() ++ else: ++ result["processed"] = data.get("text", "").upper() ++ else: ++ if config.get("fallback"): ++ result["processed"] = config["fallback"] ++ else: ++ result["processed"] = None ++ ++ if options.get("timestamp"): ++ result["timestamp"] = options["timestamp"] ++ if options.get("source"): ++ result["source"] = options["source"] ++ ++ return result ++ ++ + def simple_function(x: int) -> int: + """A simple function.""" + return x * 2 diff --git a/tests/fixtures/security-issue.diff b/tests/fixtures/security-issue.diff new file mode 100644 index 0000000..7115bfe --- /dev/null +++ b/tests/fixtures/security-issue.diff @@ -0,0 +1,31 @@ +diff --git a/src/auth.py b/src/auth.py +index 1234567..abcdefg 100644 +--- a/src/auth.py ++++ b/src/auth.py +@@ -1,10 +1,25 @@ + """Authentication module.""" + + import sqlite3 ++import os + + + def get_user(username: str) -> dict | None: + """Get user from database.""" + conn = sqlite3.connect("users.db") + cursor = conn.cursor() +- cursor.execute("SELECT * FROM users WHERE username = ?", (username,)) ++ # FIXME: this is vulnerable to SQL injection ++ query = "SELECT * FROM users WHERE username = '" + username + "'" ++ cursor.execute(query) + return cursor.fetchone() ++ ++ ++def run_command(cmd: str) -> str: ++ """Run a shell command.""" ++ # Command injection vulnerability ++ return os.popen(cmd).read() ++ ++ ++# Hardcoded credentials ++API_KEY = "sk-1234567890abcdef" ++DB_PASSWORD = "admin123" diff --git a/tests/fixtures/simple.diff b/tests/fixtures/simple.diff new file mode 100644 index 0000000..5360eac --- /dev/null +++ b/tests/fixtures/simple.diff @@ -0,0 +1,16 @@ +diff --git a/src/utils.py b/src/utils.py +index 1234567..abcdefg 100644 +--- a/src/utils.py ++++ b/src/utils.py +@@ -1,5 +1,8 @@ + """Utility functions.""" + + ++def add(a: int, b: int) -> int: ++ """Add two numbers.""" ++ return a + b ++ ++ + def subtract(a: int, b: int) -> int: + """Subtract two numbers.""" + return a - b diff --git a/tests/test_agents.py b/tests/test_agents.py new file mode 100644 index 0000000..3a2fb7e --- /dev/null +++ b/tests/test_agents.py @@ -0,0 +1,305 @@ +"""Tests for review agents.""" + +import pytest + +from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent +from arbiter.llm.prompts import PromptRegistry +from arbiter.models import AgentConfig, AgentName, Policy, Severity +from tests.conftest import MockLLMClient + + +class TestSecurityAgent: + @pytest.mark.asyncio + async def test_review_returns_result( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ some code", policy=Policy()) + + result = await agent.review(context) + + assert result.agent_name == AgentName.SECURITY + assert result.findings == [] + assert result.duration_ms >= 0 + assert result.tokens_used == 150 # 100 in + 50 out from mock + assert result.cost_usd == 0.001 + + @pytest.mark.asyncio + async def test_parses_json_findings( + self, + prompt_registry: PromptRegistry, + ) -> None: + response = """```json +[ + { + "file": "src/auth.py", + "line_start": 10, + "line_end": 15, + "severity": "high", + "confidence": 0.9, + "title": "SQL Injection", + "description": "User input concatenated", + "reasoning": "Allows SQL injection", + "suggestion": "Use parameterized queries", + "references": ["https://owasp.org"] + } +] +```""" + mock_llm = MockLLMClient(responses=[response]) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ query = ...", policy=Policy()) + + result = await agent.review(context) + + assert len(result.findings) == 1 + finding = result.findings[0] + assert finding.file == "src/auth.py" + assert finding.severity == Severity.HIGH + assert finding.confidence == 0.9 + assert finding.title == "SQL Injection" + + @pytest.mark.asyncio + async def test_uses_configured_model( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = SecurityAgent(mock_llm, prompt_registry) + policy = Policy( + agents={ + AgentName.SECURITY: AgentConfig(model="gpt-4o-mini"), + AgentName.STYLE: AgentConfig(), + AgentName.COMPLEXITY: AgentConfig(), + } + ) + context = ReviewContext(diff="+ code", policy=policy) + + await agent.review(context) + + assert mock_llm.calls[0]["model"] == "gpt-4o-mini" + + @pytest.mark.asyncio + async def test_filters_by_severity( + self, + prompt_registry: PromptRegistry, + ) -> None: + response = """[ + {"file": "a.py", "line_start": 1, "line_end": 1, "severity": "high", "confidence": 0.9, "title": "High", "description": "", "reasoning": ""}, + {"file": "b.py", "line_start": 1, "line_end": 1, "severity": "low", "confidence": 0.9, "title": "Low", "description": "", "reasoning": ""}, + {"file": "c.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.9, "title": "Info", "description": "", "reasoning": ""} +]""" + mock_llm = MockLLMClient(responses=[response]) + agent = SecurityAgent(mock_llm, prompt_registry) + policy = Policy( + agents={ + AgentName.SECURITY: AgentConfig(severity_threshold=Severity.MEDIUM), + AgentName.STYLE: AgentConfig(), + AgentName.COMPLEXITY: AgentConfig(), + } + ) + context = ReviewContext(diff="+ code", policy=policy) + + result = await agent.review(context) + + # Only high severity should pass (medium threshold filters low and info) + assert len(result.findings) == 1 + assert result.findings[0].severity == Severity.HIGH + + +class TestStyleAgent: + @pytest.mark.asyncio + async def test_review_returns_result( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = StyleAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ some code", policy=Policy()) + + result = await agent.review(context) + + assert result.agent_name == AgentName.STYLE + assert result.findings == [] + + @pytest.mark.asyncio + async def test_uses_default_model( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = StyleAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + await agent.review(context) + + assert mock_llm.calls[0]["model"] == "gpt-4o-mini" + + +class TestComplexityAgent: + @pytest.mark.asyncio + async def test_review_returns_result( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = ComplexityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ some code", policy=Policy()) + + result = await agent.review(context) + + assert result.agent_name == AgentName.COMPLEXITY + assert result.findings == [] + + @pytest.mark.asyncio + async def test_parses_complexity_findings( + self, + prompt_registry: PromptRegistry, + ) -> None: + response = """[ + { + "file": "processor.py", + "line_start": 1, + "line_end": 50, + "severity": "medium", + "confidence": 0.8, + "title": "High cyclomatic complexity", + "description": "Function has 15 branches", + "reasoning": "Makes testing and maintenance difficult" + } +]""" + mock_llm = MockLLMClient(responses=[response]) + agent = ComplexityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ complex code", policy=Policy()) + + result = await agent.review(context) + + assert len(result.findings) == 1 + assert result.findings[0].severity == Severity.MEDIUM + assert "complexity" in result.findings[0].title.lower() + + +class TestAgentResponseParsing: + @pytest.mark.asyncio + async def test_handles_empty_response( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=[""]) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + result = await agent.review(context) + + assert result.findings == [] + + @pytest.mark.asyncio + async def test_handles_invalid_json( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["not valid json"]) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + result = await agent.review(context) + + assert result.findings == [] + + @pytest.mark.asyncio + async def test_handles_json_without_code_block( + self, + prompt_registry: PromptRegistry, + ) -> None: + response = '[{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Test", "description": "", "reasoning": ""}]' + mock_llm = MockLLMClient(responses=[response]) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + result = await agent.review(context) + + assert len(result.findings) == 1 + + @pytest.mark.asyncio + async def test_handles_malformed_finding( + self, + prompt_registry: PromptRegistry, + ) -> None: + response = """[ + {"file": "a.py", "line_start": 1, "severity": "invalid_severity", "confidence": 0.5, "title": "Bad", "description": "", "reasoning": ""}, + {"file": "b.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Valid", "description": "", "reasoning": ""} +]""" + mock_llm = MockLLMClient(responses=[response]) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + result = await agent.review(context) + + # Only the valid finding should be included (first has invalid severity) + assert len(result.findings) == 1 + assert result.findings[0].title == "Valid" + + @pytest.mark.asyncio + async def test_includes_prompt_additions( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = SecurityAgent(mock_llm, prompt_registry) + policy = Policy( + agents={ + AgentName.SECURITY: AgentConfig(prompt_additions="Focus on authentication"), + AgentName.STYLE: AgentConfig(), + AgentName.COMPLEXITY: AgentConfig(), + } + ) + context = ReviewContext(diff="+ code", policy=policy) + + await agent.review(context) + + message_content = mock_llm.calls[0]["messages"][0]["content"] + assert "Focus on authentication" in message_content + + @pytest.mark.asyncio + async def test_handles_non_list_json( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=['{"not": "a list"}']) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + result = await agent.review(context) + + assert result.findings == [] + + @pytest.mark.asyncio + async def test_handles_non_dict_items( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=['["string", 123, null]']) + agent = SecurityAgent(mock_llm, prompt_registry) + context = ReviewContext(diff="+ code", policy=Policy()) + + result = await agent.review(context) + + assert result.findings == [] + + @pytest.mark.asyncio + async def test_agent_without_config_uses_defaults( + self, + prompt_registry: PromptRegistry, + ) -> None: + mock_llm = MockLLMClient(responses=["[]"]) + agent = SecurityAgent(mock_llm, prompt_registry) + # Create policy with empty agents dict + policy = Policy(agents={}) + context = ReviewContext(diff="+ code", policy=policy) + + result = await agent.review(context) + + # Should use default model (gpt-4o for security) + assert mock_llm.calls[0]["model"] == "gpt-4o" + assert result.findings == [] diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..2879b7b --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,561 @@ +"""Tests for CLI commands.""" + +import json +from pathlib import Path +from unittest.mock import AsyncMock, patch + +from typer.testing import CliRunner + +from arbiter.cli import ( + _severity_color, + _severity_icon, + _verdict_color, + _verdict_icon, + app, +) +from arbiter.deliberation import DeliberationResult +from arbiter.models import AgentName, Finding, ReviewResult, Severity, Verdict + +runner = CliRunner() + + +def make_mock_return( + findings: list[Finding] | None = None, verdict: Verdict = Verdict.APPROVE +) -> tuple[list[ReviewResult], DeliberationResult]: + """Create a mock return value for _run_review.""" + agent_results = [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=findings or [], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ) + ] + deliberation_result = DeliberationResult( + findings=findings or [], + verdict=verdict, + verdict_confidence=0.9, + verdict_reasoning="Test reasoning", + total_findings=len(findings) if findings else 0, + ) + return agent_results, deliberation_result + + +class TestVersionCommand: + def test_version_output(self) -> None: + result = runner.invoke(app, ["version"]) + assert result.exit_code == 0 + assert "arbiter" in result.output + assert "0.3.0" in result.output + + +class TestReviewCommand: + def test_file_not_found(self) -> None: + result = runner.invoke(app, ["review", "nonexistent.diff"]) + assert result.exit_code == 1 + assert "File not found" in result.output + + def test_empty_diff_warning(self, tmp_path: Path) -> None: + diff_file = tmp_path / "empty.diff" + diff_file.write_text("") + + result = runner.invoke(app, ["review", str(diff_file)]) + assert result.exit_code == 0 + assert "Empty diff" in result.output + + def test_policy_not_found(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + result = runner.invoke(app, ["review", str(diff_file), "--policy", "nonexistent.yaml"]) + assert result.exit_code == 1 + assert "Policy file not found" in result.output + + def test_reads_from_stdin(self) -> None: + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", "-"], input="+ added line\n") + + assert result.exit_code == 0 + + def test_json_output_format(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--format", "json"]) + + assert result.exit_code == 0 + assert '"verdict"' in result.output + assert '"findings"' in result.output + + def test_markdown_output_format(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"]) + + assert result.exit_code == 0 + assert "# Arbiter Review" in result.output + + def test_loads_policy_file(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + policy_file = tmp_path / "policy.yaml" + policy_file.write_text(""" +version: "1.0" +agents: + security: + enabled: true + style: + enabled: false + complexity: + enabled: false +""") + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--policy", str(policy_file)]) + + assert result.exit_code == 0 + # Verify policy was passed to _run_review + call_args = mock_run.call_args + policy = call_args[0][1] # Second positional arg is policy + assert len(policy.get_enabled_agents()) == 1 + + def test_model_override(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--model", "gpt-4o-mini"]) + + assert result.exit_code == 0 + # Verify model was set in policy + call_args = mock_run.call_args + policy = call_args[0][1] + for config in policy.agents.values(): + assert config.model == "gpt-4o-mini" + + def test_static_analysis_flag(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--no-static-analysis"]) + + assert result.exit_code == 0 + # Verify static_analysis was False + call_args = mock_run.call_args + assert call_args.kwargs.get("static_analysis") is False + + +class TestNoArgsHelp: + def test_no_args_shows_help(self) -> None: + result = runner.invoke(app, []) + assert result.exit_code == 0 + assert "A multi-agent code review system" in result.output + + +class TestOutputFormatting: + def test_severity_color(self) -> None: + assert _severity_color(Severity.CRITICAL) == "red bold" + assert _severity_color(Severity.HIGH) == "red" + assert _severity_color(Severity.MEDIUM) == "yellow" + assert _severity_color(Severity.LOW) == "blue" + assert _severity_color(Severity.INFO) == "dim" + + def test_severity_icon(self) -> None: + assert _severity_icon(Severity.CRITICAL) == "!!" + assert _severity_icon(Severity.HIGH) == "!" + assert _severity_icon(Severity.MEDIUM) == "*" + assert _severity_icon(Severity.LOW) == "-" + assert _severity_icon(Severity.INFO) == "i" + + def test_verdict_color(self) -> None: + assert _verdict_color(Verdict.APPROVE) == "green" + assert _verdict_color(Verdict.COMMENT) == "yellow" + assert _verdict_color(Verdict.REQUEST_CHANGES) == "red" + + def test_verdict_icon(self) -> None: + assert _verdict_icon(Verdict.APPROVE) == "[ok]" + assert _verdict_icon(Verdict.COMMENT) == "[..]" + assert _verdict_icon(Verdict.REQUEST_CHANGES) == "[!!]" + + +class TestRichOutput: + def test_rich_format_with_findings(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + finding = Finding( + id="test-finding-1", + agent=AgentName.SECURITY, + file="test.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection", + description="User input in query", + reasoning="Direct concatenation", + prompt_version="test-v1.0", + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return(findings=[finding]) + result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"]) + + assert result.exit_code == 0 + + def test_rich_format_no_findings(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"]) + + assert result.exit_code == 0 + assert "No issues found" in result.output + + def test_rich_format_critical_findings(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + finding = Finding( + id="test-finding-1", + agent=AgentName.SECURITY, + file="test.py", + line_start=10, + line_end=10, + severity=Severity.CRITICAL, + confidence=0.95, + title="Critical Issue", + description="This is critical", + reasoning="Very bad", + suggestion="Fix it immediately", + prompt_version="test-v1.0", + ) + + deliberation = DeliberationResult( + findings=[finding], + verdict=Verdict.REQUEST_CHANGES, + verdict_confidence=0.95, + verdict_reasoning="Critical issue found", + total_findings=1, + critical_count=1, + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = ( + [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=[finding], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ) + ], + deliberation, + ) + result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"]) + + assert result.exit_code == 0 + + +class TestMarkdownOutput: + def test_markdown_with_findings(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + finding = Finding( + id="test-finding-1", + agent=AgentName.SECURITY, + file="test.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection", + description="User input in query", + reasoning="Direct concatenation", + suggestion="Use parameterized queries", + prompt_version="test-v1.0", + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return(findings=[finding]) + result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"]) + + assert result.exit_code == 0 + assert "## Findings" in result.output + assert "SQL Injection" in result.output + + def test_markdown_verdict_badges(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + for verdict in [Verdict.APPROVE, Verdict.COMMENT, Verdict.REQUEST_CHANGES]: + deliberation = DeliberationResult( + verdict=verdict, + verdict_confidence=0.9, + verdict_reasoning="Test", + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = ( + [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=[], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ) + ], + deliberation, + ) + result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"]) + + assert result.exit_code == 0 + assert verdict.value.upper() in result.output + + +class TestJsonOutput: + def test_json_with_conflicts(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + from arbiter.deliberation.conflicts import Conflict, ConflictNature + + conflict = Conflict( + id="test-conflict", + finding_ids=["f1", "f2"], + nature=ConflictNature.TRADE_OFF, + description="Test conflict", + severity_weight=0.8, + ) + + deliberation = DeliberationResult( + verdict=Verdict.COMMENT, + verdict_confidence=0.7, + verdict_reasoning="Conflicts found", + conflicts=[conflict], + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = ( + [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=[], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ) + ], + deliberation, + ) + result = runner.invoke(app, ["review", str(diff_file), "--format", "json"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert "conflicts" in output + assert len(output["conflicts"]) == 1 + + +class TestWorkDirHandling: + def test_work_dir_option(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + work_dir = tmp_path / "src" + work_dir.mkdir() + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = make_mock_return() + result = runner.invoke(app, ["review", str(diff_file), "--work-dir", str(work_dir)]) + + assert result.exit_code == 0 + call_args = mock_run.call_args + assert call_args.kwargs.get("work_dir") == work_dir.resolve() + + +class TestRichOutputWithConflicts: + def test_rich_conflicts(self, tmp_path: Path) -> None: + from arbiter.deliberation.conflicts import Conflict, ConflictNature + from arbiter.deliberation.synthesis import Resolution + + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + conflict = Conflict( + id="test-conflict", + finding_ids=["f1", "f2"], + nature=ConflictNature.TRADE_OFF, + description="Security vs complexity trade-off", + severity_weight=0.8, + ) + + resolution = Resolution( + conflict_id="test-conflict", + decision="prefer_first", + reasoning="Security takes priority", + confidence=0.9, + ) + + deliberation = DeliberationResult( + verdict=Verdict.COMMENT, + verdict_confidence=0.7, + verdict_reasoning="Conflicts found", + conflicts=[conflict], + resolutions=[resolution], + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = ( + [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=[], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ) + ], + deliberation, + ) + result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"]) + + assert result.exit_code == 0 + + +class TestMarkdownOutputWithConflicts: + def test_markdown_conflicts(self, tmp_path: Path) -> None: + from arbiter.deliberation.conflicts import Conflict, ConflictNature + from arbiter.deliberation.synthesis import Resolution + + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + conflict = Conflict( + id="test-conflict", + finding_ids=["f1", "f2"], + nature=ConflictNature.CONTRADICTORY, + description="Contradictory recommendations", + severity_weight=0.9, + ) + + resolution = Resolution( + conflict_id="test-conflict", + decision="merge", + reasoning="Both concerns addressed by combined fix", + merged_suggestion="Do both things", + confidence=0.85, + ) + + deliberation = DeliberationResult( + verdict=Verdict.COMMENT, + verdict_confidence=0.7, + verdict_reasoning="Conflicts found", + conflicts=[conflict], + resolutions=[resolution], + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = ( + [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=[], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ) + ], + deliberation, + ) + result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"]) + + assert result.exit_code == 0 + assert "## Conflicts" in result.output + assert "Contradictory" in result.output + assert "Resolution" in result.output + + def test_markdown_findings(self, tmp_path: Path) -> None: + diff_file = tmp_path / "test.diff" + diff_file.write_text("+ some change") + + findings = [ + Finding( + id="f1", + agent=AgentName.SECURITY, + file="test.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="Security Issue", + description="Vulnerable code", + reasoning="Bad pattern", + suggestion="Fix it this way", + prompt_version="test-v1.0", + ), + Finding( + id="f2", + agent=AgentName.STYLE, + file="test.py", + line_start=20, + line_end=25, + severity=Severity.LOW, + confidence=0.8, + title="Style Issue", + description="Could be cleaner", + reasoning="Convention", + prompt_version="test-v1.0", + ), + ] + + deliberation = DeliberationResult( + findings=findings, + verdict=Verdict.COMMENT, + verdict_confidence=0.75, + verdict_reasoning="Issues found", + total_findings=2, + ) + + with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run: + mock_run.return_value = ( + [ + ReviewResult( + agent_name=AgentName.SECURITY, + findings=[findings[0]], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ), + ReviewResult( + agent_name=AgentName.STYLE, + findings=[findings[1]], + duration_ms=100, + tokens_used=100, + cost_usd=0.001, + ), + ], + deliberation, + ) + result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"]) + + assert result.exit_code == 0 + assert "Security Issue" in result.output + assert "Style Issue" in result.output + assert "Fix it this way" in result.output diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..9df58de --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,313 @@ +"""Tests for LLM client and prompts.""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from arbiter.llm.client import LiteLLMClient, LLMResponse +from arbiter.llm.prompts import PromptRegistry, PromptTemplate +from tests.conftest import MockLLMClient + + +class TestLiteLLMClient: + def test_init_default_values(self) -> None: + client = LiteLLMClient() + assert client.timeout == 60 + assert client.max_retries == 3 + + def test_init_custom_values(self) -> None: + client = LiteLLMClient(timeout=120, max_retries=5) + assert client.timeout == 120 + assert client.max_retries == 5 + + @pytest.mark.asyncio + async def test_complete_returns_response(self) -> None: + client = LiteLLMClient() + + # Mock the litellm.acompletion function + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Test response" + mock_response.model = "gpt-4o" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + + with ( + patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, + patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, + ): + mock_acomp.return_value = mock_response + mock_cost.return_value = 0.001 + + messages = [{"role": "user", "content": "Hello"}] + response = await client.complete(messages, "gpt-4o") + + assert response.content == "Test response" + assert response.model == "gpt-4o" + assert response.tokens_in == 10 + assert response.tokens_out == 5 + assert response.cost_usd == 0.001 + + mock_acomp.assert_called_once_with( + model="gpt-4o", + messages=messages, + timeout=60, + num_retries=3, + ) + + @pytest.mark.asyncio + async def test_complete_handles_empty_content(self) -> None: + client = LiteLLMClient() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = None # None content + mock_response.model = "gpt-4o" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 5 + mock_response.usage.completion_tokens = 0 + + with ( + patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, + patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, + ): + mock_acomp.return_value = mock_response + mock_cost.return_value = 0.0 + + messages = [{"role": "user", "content": "Hello"}] + response = await client.complete(messages, "gpt-4o") + + assert response.content == "" # Should be empty string, not None + + @pytest.mark.asyncio + async def test_complete_handles_missing_usage(self) -> None: + client = LiteLLMClient() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.model = "gpt-4o" + mock_response.usage = None # No usage data + + with ( + patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, + patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, + ): + mock_acomp.return_value = mock_response + mock_cost.return_value = 0.0 + + messages = [{"role": "user", "content": "Hello"}] + response = await client.complete(messages, "gpt-4o") + + assert response.tokens_in == 0 + assert response.tokens_out == 0 + + @pytest.mark.asyncio + async def test_complete_uses_fallback_model(self) -> None: + client = LiteLLMClient() + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.model = None # No model in response + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 5 + mock_response.usage.completion_tokens = 3 + + with ( + patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp, + patch("arbiter.llm.client.litellm.completion_cost") as mock_cost, + ): + mock_acomp.return_value = mock_response + mock_cost.return_value = 0.0 + + messages = [{"role": "user", "content": "Hello"}] + response = await client.complete(messages, "claude-3-opus") + + # Should use the passed model as fallback + assert response.model == "claude-3-opus" + + +class TestLLMResponse: + def test_response_creation(self) -> None: + response = LLMResponse( + content="Hello, world!", + model="gpt-4o", + tokens_in=10, + tokens_out=5, + cost_usd=0.001, + ) + assert response.content == "Hello, world!" + assert response.model == "gpt-4o" + assert response.tokens_in == 10 + assert response.tokens_out == 5 + assert response.cost_usd == 0.001 + + +class TestMockLLMClient: + @pytest.mark.asyncio + async def test_records_calls(self) -> None: + client = MockLLMClient() + messages = [{"role": "user", "content": "Hello"}] + + await client.complete(messages, "gpt-4o") + + assert len(client.calls) == 1 + assert client.calls[0]["messages"] == messages + assert client.calls[0]["model"] == "gpt-4o" + + @pytest.mark.asyncio + async def test_returns_canned_responses(self) -> None: + client = MockLLMClient(responses=["First", "Second"]) + messages = [{"role": "user", "content": "Hello"}] + + response1 = await client.complete(messages, "gpt-4o") + response2 = await client.complete(messages, "gpt-4o") + response3 = await client.complete(messages, "gpt-4o") + + assert response1.content == "First" + assert response2.content == "Second" + assert response3.content == "" # Exhausted + + @pytest.mark.asyncio + async def test_reset(self) -> None: + client = MockLLMClient(responses=["Hello"]) + messages = [{"role": "user", "content": "Hi"}] + + await client.complete(messages, "gpt-4o") + assert len(client.calls) == 1 + + client.reset() + assert len(client.calls) == 0 + + response = await client.complete(messages, "gpt-4o") + assert response.content == "Hello" # Responses reset too + + +class TestPromptTemplate: + def test_template_creation(self) -> None: + template = PromptTemplate( + name="security", + version="1.0", + content="Review: {{diff}}", + ) + assert template.name == "security" + assert template.version == "1.0" + assert template.full_name == "security-v1.0" + + def test_render_substitution(self) -> None: + template = PromptTemplate( + name="test", + version="1.0", + content="File: {{file}}\nDiff: {{diff}}", + ) + result = template.render(file="test.py", diff="+ added line") + assert result == "File: test.py\nDiff: + added line" + + def test_render_missing_variable(self) -> None: + template = PromptTemplate( + name="test", + version="1.0", + content="Value: {{value}}", + ) + result = template.render() + assert result == "Value: {{value}}" + + def test_render_multiple_occurrences(self) -> None: + template = PromptTemplate( + name="test", + version="1.0", + content="{{name}} and {{name}} again", + ) + result = template.render(name="test") + assert result == "test and test again" + + +class TestPromptRegistry: + def test_get_template(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + (templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}") + + registry = PromptRegistry(templates_dir) + template = registry.get("security", "1.0") + + assert template.name == "security" + assert template.version == "1.0" + assert "{{diff}}" in template.content + + def test_get_template_cached(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + (templates_dir / "security-v1.0.md").write_text("Content") + + registry = PromptRegistry(templates_dir) + template1 = registry.get("security", "1.0") + template2 = registry.get("security", "1.0") + + assert template1 is template2 + + def test_get_template_not_found(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + + registry = PromptRegistry(templates_dir) + + with pytest.raises(FileNotFoundError): + registry.get("missing", "1.0") + + def test_list_templates(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + (templates_dir / "security-v1.0.md").write_text("Content") + (templates_dir / "style-v2.0.md").write_text("Content") + (templates_dir / "readme.md").write_text("Not a template") + + registry = PromptRegistry(templates_dir) + templates = registry.list_templates() + + assert len(templates) == 2 + assert ("security", "1.0") in templates + assert ("style", "2.0") in templates + + def test_list_templates_empty_dir(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + + registry = PromptRegistry(templates_dir) + templates = registry.list_templates() + + assert templates == [] + + def test_list_templates_missing_dir(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "missing" + + registry = PromptRegistry(templates_dir) + templates = registry.list_templates() + + assert templates == [] + + def test_clear_cache(self, tmp_path: Path) -> None: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + (templates_dir / "test-v1.0.md").write_text("Original") + + registry = PromptRegistry(templates_dir) + template1 = registry.get("test", "1.0") + assert template1.content == "Original" + + # Modify file + (templates_dir / "test-v1.0.md").write_text("Modified") + + # Still cached + template2 = registry.get("test", "1.0") + assert template2.content == "Original" + + # Clear cache + registry.clear_cache() + + # Now reads new content + template3 = registry.get("test", "1.0") + assert template3.content == "Modified" diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..4e20099 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,224 @@ +"""Tests for data models.""" + +from pathlib import Path + +import pytest + +from arbiter.models import ( + AgentConfig, + AgentName, + Finding, + Policy, + ReviewResult, + Severity, + Verdict, +) + + +class TestEnums: + def test_severity_values(self) -> None: + assert Severity.CRITICAL == "critical" + assert Severity.HIGH == "high" + assert Severity.MEDIUM == "medium" + assert Severity.LOW == "low" + assert Severity.INFO == "info" + + def test_verdict_values(self) -> None: + assert Verdict.APPROVE == "approve" + assert Verdict.REQUEST_CHANGES == "request_changes" + assert Verdict.COMMENT == "comment" + + def test_agent_name_values(self) -> None: + assert AgentName.SECURITY == "security" + assert AgentName.STYLE == "style" + assert AgentName.COMPLEXITY == "complexity" + + def test_severity_from_string(self) -> None: + assert Severity("critical") == Severity.CRITICAL + assert Severity("high") == Severity.HIGH + + +class TestFinding: + def test_finding_creation(self) -> None: + finding = Finding( + id="test-123", + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection", + description="User input concatenated in SQL query", + reasoning="Allows attackers to execute arbitrary SQL", + suggestion="Use parameterized queries", + references=["https://owasp.org"], + prompt_version="security-v1.0", + ) + assert finding.id == "test-123" + assert finding.agent == AgentName.SECURITY + assert finding.severity == Severity.HIGH + assert finding.confidence == 0.9 + + def test_finding_confidence_validation(self) -> None: + with pytest.raises(ValueError): + Finding( + id="test", + agent=AgentName.SECURITY, + file="test.py", + line_start=1, + line_end=1, + severity=Severity.INFO, + confidence=1.5, + title="Test", + description="Test", + reasoning="Test", + prompt_version="test-v1.0", + ) + + def test_finding_line_validation(self) -> None: + with pytest.raises(ValueError): + Finding( + id="test", + agent=AgentName.SECURITY, + file="test.py", + line_start=0, + line_end=1, + severity=Severity.INFO, + confidence=0.5, + title="Test", + description="Test", + reasoning="Test", + prompt_version="test-v1.0", + ) + + def test_finding_serialization(self) -> None: + finding = Finding( + id="test-123", + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="Test", + description="Test desc", + reasoning="Test reason", + prompt_version="security-v1.0", + ) + data = finding.model_dump() + assert data["id"] == "test-123" + assert data["agent"] == "security" + assert data["severity"] == "high" + + +class TestAgentConfig: + def test_default_values(self) -> None: + config = AgentConfig() + assert config.enabled is True + assert config.model is None + assert config.severity_threshold == Severity.INFO + assert config.prompt_additions is None + + def test_custom_values(self) -> None: + config = AgentConfig( + enabled=False, + model="gpt-4o", + severity_threshold=Severity.MEDIUM, + prompt_additions="Focus on auth", + ) + assert config.enabled is False + assert config.model == "gpt-4o" + assert config.severity_threshold == Severity.MEDIUM + + +class TestPolicy: + def test_default_policy(self) -> None: + policy = Policy() + assert len(policy.agents) == 3 + assert AgentName.SECURITY in policy.agents + assert AgentName.STYLE in policy.agents + assert AgentName.COMPLEXITY in policy.agents + + def test_get_enabled_agents(self) -> None: + policy = Policy() + enabled = policy.get_enabled_agents() + assert len(enabled) == 3 + assert AgentName.SECURITY in enabled + + def test_get_enabled_agents_with_disabled(self) -> None: + policy = Policy( + agents={ + AgentName.SECURITY: AgentConfig(enabled=True), + AgentName.STYLE: AgentConfig(enabled=False), + AgentName.COMPLEXITY: AgentConfig(enabled=True), + } + ) + enabled = policy.get_enabled_agents() + assert len(enabled) == 2 + assert AgentName.STYLE not in enabled + + def test_load_from_yaml(self, tmp_path: Path) -> None: + policy_file = tmp_path / "policy.yaml" + policy_file.write_text(""" +version: "1.1" +agents: + security: + enabled: true + model: gpt-4o + severity_threshold: high + style: + enabled: false + complexity: + enabled: true +""") + policy = Policy.load(policy_file) + assert policy.version == "1.1" + assert policy.agents[AgentName.SECURITY].model == "gpt-4o" + assert policy.agents[AgentName.SECURITY].severity_threshold == Severity.HIGH + assert policy.agents[AgentName.STYLE].enabled is False + + def test_load_empty_yaml(self, tmp_path: Path) -> None: + policy_file = tmp_path / "policy.yaml" + policy_file.write_text("") + policy = Policy.load(policy_file) + assert policy.version == "1.0" + assert len(policy.agents) == 3 + + +class TestReviewResult: + def test_review_result_creation(self) -> None: + result = ReviewResult( + agent_name=AgentName.SECURITY, + findings=[], + duration_ms=1000, + tokens_used=500, + cost_usd=0.01, + ) + assert result.agent_name == AgentName.SECURITY + assert result.duration_ms == 1000 + assert result.cost_usd == 0.01 + + def test_review_result_with_findings(self) -> None: + finding = Finding( + id="test-123", + agent=AgentName.SECURITY, + file="test.py", + line_start=1, + line_end=1, + severity=Severity.HIGH, + confidence=0.9, + title="Test", + description="Test", + reasoning="Test", + prompt_version="test-v1.0", + ) + result = ReviewResult( + agent_name=AgentName.SECURITY, + findings=[finding], + duration_ms=1000, + tokens_used=500, + cost_usd=0.01, + ) + assert len(result.findings) == 1 + assert result.findings[0].id == "test-123"