feat(agents): implement agent framework and CLI

This commit is contained in:
2025-03-08 15:52:29 +00:00
parent 72268ff440
commit f22ca1d5bd
30 changed files with 3466 additions and 0 deletions

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Arbiter test suite."""

490
tests/conftest.py Normal file
View File

@@ -0,0 +1,490 @@
"""Pytest configuration and fixtures."""
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from arbiter.db.models import Base, FindingModel, ReviewModel
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
from arbiter.llm.client import LLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import Policy
from arbiter.models.enums import AgentName, Severity, Verdict
class MockLLMClient(LLMClient):
"""Mock LLM client for testing."""
def __init__(self, responses: list[str] | None = None) -> None:
self.responses = responses or []
self.calls: list[dict[str, Any]] = []
self._call_index = 0
async def complete(
self,
messages: list[dict[str, str]],
model: str,
**kwargs: Any,
) -> LLMResponse:
"""Record the call and return a canned response."""
self.calls.append(
{
"messages": messages,
"model": model,
"kwargs": kwargs,
}
)
content = ""
if self._call_index < len(self.responses):
content = self.responses[self._call_index]
self._call_index += 1
return LLMResponse(
content=content,
model=model,
tokens_in=100,
tokens_out=50,
cost_usd=0.001,
)
def reset(self) -> None:
self.calls = []
self._call_index = 0
@pytest.fixture
def mock_llm() -> MockLLMClient:
return MockLLMClient()
@pytest.fixture
def mock_llm_with_findings() -> MockLLMClient:
response = """```json
[
{
"file": "src/auth.py",
"line_start": 10,
"line_end": 15,
"severity": "high",
"confidence": 0.9,
"title": "SQL Injection vulnerability",
"description": "User input is directly concatenated into SQL query",
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
"suggestion": "Use parameterized queries instead",
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
}
]
```"""
return MockLLMClient(responses=[response])
@pytest.fixture
def sample_diff() -> str:
return """diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -8,6 +8,12 @@ def authenticate(username, password):
if not username or not password:
return None
+ # Check user in database
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
+ cursor.execute(query)
+ user = cursor.fetchone()
+ if user and user.password == password:
+ return user
return None
"""
@pytest.fixture
def sample_policy() -> Policy:
return Policy()
@pytest.fixture
def prompt_registry(tmp_path: Path) -> PromptRegistry:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text(
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
(templates_dir / "style-v1.0.md").write_text(
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
(templates_dir / "complexity-v1.0.md").write_text(
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
return PromptRegistry(templates_dir)
# Database fixtures for integration tests
@pytest.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with session_factory() as session:
yield session
@pytest.fixture
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
from arbiter.api.deps import get_db, get_redis
from arbiter.main import app
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
yield MockRedis()
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_redis] = override_get_redis
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class MockRedis:
"""Mock Redis client for testing."""
def __init__(self) -> None:
self._data: dict[str, Any] = {}
async def get(self, key: str) -> str | None:
return self._data.get(key)
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
self._data[key] = value
return True
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
return 1
return 0
async def ping(self) -> bool:
return True
async def llen(self, key: str) -> int:
data = self._data.get(key, [])
return len(data) if isinstance(data, list) else 0
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
return type("Job", (), {"job_id": "test-job-id"})()
@pytest.fixture
def mock_redis() -> MockRedis:
return MockRedis()
class MockPlatformClient(PlatformClient):
"""Mock platform client for testing."""
def __init__(self) -> None:
self._comments: list[Comment] = []
self._posted_comments: list[dict[str, Any]] = []
self._status_updates: list[dict[str, Any]] = []
self._diff = "mock diff content"
self._closed = False
self._fail_on: set[str] = set() # Methods to fail on
@property
def platform(self) -> Platform:
return Platform.GITHUB
async def get_pr_diff(
self,
repository: str, # noqa: ARG002
pr_number: int, # noqa: ARG002
) -> str:
if "get_pr_diff" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_pr_diff")
return self._diff
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
if "post_comment" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: post_comment")
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
self._posted_comments.append(
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
)
return comment_url
async def update_commit_status(
self,
repository: str,
sha: str,
status: CommitStatus,
description: str,
context: str,
target_url: str | None = None,
) -> None:
if "update_commit_status" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: update_commit_status")
self._status_updates.append(
{
"repository": repository,
"sha": sha,
"status": status,
"description": description,
"context": context,
"target_url": target_url,
}
)
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
if "get_pr_info" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_pr_info")
from arbiter.integrations.base import PullRequestInfo
return PullRequestInfo(
platform=Platform.GITHUB,
repository=repository,
pr_number=pr_number,
head_sha="abc123",
base_sha="def456",
head_ref="feature",
base_ref="main",
title="Test PR",
author="testuser",
url=f"https://github.com/{repository}/pull/{pr_number}",
is_draft=False,
)
async def get_comments(
self,
repository: str, # noqa: ARG002
pr_number: int, # noqa: ARG002
) -> list[Comment]:
if "get_comments" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_comments")
return self._comments
async def update_comment(
self, repository: str, pr_number: int, comment_id: str, body: str
) -> str:
if "update_comment" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: update_comment")
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
self._posted_comments.append(
{
"repository": repository,
"pr_number": pr_number,
"comment_id": comment_id,
"body": body,
"url": comment_url,
}
)
return comment_url
async def close(self) -> None:
self._closed = True
@pytest.fixture
def mock_platform_client() -> MockPlatformClient:
return MockPlatformClient()
@pytest.fixture
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo",
pr_number=42,
pr_title="Test PR",
base_sha="abc1234567890",
head_sha="def0987654321",
author="testuser",
is_draft=False,
status="completed",
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Found some issues to discuss",
total_tokens=1500,
total_cost_usd=0.015,
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
created_at=datetime.now(UTC),
started_at=datetime.now(UTC),
completed_at=datetime.now(UTC),
)
db_session.add(review)
# Add findings
finding1 = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection vulnerability",
description="User input is directly concatenated into SQL query",
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
suggestion="Use parameterized queries instead",
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
prompt_version="security-v1.0",
)
finding2 = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.STYLE,
file="src/auth.py",
line_start=20,
line_end=25,
severity=Severity.MEDIUM,
confidence=0.85,
title="Long function",
description="Function exceeds 50 lines",
reasoning="Long functions are harder to test and maintain",
suggestion="Consider breaking into smaller functions",
references=[],
prompt_version="style-v1.0",
)
db_session.add(finding1)
db_session.add(finding2)
await db_session.commit()
await db_session.refresh(review)
return review
@pytest.fixture
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
reviews = []
for i in range(5):
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo" if i < 3 else "other/repo",
pr_number=i + 1,
pr_title=f"Test PR #{i + 1}",
base_sha=f"base{i:040d}",
head_sha=f"head{i:040d}",
author="testuser" if i % 2 == 0 else "otheruser",
is_draft=False,
status="completed" if i < 4 else "failed",
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
verdict_confidence=0.8 if i < 4 else None,
total_tokens=1000 * (i + 1),
total_cost_usd=0.01 * (i + 1),
created_at=datetime.now(UTC),
completed_at=datetime.now(UTC) if i < 4 else None,
)
db_session.add(review)
reviews.append(review)
# Add a finding to each completed review
if i < 4:
finding = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.SECURITY,
file="src/test.py",
line_start=10,
line_end=15,
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
confidence=0.9,
title=f"Finding {i + 1}",
description="Test finding",
reasoning="Test reasoning",
prompt_version="security-v1.0",
)
db_session.add(finding)
await db_session.commit()
for review in reviews:
await db_session.refresh(review)
return reviews
@pytest.fixture
def mock_settings() -> Any:
class MockSecretStr:
def __init__(self, value: str) -> None:
self._value = value
def get_secret_value(self) -> str:
return self._value
class MockSettings:
github_token = MockSecretStr("ghp_test_token")
gitlab_token = MockSecretStr("glpat_test_token")
github_base_url = "https://api.github.com"
gitlab_base_url = "https://gitlab.com/api/v4"
github_webhook_secret = MockSecretStr("webhook_secret")
gitlab_webhook_token = MockSecretStr("gitlab_token")
integration_timeout = 30
integration_max_retries = 3
llm_timeout = 60
llm_max_retries = 3
post_comments = True
update_status = True
status_check_context = "arbiter"
templates_dir = Path("templates")
followup_enabled = True
followup_confidence_threshold = 0.5
return MockSettings()
@pytest.fixture
def mock_settings_no_github() -> Any:
class MockSettings:
github_token = None
gitlab_token = None
github_base_url = "https://api.github.com"
gitlab_base_url = "https://gitlab.com/api/v4"
integration_timeout = 30
integration_max_retries = 3
post_comments = False
update_status = False
status_check_context = "arbiter"
templates_dir = Path("templates")
return MockSettings()

69
tests/fixtures/complex-function.diff vendored Normal file
View File

@@ -0,0 +1,69 @@
diff --git a/src/processor.py b/src/processor.py
index 1234567..abcdefg 100644
--- a/src/processor.py
+++ b/src/processor.py
@@ -1,5 +1,65 @@
"""Data processor module."""
+def process_data(data: dict, config: dict, options: dict | None = None) -> dict:
+ """Process data with many nested conditions."""
+ result = {}
+ options = options or {}
+
+ if data.get("type") == "A":
+ if config.get("mode") == "strict":
+ if options.get("validate"):
+ if data.get("value") > 100:
+ if config.get("transform"):
+ result["processed"] = data["value"] * 2
+ else:
+ result["processed"] = data["value"]
+ else:
+ if options.get("default"):
+ result["processed"] = options["default"]
+ else:
+ result["processed"] = 0
+ else:
+ result["processed"] = data.get("value", 0)
+ else:
+ result["processed"] = data.get("value", 0)
+ elif data.get("type") == "B":
+ if config.get("mode") == "strict":
+ if options.get("validate"):
+ if data.get("items"):
+ result["processed"] = len(data["items"])
+ else:
+ result["processed"] = 0
+ else:
+ result["processed"] = len(data.get("items", []))
+ else:
+ result["processed"] = len(data.get("items", []))
+ elif data.get("type") == "C":
+ if config.get("mode") == "strict":
+ if options.get("validate"):
+ if data.get("text"):
+ result["processed"] = data["text"].upper()
+ else:
+ result["processed"] = ""
+ else:
+ result["processed"] = data.get("text", "").upper()
+ else:
+ result["processed"] = data.get("text", "").upper()
+ else:
+ if config.get("fallback"):
+ result["processed"] = config["fallback"]
+ else:
+ result["processed"] = None
+
+ if options.get("timestamp"):
+ result["timestamp"] = options["timestamp"]
+ if options.get("source"):
+ result["source"] = options["source"]
+
+ return result
+
+
def simple_function(x: int) -> int:
"""A simple function."""
return x * 2

31
tests/fixtures/security-issue.diff vendored Normal file
View File

@@ -0,0 +1,31 @@
diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -1,10 +1,25 @@
"""Authentication module."""
import sqlite3
+import os
def get_user(username: str) -> dict | None:
"""Get user from database."""
conn = sqlite3.connect("users.db")
cursor = conn.cursor()
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
+ # FIXME: this is vulnerable to SQL injection
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
+ cursor.execute(query)
return cursor.fetchone()
+
+
+def run_command(cmd: str) -> str:
+ """Run a shell command."""
+ # Command injection vulnerability
+ return os.popen(cmd).read()
+
+
+# Hardcoded credentials
+API_KEY = "sk-1234567890abcdef"
+DB_PASSWORD = "admin123"

16
tests/fixtures/simple.diff vendored Normal file
View File

@@ -0,0 +1,16 @@
diff --git a/src/utils.py b/src/utils.py
index 1234567..abcdefg 100644
--- a/src/utils.py
+++ b/src/utils.py
@@ -1,5 +1,8 @@
"""Utility functions."""
+def add(a: int, b: int) -> int:
+ """Add two numbers."""
+ return a + b
+
+
def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b

305
tests/test_agents.py Normal file
View File

@@ -0,0 +1,305 @@
"""Tests for review agents."""
import pytest
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import AgentConfig, AgentName, Policy, Severity
from tests.conftest import MockLLMClient
class TestSecurityAgent:
@pytest.mark.asyncio
async def test_review_returns_result(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ some code", policy=Policy())
result = await agent.review(context)
assert result.agent_name == AgentName.SECURITY
assert result.findings == []
assert result.duration_ms >= 0
assert result.tokens_used == 150 # 100 in + 50 out from mock
assert result.cost_usd == 0.001
@pytest.mark.asyncio
async def test_parses_json_findings(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """```json
[
{
"file": "src/auth.py",
"line_start": 10,
"line_end": 15,
"severity": "high",
"confidence": 0.9,
"title": "SQL Injection",
"description": "User input concatenated",
"reasoning": "Allows SQL injection",
"suggestion": "Use parameterized queries",
"references": ["https://owasp.org"]
}
]
```"""
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ query = ...", policy=Policy())
result = await agent.review(context)
assert len(result.findings) == 1
finding = result.findings[0]
assert finding.file == "src/auth.py"
assert finding.severity == Severity.HIGH
assert finding.confidence == 0.9
assert finding.title == "SQL Injection"
@pytest.mark.asyncio
async def test_uses_configured_model(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(model="gpt-4o-mini"),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
}
)
context = ReviewContext(diff="+ code", policy=policy)
await agent.review(context)
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
@pytest.mark.asyncio
async def test_filters_by_severity(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """[
{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "high", "confidence": 0.9, "title": "High", "description": "", "reasoning": ""},
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "low", "confidence": 0.9, "title": "Low", "description": "", "reasoning": ""},
{"file": "c.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.9, "title": "Info", "description": "", "reasoning": ""}
]"""
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(severity_threshold=Severity.MEDIUM),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
}
)
context = ReviewContext(diff="+ code", policy=policy)
result = await agent.review(context)
# Only high severity should pass (medium threshold filters low and info)
assert len(result.findings) == 1
assert result.findings[0].severity == Severity.HIGH
class TestStyleAgent:
@pytest.mark.asyncio
async def test_review_returns_result(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = StyleAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ some code", policy=Policy())
result = await agent.review(context)
assert result.agent_name == AgentName.STYLE
assert result.findings == []
@pytest.mark.asyncio
async def test_uses_default_model(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = StyleAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
await agent.review(context)
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
class TestComplexityAgent:
@pytest.mark.asyncio
async def test_review_returns_result(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = ComplexityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ some code", policy=Policy())
result = await agent.review(context)
assert result.agent_name == AgentName.COMPLEXITY
assert result.findings == []
@pytest.mark.asyncio
async def test_parses_complexity_findings(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """[
{
"file": "processor.py",
"line_start": 1,
"line_end": 50,
"severity": "medium",
"confidence": 0.8,
"title": "High cyclomatic complexity",
"description": "Function has 15 branches",
"reasoning": "Makes testing and maintenance difficult"
}
]"""
mock_llm = MockLLMClient(responses=[response])
agent = ComplexityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ complex code", policy=Policy())
result = await agent.review(context)
assert len(result.findings) == 1
assert result.findings[0].severity == Severity.MEDIUM
assert "complexity" in result.findings[0].title.lower()
class TestAgentResponseParsing:
@pytest.mark.asyncio
async def test_handles_empty_response(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=[""])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_handles_invalid_json(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["not valid json"])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_handles_json_without_code_block(
self,
prompt_registry: PromptRegistry,
) -> None:
response = '[{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Test", "description": "", "reasoning": ""}]'
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert len(result.findings) == 1
@pytest.mark.asyncio
async def test_handles_malformed_finding(
self,
prompt_registry: PromptRegistry,
) -> None:
response = """[
{"file": "a.py", "line_start": 1, "severity": "invalid_severity", "confidence": 0.5, "title": "Bad", "description": "", "reasoning": ""},
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Valid", "description": "", "reasoning": ""}
]"""
mock_llm = MockLLMClient(responses=[response])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
# Only the valid finding should be included (first has invalid severity)
assert len(result.findings) == 1
assert result.findings[0].title == "Valid"
@pytest.mark.asyncio
async def test_includes_prompt_additions(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(prompt_additions="Focus on authentication"),
AgentName.STYLE: AgentConfig(),
AgentName.COMPLEXITY: AgentConfig(),
}
)
context = ReviewContext(diff="+ code", policy=policy)
await agent.review(context)
message_content = mock_llm.calls[0]["messages"][0]["content"]
assert "Focus on authentication" in message_content
@pytest.mark.asyncio
async def test_handles_non_list_json(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=['{"not": "a list"}'])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_handles_non_dict_items(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=['["string", 123, null]'])
agent = SecurityAgent(mock_llm, prompt_registry)
context = ReviewContext(diff="+ code", policy=Policy())
result = await agent.review(context)
assert result.findings == []
@pytest.mark.asyncio
async def test_agent_without_config_uses_defaults(
self,
prompt_registry: PromptRegistry,
) -> None:
mock_llm = MockLLMClient(responses=["[]"])
agent = SecurityAgent(mock_llm, prompt_registry)
# Create policy with empty agents dict
policy = Policy(agents={})
context = ReviewContext(diff="+ code", policy=policy)
result = await agent.review(context)
# Should use default model (gpt-4o for security)
assert mock_llm.calls[0]["model"] == "gpt-4o"
assert result.findings == []

561
tests/test_cli.py Normal file
View File

@@ -0,0 +1,561 @@
"""Tests for CLI commands."""
import json
from pathlib import Path
from unittest.mock import AsyncMock, patch
from typer.testing import CliRunner
from arbiter.cli import (
_severity_color,
_severity_icon,
_verdict_color,
_verdict_icon,
app,
)
from arbiter.deliberation import DeliberationResult
from arbiter.models import AgentName, Finding, ReviewResult, Severity, Verdict
runner = CliRunner()
def make_mock_return(
findings: list[Finding] | None = None, verdict: Verdict = Verdict.APPROVE
) -> tuple[list[ReviewResult], DeliberationResult]:
"""Create a mock return value for _run_review."""
agent_results = [
ReviewResult(
agent_name=AgentName.SECURITY,
findings=findings or [],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
]
deliberation_result = DeliberationResult(
findings=findings or [],
verdict=verdict,
verdict_confidence=0.9,
verdict_reasoning="Test reasoning",
total_findings=len(findings) if findings else 0,
)
return agent_results, deliberation_result
class TestVersionCommand:
def test_version_output(self) -> None:
result = runner.invoke(app, ["version"])
assert result.exit_code == 0
assert "arbiter" in result.output
assert "0.3.0" in result.output
class TestReviewCommand:
def test_file_not_found(self) -> None:
result = runner.invoke(app, ["review", "nonexistent.diff"])
assert result.exit_code == 1
assert "File not found" in result.output
def test_empty_diff_warning(self, tmp_path: Path) -> None:
diff_file = tmp_path / "empty.diff"
diff_file.write_text("")
result = runner.invoke(app, ["review", str(diff_file)])
assert result.exit_code == 0
assert "Empty diff" in result.output
def test_policy_not_found(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
result = runner.invoke(app, ["review", str(diff_file), "--policy", "nonexistent.yaml"])
assert result.exit_code == 1
assert "Policy file not found" in result.output
def test_reads_from_stdin(self) -> None:
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", "-"], input="+ added line\n")
assert result.exit_code == 0
def test_json_output_format(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
assert result.exit_code == 0
assert '"verdict"' in result.output
assert '"findings"' in result.output
def test_markdown_output_format(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "# Arbiter Review" in result.output
def test_loads_policy_file(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
policy_file = tmp_path / "policy.yaml"
policy_file.write_text("""
version: "1.0"
agents:
security:
enabled: true
style:
enabled: false
complexity:
enabled: false
""")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--policy", str(policy_file)])
assert result.exit_code == 0
# Verify policy was passed to _run_review
call_args = mock_run.call_args
policy = call_args[0][1] # Second positional arg is policy
assert len(policy.get_enabled_agents()) == 1
def test_model_override(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--model", "gpt-4o-mini"])
assert result.exit_code == 0
# Verify model was set in policy
call_args = mock_run.call_args
policy = call_args[0][1]
for config in policy.agents.values():
assert config.model == "gpt-4o-mini"
def test_static_analysis_flag(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--no-static-analysis"])
assert result.exit_code == 0
# Verify static_analysis was False
call_args = mock_run.call_args
assert call_args.kwargs.get("static_analysis") is False
class TestNoArgsHelp:
def test_no_args_shows_help(self) -> None:
result = runner.invoke(app, [])
assert result.exit_code == 0
assert "A multi-agent code review system" in result.output
class TestOutputFormatting:
def test_severity_color(self) -> None:
assert _severity_color(Severity.CRITICAL) == "red bold"
assert _severity_color(Severity.HIGH) == "red"
assert _severity_color(Severity.MEDIUM) == "yellow"
assert _severity_color(Severity.LOW) == "blue"
assert _severity_color(Severity.INFO) == "dim"
def test_severity_icon(self) -> None:
assert _severity_icon(Severity.CRITICAL) == "!!"
assert _severity_icon(Severity.HIGH) == "!"
assert _severity_icon(Severity.MEDIUM) == "*"
assert _severity_icon(Severity.LOW) == "-"
assert _severity_icon(Severity.INFO) == "i"
def test_verdict_color(self) -> None:
assert _verdict_color(Verdict.APPROVE) == "green"
assert _verdict_color(Verdict.COMMENT) == "yellow"
assert _verdict_color(Verdict.REQUEST_CHANGES) == "red"
def test_verdict_icon(self) -> None:
assert _verdict_icon(Verdict.APPROVE) == "[ok]"
assert _verdict_icon(Verdict.COMMENT) == "[..]"
assert _verdict_icon(Verdict.REQUEST_CHANGES) == "[!!]"
class TestRichOutput:
def test_rich_format_with_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
finding = Finding(
id="test-finding-1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input in query",
reasoning="Direct concatenation",
prompt_version="test-v1.0",
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return(findings=[finding])
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
def test_rich_format_no_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
assert "No issues found" in result.output
def test_rich_format_critical_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
finding = Finding(
id="test-finding-1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=10,
severity=Severity.CRITICAL,
confidence=0.95,
title="Critical Issue",
description="This is critical",
reasoning="Very bad",
suggestion="Fix it immediately",
prompt_version="test-v1.0",
)
deliberation = DeliberationResult(
findings=[finding],
verdict=Verdict.REQUEST_CHANGES,
verdict_confidence=0.95,
verdict_reasoning="Critical issue found",
total_findings=1,
critical_count=1,
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[finding],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
class TestMarkdownOutput:
def test_markdown_with_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
finding = Finding(
id="test-finding-1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input in query",
reasoning="Direct concatenation",
suggestion="Use parameterized queries",
prompt_version="test-v1.0",
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return(findings=[finding])
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "## Findings" in result.output
assert "SQL Injection" in result.output
def test_markdown_verdict_badges(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
for verdict in [Verdict.APPROVE, Verdict.COMMENT, Verdict.REQUEST_CHANGES]:
deliberation = DeliberationResult(
verdict=verdict,
verdict_confidence=0.9,
verdict_reasoning="Test",
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert verdict.value.upper() in result.output
class TestJsonOutput:
def test_json_with_conflicts(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
from arbiter.deliberation.conflicts import Conflict, ConflictNature
conflict = Conflict(
id="test-conflict",
finding_ids=["f1", "f2"],
nature=ConflictNature.TRADE_OFF,
description="Test conflict",
severity_weight=0.8,
)
deliberation = DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.7,
verdict_reasoning="Conflicts found",
conflicts=[conflict],
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
assert result.exit_code == 0
output = json.loads(result.output)
assert "conflicts" in output
assert len(output["conflicts"]) == 1
class TestWorkDirHandling:
def test_work_dir_option(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
work_dir = tmp_path / "src"
work_dir.mkdir()
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = make_mock_return()
result = runner.invoke(app, ["review", str(diff_file), "--work-dir", str(work_dir)])
assert result.exit_code == 0
call_args = mock_run.call_args
assert call_args.kwargs.get("work_dir") == work_dir.resolve()
class TestRichOutputWithConflicts:
def test_rich_conflicts(self, tmp_path: Path) -> None:
from arbiter.deliberation.conflicts import Conflict, ConflictNature
from arbiter.deliberation.synthesis import Resolution
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
conflict = Conflict(
id="test-conflict",
finding_ids=["f1", "f2"],
nature=ConflictNature.TRADE_OFF,
description="Security vs complexity trade-off",
severity_weight=0.8,
)
resolution = Resolution(
conflict_id="test-conflict",
decision="prefer_first",
reasoning="Security takes priority",
confidence=0.9,
)
deliberation = DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.7,
verdict_reasoning="Conflicts found",
conflicts=[conflict],
resolutions=[resolution],
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
assert result.exit_code == 0
class TestMarkdownOutputWithConflicts:
def test_markdown_conflicts(self, tmp_path: Path) -> None:
from arbiter.deliberation.conflicts import Conflict, ConflictNature
from arbiter.deliberation.synthesis import Resolution
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
conflict = Conflict(
id="test-conflict",
finding_ids=["f1", "f2"],
nature=ConflictNature.CONTRADICTORY,
description="Contradictory recommendations",
severity_weight=0.9,
)
resolution = Resolution(
conflict_id="test-conflict",
decision="merge",
reasoning="Both concerns addressed by combined fix",
merged_suggestion="Do both things",
confidence=0.85,
)
deliberation = DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.7,
verdict_reasoning="Conflicts found",
conflicts=[conflict],
resolutions=[resolution],
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
)
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "## Conflicts" in result.output
assert "Contradictory" in result.output
assert "Resolution" in result.output
def test_markdown_findings(self, tmp_path: Path) -> None:
diff_file = tmp_path / "test.diff"
diff_file.write_text("+ some change")
findings = [
Finding(
id="f1",
agent=AgentName.SECURITY,
file="test.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="Security Issue",
description="Vulnerable code",
reasoning="Bad pattern",
suggestion="Fix it this way",
prompt_version="test-v1.0",
),
Finding(
id="f2",
agent=AgentName.STYLE,
file="test.py",
line_start=20,
line_end=25,
severity=Severity.LOW,
confidence=0.8,
title="Style Issue",
description="Could be cleaner",
reasoning="Convention",
prompt_version="test-v1.0",
),
]
deliberation = DeliberationResult(
findings=findings,
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Issues found",
total_findings=2,
)
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
mock_run.return_value = (
[
ReviewResult(
agent_name=AgentName.SECURITY,
findings=[findings[0]],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
),
ReviewResult(
agent_name=AgentName.STYLE,
findings=[findings[1]],
duration_ms=100,
tokens_used=100,
cost_usd=0.001,
),
],
deliberation,
)
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
assert result.exit_code == 0
assert "Security Issue" in result.output
assert "Style Issue" in result.output
assert "Fix it this way" in result.output

313
tests/test_llm.py Normal file
View File

@@ -0,0 +1,313 @@
"""Tests for LLM client and prompts."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from arbiter.llm.client import LiteLLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
from tests.conftest import MockLLMClient
class TestLiteLLMClient:
def test_init_default_values(self) -> None:
client = LiteLLMClient()
assert client.timeout == 60
assert client.max_retries == 3
def test_init_custom_values(self) -> None:
client = LiteLLMClient(timeout=120, max_retries=5)
assert client.timeout == 120
assert client.max_retries == 5
@pytest.mark.asyncio
async def test_complete_returns_response(self) -> None:
client = LiteLLMClient()
# Mock the litellm.acompletion function
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Test response"
mock_response.model = "gpt-4o"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 5
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.001
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.content == "Test response"
assert response.model == "gpt-4o"
assert response.tokens_in == 10
assert response.tokens_out == 5
assert response.cost_usd == 0.001
mock_acomp.assert_called_once_with(
model="gpt-4o",
messages=messages,
timeout=60,
num_retries=3,
)
@pytest.mark.asyncio
async def test_complete_handles_empty_content(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = None # None content
mock_response.model = "gpt-4o"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.usage.completion_tokens = 0
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.content == "" # Should be empty string, not None
@pytest.mark.asyncio
async def test_complete_handles_missing_usage(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.model = "gpt-4o"
mock_response.usage = None # No usage data
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "gpt-4o")
assert response.tokens_in == 0
assert response.tokens_out == 0
@pytest.mark.asyncio
async def test_complete_uses_fallback_model(self) -> None:
client = LiteLLMClient()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Response"
mock_response.model = None # No model in response
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response.usage.completion_tokens = 3
with (
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
):
mock_acomp.return_value = mock_response
mock_cost.return_value = 0.0
messages = [{"role": "user", "content": "Hello"}]
response = await client.complete(messages, "claude-3-opus")
# Should use the passed model as fallback
assert response.model == "claude-3-opus"
class TestLLMResponse:
def test_response_creation(self) -> None:
response = LLMResponse(
content="Hello, world!",
model="gpt-4o",
tokens_in=10,
tokens_out=5,
cost_usd=0.001,
)
assert response.content == "Hello, world!"
assert response.model == "gpt-4o"
assert response.tokens_in == 10
assert response.tokens_out == 5
assert response.cost_usd == 0.001
class TestMockLLMClient:
@pytest.mark.asyncio
async def test_records_calls(self) -> None:
client = MockLLMClient()
messages = [{"role": "user", "content": "Hello"}]
await client.complete(messages, "gpt-4o")
assert len(client.calls) == 1
assert client.calls[0]["messages"] == messages
assert client.calls[0]["model"] == "gpt-4o"
@pytest.mark.asyncio
async def test_returns_canned_responses(self) -> None:
client = MockLLMClient(responses=["First", "Second"])
messages = [{"role": "user", "content": "Hello"}]
response1 = await client.complete(messages, "gpt-4o")
response2 = await client.complete(messages, "gpt-4o")
response3 = await client.complete(messages, "gpt-4o")
assert response1.content == "First"
assert response2.content == "Second"
assert response3.content == "" # Exhausted
@pytest.mark.asyncio
async def test_reset(self) -> None:
client = MockLLMClient(responses=["Hello"])
messages = [{"role": "user", "content": "Hi"}]
await client.complete(messages, "gpt-4o")
assert len(client.calls) == 1
client.reset()
assert len(client.calls) == 0
response = await client.complete(messages, "gpt-4o")
assert response.content == "Hello" # Responses reset too
class TestPromptTemplate:
def test_template_creation(self) -> None:
template = PromptTemplate(
name="security",
version="1.0",
content="Review: {{diff}}",
)
assert template.name == "security"
assert template.version == "1.0"
assert template.full_name == "security-v1.0"
def test_render_substitution(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="File: {{file}}\nDiff: {{diff}}",
)
result = template.render(file="test.py", diff="+ added line")
assert result == "File: test.py\nDiff: + added line"
def test_render_missing_variable(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="Value: {{value}}",
)
result = template.render()
assert result == "Value: {{value}}"
def test_render_multiple_occurrences(self) -> None:
template = PromptTemplate(
name="test",
version="1.0",
content="{{name}} and {{name}} again",
)
result = template.render(name="test")
assert result == "test and test again"
class TestPromptRegistry:
def test_get_template(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
registry = PromptRegistry(templates_dir)
template = registry.get("security", "1.0")
assert template.name == "security"
assert template.version == "1.0"
assert "{{diff}}" in template.content
def test_get_template_cached(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Content")
registry = PromptRegistry(templates_dir)
template1 = registry.get("security", "1.0")
template2 = registry.get("security", "1.0")
assert template1 is template2
def test_get_template_not_found(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
registry = PromptRegistry(templates_dir)
with pytest.raises(FileNotFoundError):
registry.get("missing", "1.0")
def test_list_templates(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text("Content")
(templates_dir / "style-v2.0.md").write_text("Content")
(templates_dir / "readme.md").write_text("Not a template")
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert len(templates) == 2
assert ("security", "1.0") in templates
assert ("style", "2.0") in templates
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert templates == []
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "missing"
registry = PromptRegistry(templates_dir)
templates = registry.list_templates()
assert templates == []
def test_clear_cache(self, tmp_path: Path) -> None:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "test-v1.0.md").write_text("Original")
registry = PromptRegistry(templates_dir)
template1 = registry.get("test", "1.0")
assert template1.content == "Original"
# Modify file
(templates_dir / "test-v1.0.md").write_text("Modified")
# Still cached
template2 = registry.get("test", "1.0")
assert template2.content == "Original"
# Clear cache
registry.clear_cache()
# Now reads new content
template3 = registry.get("test", "1.0")
assert template3.content == "Modified"

224
tests/test_models.py Normal file
View File

@@ -0,0 +1,224 @@
"""Tests for data models."""
from pathlib import Path
import pytest
from arbiter.models import (
AgentConfig,
AgentName,
Finding,
Policy,
ReviewResult,
Severity,
Verdict,
)
class TestEnums:
def test_severity_values(self) -> None:
assert Severity.CRITICAL == "critical"
assert Severity.HIGH == "high"
assert Severity.MEDIUM == "medium"
assert Severity.LOW == "low"
assert Severity.INFO == "info"
def test_verdict_values(self) -> None:
assert Verdict.APPROVE == "approve"
assert Verdict.REQUEST_CHANGES == "request_changes"
assert Verdict.COMMENT == "comment"
def test_agent_name_values(self) -> None:
assert AgentName.SECURITY == "security"
assert AgentName.STYLE == "style"
assert AgentName.COMPLEXITY == "complexity"
def test_severity_from_string(self) -> None:
assert Severity("critical") == Severity.CRITICAL
assert Severity("high") == Severity.HIGH
class TestFinding:
def test_finding_creation(self) -> None:
finding = Finding(
id="test-123",
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input concatenated in SQL query",
reasoning="Allows attackers to execute arbitrary SQL",
suggestion="Use parameterized queries",
references=["https://owasp.org"],
prompt_version="security-v1.0",
)
assert finding.id == "test-123"
assert finding.agent == AgentName.SECURITY
assert finding.severity == Severity.HIGH
assert finding.confidence == 0.9
def test_finding_confidence_validation(self) -> None:
with pytest.raises(ValueError):
Finding(
id="test",
agent=AgentName.SECURITY,
file="test.py",
line_start=1,
line_end=1,
severity=Severity.INFO,
confidence=1.5,
title="Test",
description="Test",
reasoning="Test",
prompt_version="test-v1.0",
)
def test_finding_line_validation(self) -> None:
with pytest.raises(ValueError):
Finding(
id="test",
agent=AgentName.SECURITY,
file="test.py",
line_start=0,
line_end=1,
severity=Severity.INFO,
confidence=0.5,
title="Test",
description="Test",
reasoning="Test",
prompt_version="test-v1.0",
)
def test_finding_serialization(self) -> None:
finding = Finding(
id="test-123",
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="Test",
description="Test desc",
reasoning="Test reason",
prompt_version="security-v1.0",
)
data = finding.model_dump()
assert data["id"] == "test-123"
assert data["agent"] == "security"
assert data["severity"] == "high"
class TestAgentConfig:
def test_default_values(self) -> None:
config = AgentConfig()
assert config.enabled is True
assert config.model is None
assert config.severity_threshold == Severity.INFO
assert config.prompt_additions is None
def test_custom_values(self) -> None:
config = AgentConfig(
enabled=False,
model="gpt-4o",
severity_threshold=Severity.MEDIUM,
prompt_additions="Focus on auth",
)
assert config.enabled is False
assert config.model == "gpt-4o"
assert config.severity_threshold == Severity.MEDIUM
class TestPolicy:
def test_default_policy(self) -> None:
policy = Policy()
assert len(policy.agents) == 3
assert AgentName.SECURITY in policy.agents
assert AgentName.STYLE in policy.agents
assert AgentName.COMPLEXITY in policy.agents
def test_get_enabled_agents(self) -> None:
policy = Policy()
enabled = policy.get_enabled_agents()
assert len(enabled) == 3
assert AgentName.SECURITY in enabled
def test_get_enabled_agents_with_disabled(self) -> None:
policy = Policy(
agents={
AgentName.SECURITY: AgentConfig(enabled=True),
AgentName.STYLE: AgentConfig(enabled=False),
AgentName.COMPLEXITY: AgentConfig(enabled=True),
}
)
enabled = policy.get_enabled_agents()
assert len(enabled) == 2
assert AgentName.STYLE not in enabled
def test_load_from_yaml(self, tmp_path: Path) -> None:
policy_file = tmp_path / "policy.yaml"
policy_file.write_text("""
version: "1.1"
agents:
security:
enabled: true
model: gpt-4o
severity_threshold: high
style:
enabled: false
complexity:
enabled: true
""")
policy = Policy.load(policy_file)
assert policy.version == "1.1"
assert policy.agents[AgentName.SECURITY].model == "gpt-4o"
assert policy.agents[AgentName.SECURITY].severity_threshold == Severity.HIGH
assert policy.agents[AgentName.STYLE].enabled is False
def test_load_empty_yaml(self, tmp_path: Path) -> None:
policy_file = tmp_path / "policy.yaml"
policy_file.write_text("")
policy = Policy.load(policy_file)
assert policy.version == "1.0"
assert len(policy.agents) == 3
class TestReviewResult:
def test_review_result_creation(self) -> None:
result = ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=1000,
tokens_used=500,
cost_usd=0.01,
)
assert result.agent_name == AgentName.SECURITY
assert result.duration_ms == 1000
assert result.cost_usd == 0.01
def test_review_result_with_findings(self) -> None:
finding = Finding(
id="test-123",
agent=AgentName.SECURITY,
file="test.py",
line_start=1,
line_end=1,
severity=Severity.HIGH,
confidence=0.9,
title="Test",
description="Test",
reasoning="Test",
prompt_version="test-v1.0",
)
result = ReviewResult(
agent_name=AgentName.SECURITY,
findings=[finding],
duration_ms=1000,
tokens_used=500,
cost_usd=0.01,
)
assert len(result.findings) == 1
assert result.findings[0].id == "test-123"