feat(agents): implement agent framework and CLI
This commit is contained in:
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Arbiter test suite."""
|
||||
490
tests/conftest.py
Normal file
490
tests/conftest.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""Pytest configuration and fixtures."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from arbiter.db.models import Base, FindingModel, ReviewModel
|
||||
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
|
||||
from arbiter.llm.client import LLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import Policy
|
||||
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||
|
||||
|
||||
class MockLLMClient(LLMClient):
|
||||
"""Mock LLM client for testing."""
|
||||
|
||||
def __init__(self, responses: list[str] | None = None) -> None:
|
||||
self.responses = responses or []
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._call_index = 0
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Record the call and return a canned response."""
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
content = ""
|
||||
if self._call_index < len(self.responses):
|
||||
content = self.responses[self._call_index]
|
||||
self._call_index += 1
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.calls = []
|
||||
self._call_index = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MockLLMClient:
|
||||
return MockLLMClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_with_findings() -> MockLLMClient:
|
||||
response = """```json
|
||||
[
|
||||
{
|
||||
"file": "src/auth.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "high",
|
||||
"confidence": 0.9,
|
||||
"title": "SQL Injection vulnerability",
|
||||
"description": "User input is directly concatenated into SQL query",
|
||||
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||
"suggestion": "Use parameterized queries instead",
|
||||
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
|
||||
}
|
||||
]
|
||||
```"""
|
||||
return MockLLMClient(responses=[response])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_diff() -> str:
|
||||
return """diff --git a/src/auth.py b/src/auth.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/auth.py
|
||||
+++ b/src/auth.py
|
||||
@@ -8,6 +8,12 @@ def authenticate(username, password):
|
||||
if not username or not password:
|
||||
return None
|
||||
|
||||
+ # Check user in database
|
||||
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||
+ cursor.execute(query)
|
||||
+ user = cursor.fetchone()
|
||||
+ if user and user.password == password:
|
||||
+ return user
|
||||
return None
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_policy() -> Policy:
|
||||
return Policy()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_registry(tmp_path: Path) -> PromptRegistry:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
(templates_dir / "security-v1.0.md").write_text(
|
||||
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
(templates_dir / "style-v1.0.md").write_text(
|
||||
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
(templates_dir / "complexity-v1.0.md").write_text(
|
||||
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
|
||||
return PromptRegistry(templates_dir)
|
||||
|
||||
|
||||
# Database fixtures for integration tests
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
session_factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||
from arbiter.api.deps import get_db, get_redis
|
||||
from arbiter.main import app
|
||||
|
||||
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
|
||||
yield MockRedis()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_redis] = override_get_redis
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class MockRedis:
|
||||
"""Mock Redis client for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, Any] = {}
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._data.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||
self._data[key] = value
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def ping(self) -> bool:
|
||||
return True
|
||||
|
||||
async def llen(self, key: str) -> int:
|
||||
data = self._data.get(key, [])
|
||||
return len(data) if isinstance(data, list) else 0
|
||||
|
||||
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
|
||||
return type("Job", (), {"job_id": "test-job-id"})()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis() -> MockRedis:
|
||||
return MockRedis()
|
||||
|
||||
|
||||
class MockPlatformClient(PlatformClient):
|
||||
"""Mock platform client for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._comments: list[Comment] = []
|
||||
self._posted_comments: list[dict[str, Any]] = []
|
||||
self._status_updates: list[dict[str, Any]] = []
|
||||
self._diff = "mock diff content"
|
||||
self._closed = False
|
||||
self._fail_on: set[str] = set() # Methods to fail on
|
||||
|
||||
@property
|
||||
def platform(self) -> Platform:
|
||||
return Platform.GITHUB
|
||||
|
||||
async def get_pr_diff(
|
||||
self,
|
||||
repository: str, # noqa: ARG002
|
||||
pr_number: int, # noqa: ARG002
|
||||
) -> str:
|
||||
if "get_pr_diff" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_pr_diff")
|
||||
return self._diff
|
||||
|
||||
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
||||
if "post_comment" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: post_comment")
|
||||
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
|
||||
self._posted_comments.append(
|
||||
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
|
||||
)
|
||||
return comment_url
|
||||
|
||||
async def update_commit_status(
|
||||
self,
|
||||
repository: str,
|
||||
sha: str,
|
||||
status: CommitStatus,
|
||||
description: str,
|
||||
context: str,
|
||||
target_url: str | None = None,
|
||||
) -> None:
|
||||
if "update_commit_status" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: update_commit_status")
|
||||
self._status_updates.append(
|
||||
{
|
||||
"repository": repository,
|
||||
"sha": sha,
|
||||
"status": status,
|
||||
"description": description,
|
||||
"context": context,
|
||||
"target_url": target_url,
|
||||
}
|
||||
)
|
||||
|
||||
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
|
||||
if "get_pr_info" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_pr_info")
|
||||
from arbiter.integrations.base import PullRequestInfo
|
||||
|
||||
return PullRequestInfo(
|
||||
platform=Platform.GITHUB,
|
||||
repository=repository,
|
||||
pr_number=pr_number,
|
||||
head_sha="abc123",
|
||||
base_sha="def456",
|
||||
head_ref="feature",
|
||||
base_ref="main",
|
||||
title="Test PR",
|
||||
author="testuser",
|
||||
url=f"https://github.com/{repository}/pull/{pr_number}",
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
async def get_comments(
|
||||
self,
|
||||
repository: str, # noqa: ARG002
|
||||
pr_number: int, # noqa: ARG002
|
||||
) -> list[Comment]:
|
||||
if "get_comments" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_comments")
|
||||
return self._comments
|
||||
|
||||
async def update_comment(
|
||||
self, repository: str, pr_number: int, comment_id: str, body: str
|
||||
) -> str:
|
||||
if "update_comment" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: update_comment")
|
||||
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
|
||||
self._posted_comments.append(
|
||||
{
|
||||
"repository": repository,
|
||||
"pr_number": pr_number,
|
||||
"comment_id": comment_id,
|
||||
"body": body,
|
||||
"url": comment_url,
|
||||
}
|
||||
)
|
||||
return comment_url
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_platform_client() -> MockPlatformClient:
|
||||
return MockPlatformClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
|
||||
review = ReviewModel(
|
||||
id=str(uuid4()),
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
pr_title="Test PR",
|
||||
base_sha="abc1234567890",
|
||||
head_sha="def0987654321",
|
||||
author="testuser",
|
||||
is_draft=False,
|
||||
status="completed",
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.75,
|
||||
verdict_reasoning="Found some issues to discuss",
|
||||
total_tokens=1500,
|
||||
total_cost_usd=0.015,
|
||||
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
|
||||
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
|
||||
created_at=datetime.now(UTC),
|
||||
started_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
db_session.add(review)
|
||||
|
||||
# Add findings
|
||||
finding1 = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection vulnerability",
|
||||
description="User input is directly concatenated into SQL query",
|
||||
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||
suggestion="Use parameterized queries instead",
|
||||
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
finding2 = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.STYLE,
|
||||
file="src/auth.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.MEDIUM,
|
||||
confidence=0.85,
|
||||
title="Long function",
|
||||
description="Function exceeds 50 lines",
|
||||
reasoning="Long functions are harder to test and maintain",
|
||||
suggestion="Consider breaking into smaller functions",
|
||||
references=[],
|
||||
prompt_version="style-v1.0",
|
||||
)
|
||||
db_session.add(finding1)
|
||||
db_session.add(finding2)
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(review)
|
||||
return review
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
|
||||
reviews = []
|
||||
for i in range(5):
|
||||
review = ReviewModel(
|
||||
id=str(uuid4()),
|
||||
repository="owner/repo" if i < 3 else "other/repo",
|
||||
pr_number=i + 1,
|
||||
pr_title=f"Test PR #{i + 1}",
|
||||
base_sha=f"base{i:040d}",
|
||||
head_sha=f"head{i:040d}",
|
||||
author="testuser" if i % 2 == 0 else "otheruser",
|
||||
is_draft=False,
|
||||
status="completed" if i < 4 else "failed",
|
||||
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
|
||||
verdict_confidence=0.8 if i < 4 else None,
|
||||
total_tokens=1000 * (i + 1),
|
||||
total_cost_usd=0.01 * (i + 1),
|
||||
created_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC) if i < 4 else None,
|
||||
)
|
||||
db_session.add(review)
|
||||
reviews.append(review)
|
||||
|
||||
# Add a finding to each completed review
|
||||
if i < 4:
|
||||
finding = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title=f"Finding {i + 1}",
|
||||
description="Test finding",
|
||||
reasoning="Test reasoning",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
db_session.add(finding)
|
||||
|
||||
await db_session.commit()
|
||||
for review in reviews:
|
||||
await db_session.refresh(review)
|
||||
return reviews
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> Any:
|
||||
class MockSecretStr:
|
||||
def __init__(self, value: str) -> None:
|
||||
self._value = value
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
return self._value
|
||||
|
||||
class MockSettings:
|
||||
github_token = MockSecretStr("ghp_test_token")
|
||||
gitlab_token = MockSecretStr("glpat_test_token")
|
||||
github_base_url = "https://api.github.com"
|
||||
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||
github_webhook_secret = MockSecretStr("webhook_secret")
|
||||
gitlab_webhook_token = MockSecretStr("gitlab_token")
|
||||
integration_timeout = 30
|
||||
integration_max_retries = 3
|
||||
llm_timeout = 60
|
||||
llm_max_retries = 3
|
||||
post_comments = True
|
||||
update_status = True
|
||||
status_check_context = "arbiter"
|
||||
templates_dir = Path("templates")
|
||||
followup_enabled = True
|
||||
followup_confidence_threshold = 0.5
|
||||
|
||||
return MockSettings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_no_github() -> Any:
|
||||
class MockSettings:
|
||||
github_token = None
|
||||
gitlab_token = None
|
||||
github_base_url = "https://api.github.com"
|
||||
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||
integration_timeout = 30
|
||||
integration_max_retries = 3
|
||||
post_comments = False
|
||||
update_status = False
|
||||
status_check_context = "arbiter"
|
||||
templates_dir = Path("templates")
|
||||
|
||||
return MockSettings()
|
||||
69
tests/fixtures/complex-function.diff
vendored
Normal file
69
tests/fixtures/complex-function.diff
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
diff --git a/src/processor.py b/src/processor.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/processor.py
|
||||
+++ b/src/processor.py
|
||||
@@ -1,5 +1,65 @@
|
||||
"""Data processor module."""
|
||||
|
||||
|
||||
+def process_data(data: dict, config: dict, options: dict | None = None) -> dict:
|
||||
+ """Process data with many nested conditions."""
|
||||
+ result = {}
|
||||
+ options = options or {}
|
||||
+
|
||||
+ if data.get("type") == "A":
|
||||
+ if config.get("mode") == "strict":
|
||||
+ if options.get("validate"):
|
||||
+ if data.get("value") > 100:
|
||||
+ if config.get("transform"):
|
||||
+ result["processed"] = data["value"] * 2
|
||||
+ else:
|
||||
+ result["processed"] = data["value"]
|
||||
+ else:
|
||||
+ if options.get("default"):
|
||||
+ result["processed"] = options["default"]
|
||||
+ else:
|
||||
+ result["processed"] = 0
|
||||
+ else:
|
||||
+ result["processed"] = data.get("value", 0)
|
||||
+ else:
|
||||
+ result["processed"] = data.get("value", 0)
|
||||
+ elif data.get("type") == "B":
|
||||
+ if config.get("mode") == "strict":
|
||||
+ if options.get("validate"):
|
||||
+ if data.get("items"):
|
||||
+ result["processed"] = len(data["items"])
|
||||
+ else:
|
||||
+ result["processed"] = 0
|
||||
+ else:
|
||||
+ result["processed"] = len(data.get("items", []))
|
||||
+ else:
|
||||
+ result["processed"] = len(data.get("items", []))
|
||||
+ elif data.get("type") == "C":
|
||||
+ if config.get("mode") == "strict":
|
||||
+ if options.get("validate"):
|
||||
+ if data.get("text"):
|
||||
+ result["processed"] = data["text"].upper()
|
||||
+ else:
|
||||
+ result["processed"] = ""
|
||||
+ else:
|
||||
+ result["processed"] = data.get("text", "").upper()
|
||||
+ else:
|
||||
+ result["processed"] = data.get("text", "").upper()
|
||||
+ else:
|
||||
+ if config.get("fallback"):
|
||||
+ result["processed"] = config["fallback"]
|
||||
+ else:
|
||||
+ result["processed"] = None
|
||||
+
|
||||
+ if options.get("timestamp"):
|
||||
+ result["timestamp"] = options["timestamp"]
|
||||
+ if options.get("source"):
|
||||
+ result["source"] = options["source"]
|
||||
+
|
||||
+ return result
|
||||
+
|
||||
+
|
||||
def simple_function(x: int) -> int:
|
||||
"""A simple function."""
|
||||
return x * 2
|
||||
31
tests/fixtures/security-issue.diff
vendored
Normal file
31
tests/fixtures/security-issue.diff
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
diff --git a/src/auth.py b/src/auth.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/auth.py
|
||||
+++ b/src/auth.py
|
||||
@@ -1,10 +1,25 @@
|
||||
"""Authentication module."""
|
||||
|
||||
import sqlite3
|
||||
+import os
|
||||
|
||||
|
||||
def get_user(username: str) -> dict | None:
|
||||
"""Get user from database."""
|
||||
conn = sqlite3.connect("users.db")
|
||||
cursor = conn.cursor()
|
||||
- cursor.execute("SELECT * FROM users WHERE username = ?", (username,))
|
||||
+ # FIXME: this is vulnerable to SQL injection
|
||||
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||
+ cursor.execute(query)
|
||||
return cursor.fetchone()
|
||||
+
|
||||
+
|
||||
+def run_command(cmd: str) -> str:
|
||||
+ """Run a shell command."""
|
||||
+ # Command injection vulnerability
|
||||
+ return os.popen(cmd).read()
|
||||
+
|
||||
+
|
||||
+# Hardcoded credentials
|
||||
+API_KEY = "sk-1234567890abcdef"
|
||||
+DB_PASSWORD = "admin123"
|
||||
16
tests/fixtures/simple.diff
vendored
Normal file
16
tests/fixtures/simple.diff
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
diff --git a/src/utils.py b/src/utils.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/utils.py
|
||||
+++ b/src/utils.py
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Utility functions."""
|
||||
|
||||
|
||||
+def add(a: int, b: int) -> int:
|
||||
+ """Add two numbers."""
|
||||
+ return a + b
|
||||
+
|
||||
+
|
||||
def subtract(a: int, b: int) -> int:
|
||||
"""Subtract two numbers."""
|
||||
return a - b
|
||||
305
tests/test_agents.py
Normal file
305
tests/test_agents.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Tests for review agents."""
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.agents import ComplexityAgent, ReviewContext, SecurityAgent, StyleAgent
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import AgentConfig, AgentName, Policy, Severity
|
||||
from tests.conftest import MockLLMClient
|
||||
|
||||
|
||||
class TestSecurityAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_returns_result(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.agent_name == AgentName.SECURITY
|
||||
assert result.findings == []
|
||||
assert result.duration_ms >= 0
|
||||
assert result.tokens_used == 150 # 100 in + 50 out from mock
|
||||
assert result.cost_usd == 0.001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_json_findings(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """```json
|
||||
[
|
||||
{
|
||||
"file": "src/auth.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "high",
|
||||
"confidence": 0.9,
|
||||
"title": "SQL Injection",
|
||||
"description": "User input concatenated",
|
||||
"reasoning": "Allows SQL injection",
|
||||
"suggestion": "Use parameterized queries",
|
||||
"references": ["https://owasp.org"]
|
||||
}
|
||||
]
|
||||
```"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ query = ...", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert len(result.findings) == 1
|
||||
finding = result.findings[0]
|
||||
assert finding.file == "src/auth.py"
|
||||
assert finding.severity == Severity.HIGH
|
||||
assert finding.confidence == 0.9
|
||||
assert finding.title == "SQL Injection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_configured_model(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(model="gpt-4o-mini"),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
}
|
||||
)
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
await agent.review(context)
|
||||
|
||||
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_by_severity(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """[
|
||||
{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "high", "confidence": 0.9, "title": "High", "description": "", "reasoning": ""},
|
||||
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "low", "confidence": 0.9, "title": "Low", "description": "", "reasoning": ""},
|
||||
{"file": "c.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.9, "title": "Info", "description": "", "reasoning": ""}
|
||||
]"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(severity_threshold=Severity.MEDIUM),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
}
|
||||
)
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
# Only high severity should pass (medium threshold filters low and info)
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].severity == Severity.HIGH
|
||||
|
||||
|
||||
class TestStyleAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_returns_result(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = StyleAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.agent_name == AgentName.STYLE
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_default_model(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = StyleAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
await agent.review(context)
|
||||
|
||||
assert mock_llm.calls[0]["model"] == "gpt-4o-mini"
|
||||
|
||||
|
||||
class TestComplexityAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_review_returns_result(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = ComplexityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ some code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.agent_name == AgentName.COMPLEXITY
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parses_complexity_findings(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """[
|
||||
{
|
||||
"file": "processor.py",
|
||||
"line_start": 1,
|
||||
"line_end": 50,
|
||||
"severity": "medium",
|
||||
"confidence": 0.8,
|
||||
"title": "High cyclomatic complexity",
|
||||
"description": "Function has 15 branches",
|
||||
"reasoning": "Makes testing and maintenance difficult"
|
||||
}
|
||||
]"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = ComplexityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ complex code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].severity == Severity.MEDIUM
|
||||
assert "complexity" in result.findings[0].title.lower()
|
||||
|
||||
|
||||
class TestAgentResponseParsing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_empty_response(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=[""])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_invalid_json(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["not valid json"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_json_without_code_block(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = '[{"file": "a.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Test", "description": "", "reasoning": ""}]'
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert len(result.findings) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_malformed_finding(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
response = """[
|
||||
{"file": "a.py", "line_start": 1, "severity": "invalid_severity", "confidence": 0.5, "title": "Bad", "description": "", "reasoning": ""},
|
||||
{"file": "b.py", "line_start": 1, "line_end": 1, "severity": "info", "confidence": 0.5, "title": "Valid", "description": "", "reasoning": ""}
|
||||
]"""
|
||||
mock_llm = MockLLMClient(responses=[response])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
# Only the valid finding should be included (first has invalid severity)
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].title == "Valid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_includes_prompt_additions(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(prompt_additions="Focus on authentication"),
|
||||
AgentName.STYLE: AgentConfig(),
|
||||
AgentName.COMPLEXITY: AgentConfig(),
|
||||
}
|
||||
)
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
await agent.review(context)
|
||||
|
||||
message_content = mock_llm.calls[0]["messages"][0]["content"]
|
||||
assert "Focus on authentication" in message_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_non_list_json(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=['{"not": "a list"}'])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_non_dict_items(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=['["string", 123, null]'])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
context = ReviewContext(diff="+ code", policy=Policy())
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
assert result.findings == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_without_config_uses_defaults(
|
||||
self,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
mock_llm = MockLLMClient(responses=["[]"])
|
||||
agent = SecurityAgent(mock_llm, prompt_registry)
|
||||
# Create policy with empty agents dict
|
||||
policy = Policy(agents={})
|
||||
context = ReviewContext(diff="+ code", policy=policy)
|
||||
|
||||
result = await agent.review(context)
|
||||
|
||||
# Should use default model (gpt-4o for security)
|
||||
assert mock_llm.calls[0]["model"] == "gpt-4o"
|
||||
assert result.findings == []
|
||||
561
tests/test_cli.py
Normal file
561
tests/test_cli.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""Tests for CLI commands."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from arbiter.cli import (
|
||||
_severity_color,
|
||||
_severity_icon,
|
||||
_verdict_color,
|
||||
_verdict_icon,
|
||||
app,
|
||||
)
|
||||
from arbiter.deliberation import DeliberationResult
|
||||
from arbiter.models import AgentName, Finding, ReviewResult, Severity, Verdict
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def make_mock_return(
|
||||
findings: list[Finding] | None = None, verdict: Verdict = Verdict.APPROVE
|
||||
) -> tuple[list[ReviewResult], DeliberationResult]:
|
||||
"""Create a mock return value for _run_review."""
|
||||
agent_results = [
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=findings or [],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
]
|
||||
deliberation_result = DeliberationResult(
|
||||
findings=findings or [],
|
||||
verdict=verdict,
|
||||
verdict_confidence=0.9,
|
||||
verdict_reasoning="Test reasoning",
|
||||
total_findings=len(findings) if findings else 0,
|
||||
)
|
||||
return agent_results, deliberation_result
|
||||
|
||||
|
||||
class TestVersionCommand:
|
||||
def test_version_output(self) -> None:
|
||||
result = runner.invoke(app, ["version"])
|
||||
assert result.exit_code == 0
|
||||
assert "arbiter" in result.output
|
||||
assert "0.3.0" in result.output
|
||||
|
||||
|
||||
class TestReviewCommand:
|
||||
def test_file_not_found(self) -> None:
|
||||
result = runner.invoke(app, ["review", "nonexistent.diff"])
|
||||
assert result.exit_code == 1
|
||||
assert "File not found" in result.output
|
||||
|
||||
def test_empty_diff_warning(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "empty.diff"
|
||||
diff_file.write_text("")
|
||||
|
||||
result = runner.invoke(app, ["review", str(diff_file)])
|
||||
assert result.exit_code == 0
|
||||
assert "Empty diff" in result.output
|
||||
|
||||
def test_policy_not_found(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--policy", "nonexistent.yaml"])
|
||||
assert result.exit_code == 1
|
||||
assert "Policy file not found" in result.output
|
||||
|
||||
def test_reads_from_stdin(self) -> None:
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", "-"], input="+ added line\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_json_output_format(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert '"verdict"' in result.output
|
||||
assert '"findings"' in result.output
|
||||
|
||||
def test_markdown_output_format(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "# Arbiter Review" in result.output
|
||||
|
||||
def test_loads_policy_file(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
policy_file = tmp_path / "policy.yaml"
|
||||
policy_file.write_text("""
|
||||
version: "1.0"
|
||||
agents:
|
||||
security:
|
||||
enabled: true
|
||||
style:
|
||||
enabled: false
|
||||
complexity:
|
||||
enabled: false
|
||||
""")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--policy", str(policy_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Verify policy was passed to _run_review
|
||||
call_args = mock_run.call_args
|
||||
policy = call_args[0][1] # Second positional arg is policy
|
||||
assert len(policy.get_enabled_agents()) == 1
|
||||
|
||||
def test_model_override(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--model", "gpt-4o-mini"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Verify model was set in policy
|
||||
call_args = mock_run.call_args
|
||||
policy = call_args[0][1]
|
||||
for config in policy.agents.values():
|
||||
assert config.model == "gpt-4o-mini"
|
||||
|
||||
def test_static_analysis_flag(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--no-static-analysis"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Verify static_analysis was False
|
||||
call_args = mock_run.call_args
|
||||
assert call_args.kwargs.get("static_analysis") is False
|
||||
|
||||
|
||||
class TestNoArgsHelp:
|
||||
def test_no_args_shows_help(self) -> None:
|
||||
result = runner.invoke(app, [])
|
||||
assert result.exit_code == 0
|
||||
assert "A multi-agent code review system" in result.output
|
||||
|
||||
|
||||
class TestOutputFormatting:
|
||||
def test_severity_color(self) -> None:
|
||||
assert _severity_color(Severity.CRITICAL) == "red bold"
|
||||
assert _severity_color(Severity.HIGH) == "red"
|
||||
assert _severity_color(Severity.MEDIUM) == "yellow"
|
||||
assert _severity_color(Severity.LOW) == "blue"
|
||||
assert _severity_color(Severity.INFO) == "dim"
|
||||
|
||||
def test_severity_icon(self) -> None:
|
||||
assert _severity_icon(Severity.CRITICAL) == "!!"
|
||||
assert _severity_icon(Severity.HIGH) == "!"
|
||||
assert _severity_icon(Severity.MEDIUM) == "*"
|
||||
assert _severity_icon(Severity.LOW) == "-"
|
||||
assert _severity_icon(Severity.INFO) == "i"
|
||||
|
||||
def test_verdict_color(self) -> None:
|
||||
assert _verdict_color(Verdict.APPROVE) == "green"
|
||||
assert _verdict_color(Verdict.COMMENT) == "yellow"
|
||||
assert _verdict_color(Verdict.REQUEST_CHANGES) == "red"
|
||||
|
||||
def test_verdict_icon(self) -> None:
|
||||
assert _verdict_icon(Verdict.APPROVE) == "[ok]"
|
||||
assert _verdict_icon(Verdict.COMMENT) == "[..]"
|
||||
assert _verdict_icon(Verdict.REQUEST_CHANGES) == "[!!]"
|
||||
|
||||
|
||||
class TestRichOutput:
|
||||
def test_rich_format_with_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
finding = Finding(
|
||||
id="test-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input in query",
|
||||
reasoning="Direct concatenation",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return(findings=[finding])
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_rich_format_no_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No issues found" in result.output
|
||||
|
||||
def test_rich_format_critical_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
finding = Finding(
|
||||
id="test-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=10,
|
||||
severity=Severity.CRITICAL,
|
||||
confidence=0.95,
|
||||
title="Critical Issue",
|
||||
description="This is critical",
|
||||
reasoning="Very bad",
|
||||
suggestion="Fix it immediately",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
findings=[finding],
|
||||
verdict=Verdict.REQUEST_CHANGES,
|
||||
verdict_confidence=0.95,
|
||||
verdict_reasoning="Critical issue found",
|
||||
total_findings=1,
|
||||
critical_count=1,
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[finding],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestMarkdownOutput:
|
||||
def test_markdown_with_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
finding = Finding(
|
||||
id="test-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input in query",
|
||||
reasoning="Direct concatenation",
|
||||
suggestion="Use parameterized queries",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return(findings=[finding])
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "## Findings" in result.output
|
||||
assert "SQL Injection" in result.output
|
||||
|
||||
def test_markdown_verdict_badges(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
for verdict in [Verdict.APPROVE, Verdict.COMMENT, Verdict.REQUEST_CHANGES]:
|
||||
deliberation = DeliberationResult(
|
||||
verdict=verdict,
|
||||
verdict_confidence=0.9,
|
||||
verdict_reasoning="Test",
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert verdict.value.upper() in result.output
|
||||
|
||||
|
||||
class TestJsonOutput:
|
||||
def test_json_with_conflicts(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
|
||||
conflict = Conflict(
|
||||
id="test-conflict",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.TRADE_OFF,
|
||||
description="Test conflict",
|
||||
severity_weight=0.8,
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.7,
|
||||
verdict_reasoning="Conflicts found",
|
||||
conflicts=[conflict],
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "json"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
output = json.loads(result.output)
|
||||
assert "conflicts" in output
|
||||
assert len(output["conflicts"]) == 1
|
||||
|
||||
|
||||
class TestWorkDirHandling:
|
||||
def test_work_dir_option(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
work_dir = tmp_path / "src"
|
||||
work_dir.mkdir()
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = make_mock_return()
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--work-dir", str(work_dir)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
call_args = mock_run.call_args
|
||||
assert call_args.kwargs.get("work_dir") == work_dir.resolve()
|
||||
|
||||
|
||||
class TestRichOutputWithConflicts:
|
||||
def test_rich_conflicts(self, tmp_path: Path) -> None:
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
from arbiter.deliberation.synthesis import Resolution
|
||||
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
conflict = Conflict(
|
||||
id="test-conflict",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.TRADE_OFF,
|
||||
description="Security vs complexity trade-off",
|
||||
severity_weight=0.8,
|
||||
)
|
||||
|
||||
resolution = Resolution(
|
||||
conflict_id="test-conflict",
|
||||
decision="prefer_first",
|
||||
reasoning="Security takes priority",
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.7,
|
||||
verdict_reasoning="Conflicts found",
|
||||
conflicts=[conflict],
|
||||
resolutions=[resolution],
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "rich"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
class TestMarkdownOutputWithConflicts:
|
||||
def test_markdown_conflicts(self, tmp_path: Path) -> None:
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
from arbiter.deliberation.synthesis import Resolution
|
||||
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
conflict = Conflict(
|
||||
id="test-conflict",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.CONTRADICTORY,
|
||||
description="Contradictory recommendations",
|
||||
severity_weight=0.9,
|
||||
)
|
||||
|
||||
resolution = Resolution(
|
||||
conflict_id="test-conflict",
|
||||
decision="merge",
|
||||
reasoning="Both concerns addressed by combined fix",
|
||||
merged_suggestion="Do both things",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.7,
|
||||
verdict_reasoning="Conflicts found",
|
||||
conflicts=[conflict],
|
||||
resolutions=[resolution],
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "## Conflicts" in result.output
|
||||
assert "Contradictory" in result.output
|
||||
assert "Resolution" in result.output
|
||||
|
||||
def test_markdown_findings(self, tmp_path: Path) -> None:
|
||||
diff_file = tmp_path / "test.diff"
|
||||
diff_file.write_text("+ some change")
|
||||
|
||||
findings = [
|
||||
Finding(
|
||||
id="f1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="Security Issue",
|
||||
description="Vulnerable code",
|
||||
reasoning="Bad pattern",
|
||||
suggestion="Fix it this way",
|
||||
prompt_version="test-v1.0",
|
||||
),
|
||||
Finding(
|
||||
id="f2",
|
||||
agent=AgentName.STYLE,
|
||||
file="test.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.LOW,
|
||||
confidence=0.8,
|
||||
title="Style Issue",
|
||||
description="Could be cleaner",
|
||||
reasoning="Convention",
|
||||
prompt_version="test-v1.0",
|
||||
),
|
||||
]
|
||||
|
||||
deliberation = DeliberationResult(
|
||||
findings=findings,
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.75,
|
||||
verdict_reasoning="Issues found",
|
||||
total_findings=2,
|
||||
)
|
||||
|
||||
with patch("arbiter.cli._run_review", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (
|
||||
[
|
||||
ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[findings[0]],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
),
|
||||
ReviewResult(
|
||||
agent_name=AgentName.STYLE,
|
||||
findings=[findings[1]],
|
||||
duration_ms=100,
|
||||
tokens_used=100,
|
||||
cost_usd=0.001,
|
||||
),
|
||||
],
|
||||
deliberation,
|
||||
)
|
||||
result = runner.invoke(app, ["review", str(diff_file), "--format", "markdown"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Security Issue" in result.output
|
||||
assert "Style Issue" in result.output
|
||||
assert "Fix it this way" in result.output
|
||||
313
tests/test_llm.py
Normal file
313
tests/test_llm.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Tests for LLM client and prompts."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.llm.client import LiteLLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry, PromptTemplate
|
||||
from tests.conftest import MockLLMClient
|
||||
|
||||
|
||||
class TestLiteLLMClient:
|
||||
def test_init_default_values(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
assert client.timeout == 60
|
||||
assert client.max_retries == 3
|
||||
|
||||
def test_init_custom_values(self) -> None:
|
||||
client = LiteLLMClient(timeout=120, max_retries=5)
|
||||
assert client.timeout == 120
|
||||
assert client.max_retries == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_returns_response(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
# Mock the litellm.acompletion function
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.001
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.content == "Test response"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.tokens_in == 10
|
||||
assert response.tokens_out == 5
|
||||
assert response.cost_usd == 0.001
|
||||
|
||||
mock_acomp.assert_called_once_with(
|
||||
model="gpt-4o",
|
||||
messages=messages,
|
||||
timeout=60,
|
||||
num_retries=3,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_handles_empty_content(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = None # None content
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 5
|
||||
mock_response.usage.completion_tokens = 0
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.content == "" # Should be empty string, not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_handles_missing_usage(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.model = "gpt-4o"
|
||||
mock_response.usage = None # No usage data
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response.tokens_in == 0
|
||||
assert response.tokens_out == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_uses_fallback_model(self) -> None:
|
||||
client = LiteLLMClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Response"
|
||||
mock_response.model = None # No model in response
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.prompt_tokens = 5
|
||||
mock_response.usage.completion_tokens = 3
|
||||
|
||||
with (
|
||||
patch("arbiter.llm.client.litellm.acompletion", new_callable=AsyncMock) as mock_acomp,
|
||||
patch("arbiter.llm.client.litellm.completion_cost") as mock_cost,
|
||||
):
|
||||
mock_acomp.return_value = mock_response
|
||||
mock_cost.return_value = 0.0
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = await client.complete(messages, "claude-3-opus")
|
||||
|
||||
# Should use the passed model as fallback
|
||||
assert response.model == "claude-3-opus"
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
def test_response_creation(self) -> None:
|
||||
response = LLMResponse(
|
||||
content="Hello, world!",
|
||||
model="gpt-4o",
|
||||
tokens_in=10,
|
||||
tokens_out=5,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
assert response.content == "Hello, world!"
|
||||
assert response.model == "gpt-4o"
|
||||
assert response.tokens_in == 10
|
||||
assert response.tokens_out == 5
|
||||
assert response.cost_usd == 0.001
|
||||
|
||||
|
||||
class TestMockLLMClient:
|
||||
@pytest.mark.asyncio
|
||||
async def test_records_calls(self) -> None:
|
||||
client = MockLLMClient()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert len(client.calls) == 1
|
||||
assert client.calls[0]["messages"] == messages
|
||||
assert client.calls[0]["model"] == "gpt-4o"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_canned_responses(self) -> None:
|
||||
client = MockLLMClient(responses=["First", "Second"])
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
response1 = await client.complete(messages, "gpt-4o")
|
||||
response2 = await client.complete(messages, "gpt-4o")
|
||||
response3 = await client.complete(messages, "gpt-4o")
|
||||
|
||||
assert response1.content == "First"
|
||||
assert response2.content == "Second"
|
||||
assert response3.content == "" # Exhausted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset(self) -> None:
|
||||
client = MockLLMClient(responses=["Hello"])
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
|
||||
await client.complete(messages, "gpt-4o")
|
||||
assert len(client.calls) == 1
|
||||
|
||||
client.reset()
|
||||
assert len(client.calls) == 0
|
||||
|
||||
response = await client.complete(messages, "gpt-4o")
|
||||
assert response.content == "Hello" # Responses reset too
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
def test_template_creation(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="security",
|
||||
version="1.0",
|
||||
content="Review: {{diff}}",
|
||||
)
|
||||
assert template.name == "security"
|
||||
assert template.version == "1.0"
|
||||
assert template.full_name == "security-v1.0"
|
||||
|
||||
def test_render_substitution(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="File: {{file}}\nDiff: {{diff}}",
|
||||
)
|
||||
result = template.render(file="test.py", diff="+ added line")
|
||||
assert result == "File: test.py\nDiff: + added line"
|
||||
|
||||
def test_render_missing_variable(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="Value: {{value}}",
|
||||
)
|
||||
result = template.render()
|
||||
assert result == "Value: {{value}}"
|
||||
|
||||
def test_render_multiple_occurrences(self) -> None:
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="1.0",
|
||||
content="{{name}} and {{name}} again",
|
||||
)
|
||||
result = template.render(name="test")
|
||||
assert result == "test and test again"
|
||||
|
||||
|
||||
class TestPromptRegistry:
|
||||
def test_get_template(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template = registry.get("security", "1.0")
|
||||
|
||||
assert template.name == "security"
|
||||
assert template.version == "1.0"
|
||||
assert "{{diff}}" in template.content
|
||||
|
||||
def test_get_template_cached(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template1 = registry.get("security", "1.0")
|
||||
template2 = registry.get("security", "1.0")
|
||||
|
||||
assert template1 is template2
|
||||
|
||||
def test_get_template_not_found(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
registry.get("missing", "1.0")
|
||||
|
||||
def test_list_templates(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "security-v1.0.md").write_text("Content")
|
||||
(templates_dir / "style-v2.0.md").write_text("Content")
|
||||
(templates_dir / "readme.md").write_text("Not a template")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert len(templates) == 2
|
||||
assert ("security", "1.0") in templates
|
||||
assert ("style", "2.0") in templates
|
||||
|
||||
def test_list_templates_empty_dir(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert templates == []
|
||||
|
||||
def test_list_templates_missing_dir(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "missing"
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
templates = registry.list_templates()
|
||||
|
||||
assert templates == []
|
||||
|
||||
def test_clear_cache(self, tmp_path: Path) -> None:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
(templates_dir / "test-v1.0.md").write_text("Original")
|
||||
|
||||
registry = PromptRegistry(templates_dir)
|
||||
template1 = registry.get("test", "1.0")
|
||||
assert template1.content == "Original"
|
||||
|
||||
# Modify file
|
||||
(templates_dir / "test-v1.0.md").write_text("Modified")
|
||||
|
||||
# Still cached
|
||||
template2 = registry.get("test", "1.0")
|
||||
assert template2.content == "Original"
|
||||
|
||||
# Clear cache
|
||||
registry.clear_cache()
|
||||
|
||||
# Now reads new content
|
||||
template3 = registry.get("test", "1.0")
|
||||
assert template3.content == "Modified"
|
||||
224
tests/test_models.py
Normal file
224
tests/test_models.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Tests for data models."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.models import (
|
||||
AgentConfig,
|
||||
AgentName,
|
||||
Finding,
|
||||
Policy,
|
||||
ReviewResult,
|
||||
Severity,
|
||||
Verdict,
|
||||
)
|
||||
|
||||
|
||||
class TestEnums:
|
||||
def test_severity_values(self) -> None:
|
||||
assert Severity.CRITICAL == "critical"
|
||||
assert Severity.HIGH == "high"
|
||||
assert Severity.MEDIUM == "medium"
|
||||
assert Severity.LOW == "low"
|
||||
assert Severity.INFO == "info"
|
||||
|
||||
def test_verdict_values(self) -> None:
|
||||
assert Verdict.APPROVE == "approve"
|
||||
assert Verdict.REQUEST_CHANGES == "request_changes"
|
||||
assert Verdict.COMMENT == "comment"
|
||||
|
||||
def test_agent_name_values(self) -> None:
|
||||
assert AgentName.SECURITY == "security"
|
||||
assert AgentName.STYLE == "style"
|
||||
assert AgentName.COMPLEXITY == "complexity"
|
||||
|
||||
def test_severity_from_string(self) -> None:
|
||||
assert Severity("critical") == Severity.CRITICAL
|
||||
assert Severity("high") == Severity.HIGH
|
||||
|
||||
|
||||
class TestFinding:
|
||||
def test_finding_creation(self) -> None:
|
||||
finding = Finding(
|
||||
id="test-123",
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input concatenated in SQL query",
|
||||
reasoning="Allows attackers to execute arbitrary SQL",
|
||||
suggestion="Use parameterized queries",
|
||||
references=["https://owasp.org"],
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
assert finding.id == "test-123"
|
||||
assert finding.agent == AgentName.SECURITY
|
||||
assert finding.severity == Severity.HIGH
|
||||
assert finding.confidence == 0.9
|
||||
|
||||
def test_finding_confidence_validation(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
Finding(
|
||||
id="test",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=1,
|
||||
line_end=1,
|
||||
severity=Severity.INFO,
|
||||
confidence=1.5,
|
||||
title="Test",
|
||||
description="Test",
|
||||
reasoning="Test",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
def test_finding_line_validation(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
Finding(
|
||||
id="test",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=0,
|
||||
line_end=1,
|
||||
severity=Severity.INFO,
|
||||
confidence=0.5,
|
||||
title="Test",
|
||||
description="Test",
|
||||
reasoning="Test",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
|
||||
def test_finding_serialization(self) -> None:
|
||||
finding = Finding(
|
||||
id="test-123",
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="Test",
|
||||
description="Test desc",
|
||||
reasoning="Test reason",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
data = finding.model_dump()
|
||||
assert data["id"] == "test-123"
|
||||
assert data["agent"] == "security"
|
||||
assert data["severity"] == "high"
|
||||
|
||||
|
||||
class TestAgentConfig:
|
||||
def test_default_values(self) -> None:
|
||||
config = AgentConfig()
|
||||
assert config.enabled is True
|
||||
assert config.model is None
|
||||
assert config.severity_threshold == Severity.INFO
|
||||
assert config.prompt_additions is None
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
config = AgentConfig(
|
||||
enabled=False,
|
||||
model="gpt-4o",
|
||||
severity_threshold=Severity.MEDIUM,
|
||||
prompt_additions="Focus on auth",
|
||||
)
|
||||
assert config.enabled is False
|
||||
assert config.model == "gpt-4o"
|
||||
assert config.severity_threshold == Severity.MEDIUM
|
||||
|
||||
|
||||
class TestPolicy:
|
||||
def test_default_policy(self) -> None:
|
||||
policy = Policy()
|
||||
assert len(policy.agents) == 3
|
||||
assert AgentName.SECURITY in policy.agents
|
||||
assert AgentName.STYLE in policy.agents
|
||||
assert AgentName.COMPLEXITY in policy.agents
|
||||
|
||||
def test_get_enabled_agents(self) -> None:
|
||||
policy = Policy()
|
||||
enabled = policy.get_enabled_agents()
|
||||
assert len(enabled) == 3
|
||||
assert AgentName.SECURITY in enabled
|
||||
|
||||
def test_get_enabled_agents_with_disabled(self) -> None:
|
||||
policy = Policy(
|
||||
agents={
|
||||
AgentName.SECURITY: AgentConfig(enabled=True),
|
||||
AgentName.STYLE: AgentConfig(enabled=False),
|
||||
AgentName.COMPLEXITY: AgentConfig(enabled=True),
|
||||
}
|
||||
)
|
||||
enabled = policy.get_enabled_agents()
|
||||
assert len(enabled) == 2
|
||||
assert AgentName.STYLE not in enabled
|
||||
|
||||
def test_load_from_yaml(self, tmp_path: Path) -> None:
|
||||
policy_file = tmp_path / "policy.yaml"
|
||||
policy_file.write_text("""
|
||||
version: "1.1"
|
||||
agents:
|
||||
security:
|
||||
enabled: true
|
||||
model: gpt-4o
|
||||
severity_threshold: high
|
||||
style:
|
||||
enabled: false
|
||||
complexity:
|
||||
enabled: true
|
||||
""")
|
||||
policy = Policy.load(policy_file)
|
||||
assert policy.version == "1.1"
|
||||
assert policy.agents[AgentName.SECURITY].model == "gpt-4o"
|
||||
assert policy.agents[AgentName.SECURITY].severity_threshold == Severity.HIGH
|
||||
assert policy.agents[AgentName.STYLE].enabled is False
|
||||
|
||||
def test_load_empty_yaml(self, tmp_path: Path) -> None:
|
||||
policy_file = tmp_path / "policy.yaml"
|
||||
policy_file.write_text("")
|
||||
policy = Policy.load(policy_file)
|
||||
assert policy.version == "1.0"
|
||||
assert len(policy.agents) == 3
|
||||
|
||||
|
||||
class TestReviewResult:
|
||||
def test_review_result_creation(self) -> None:
|
||||
result = ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=1000,
|
||||
tokens_used=500,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
assert result.agent_name == AgentName.SECURITY
|
||||
assert result.duration_ms == 1000
|
||||
assert result.cost_usd == 0.01
|
||||
|
||||
def test_review_result_with_findings(self) -> None:
|
||||
finding = Finding(
|
||||
id="test-123",
|
||||
agent=AgentName.SECURITY,
|
||||
file="test.py",
|
||||
line_start=1,
|
||||
line_end=1,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="Test",
|
||||
description="Test",
|
||||
reasoning="Test",
|
||||
prompt_version="test-v1.0",
|
||||
)
|
||||
result = ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[finding],
|
||||
duration_ms=1000,
|
||||
tokens_used=500,
|
||||
cost_usd=0.01,
|
||||
)
|
||||
assert len(result.findings) == 1
|
||||
assert result.findings[0].id == "test-123"
|
||||
Reference in New Issue
Block a user