tests for api, worker, cache

This commit is contained in:
2025-03-22 11:04:46 +00:00
parent 41ab2f04df
commit 9a23e2c9c4
5 changed files with 2655 additions and 0 deletions

1153
tests/test_api.py Normal file

File diff suppressed because it is too large Load Diff

164
tests/test_cache.py Normal file
View File

@@ -0,0 +1,164 @@
"""Tests for the LLM cache module."""
import pytest
from arbiter.llm.cache import LLMCache, compute_policy_hash
from arbiter.llm.client import LLMResponse
class TestComputePolicyHash:
def test_compute_policy_hash_deterministic(self) -> None:
policy = {"agents": {"security": {"enabled": True}}}
hash1 = compute_policy_hash(policy)
hash2 = compute_policy_hash(policy)
assert hash1 == hash2
def test_policy_hash_varies(self) -> None:
policy1 = {"agents": {"security": {"enabled": True}}}
policy2 = {"agents": {"security": {"enabled": False}}}
assert compute_policy_hash(policy1) != compute_policy_hash(policy2)
def test_compute_policy_hash_format(self) -> None:
policy = {"test": "data"}
hash_value = compute_policy_hash(policy)
assert len(hash_value) == 16
assert all(c in "0123456789abcdef" for c in hash_value)
class MockRedisForCache:
"""Mock Redis client for cache testing."""
def __init__(self) -> None:
self._data: dict[str, str] = {}
async def get(self, key: str) -> str | None:
return self._data.get(key)
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
self._data[key] = value
return True
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
return 1
return 0
def scan_iter(self, match: str | None = None): # noqa: ARG002
async def _gen():
for key in list(self._data.keys()):
yield key
return _gen()
class TestLLMCache:
@pytest.fixture
def cache(self) -> LLMCache:
mock_redis = MockRedisForCache()
return LLMCache(mock_redis) # type: ignore[arg-type]
def test_compute_key(self, cache: LLMCache) -> None:
key = cache._compute_key("diff content", "security", "v1.0", "policy123")
assert key.startswith("arbiter:llm:cache:")
assert len(key) > 20 # prefix + hash
def test_compute_key_deterministic(self, cache: LLMCache) -> None:
key1 = cache._compute_key("diff", "security", "v1.0")
key2 = cache._compute_key("diff", "security", "v1.0")
assert key1 == key2
def test_compute_key_unique(self, cache: LLMCache) -> None:
key1 = cache._compute_key("diff1", "security", "v1.0")
key2 = cache._compute_key("diff2", "security", "v1.0")
key3 = cache._compute_key("diff1", "style", "v1.0")
key4 = cache._compute_key("diff1", "security", "v2.0")
assert len({key1, key2, key3, key4}) == 4
def test_serialize_deserialize_response(self, cache: LLMCache) -> None:
response = LLMResponse(
content="test content",
model="gpt-4o",
tokens_in=100,
tokens_out=50,
cost_usd=0.01,
)
serialized = cache._serialize_response(response)
deserialized = cache._deserialize_response(serialized)
assert deserialized.content == response.content
assert deserialized.model == response.model
assert deserialized.tokens_in == response.tokens_in
assert deserialized.tokens_out == response.tokens_out
assert deserialized.cost_usd == response.cost_usd
async def test_cache_get_miss(self, cache: LLMCache) -> None:
result = await cache.get("diff", "security", "v1.0")
assert result is None
assert cache._misses == 1
assert cache._hits == 0
async def test_cache_set_and_get(self, cache: LLMCache) -> None:
response = LLMResponse(
content="cached content",
model="gpt-4o",
tokens_in=100,
tokens_out=50,
cost_usd=0.01,
)
await cache.set("diff", "security", "v1.0", response)
result = await cache.get("diff", "security", "v1.0")
assert result is not None
assert result.content == "cached content"
assert cache._hits == 1
async def test_cache_invalidate(self, cache: LLMCache) -> None:
response = LLMResponse(
content="test",
model="gpt-4o",
tokens_in=100,
tokens_out=50,
cost_usd=0.01,
)
await cache.set("diff", "security", "v1.0", response)
deleted = await cache.invalidate("diff", "security", "v1.0")
assert deleted is True
result = await cache.get("diff", "security", "v1.0")
assert result is None
async def test_cache_invalidate_nonexistent(self, cache: LLMCache) -> None:
deleted = await cache.invalidate("nonexistent", "security", "v1.0")
assert deleted is False
def test_get_stats(self, cache: LLMCache) -> None:
stats = cache.get_stats()
assert stats["hits"] == 0
assert stats["misses"] == 0
assert stats["total"] == 0
assert stats["hit_rate"] == 0.0
async def test_get_stats_after_operations(self, cache: LLMCache) -> None:
await cache.get("key1", "agent", "v1") # miss
await cache.get("key2", "agent", "v1") # miss
response = LLMResponse(
content="test",
model="gpt-4o",
tokens_in=100,
tokens_out=50,
cost_usd=0.01,
)
await cache.set("key1", "agent", "v1", response)
await cache.get("key1", "agent", "v1") # hit
stats = cache.get_stats()
assert stats["hits"] == 1
assert stats["misses"] == 2
assert stats["total"] == 3
assert stats["hit_rate"] == pytest.approx(1 / 3)

157
tests/test_cost.py Normal file
View File

@@ -0,0 +1,157 @@
"""Tests for the cost tracking module."""
import pytest
from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost
from arbiter.models.enums import AgentName
class TestAgentCost:
def test_agent_cost_creation(self) -> None:
cost = AgentCost(
agent=AgentName.SECURITY,
tokens_in=100,
tokens_out=50,
total_tokens=150,
cost_usd=0.01,
)
assert cost.agent == AgentName.SECURITY
assert cost.total_tokens == 150
assert cost.cost_usd == 0.01
def test_agent_cost_defaults(self) -> None:
cost = AgentCost(agent=AgentName.STYLE)
assert cost.tokens_in == 0
assert cost.tokens_out == 0
assert cost.total_tokens == 0
assert cost.cost_usd == 0.0
class TestReviewCost:
def test_review_cost_defaults(self) -> None:
cost = ReviewCost()
assert cost.total_tokens == 0
assert cost.total_cost_usd == 0.0
assert cost.agent_costs == []
assert cost.cache_hits == 0
assert cost.cache_misses == 0
def test_add_agent_cost(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008)
assert len(cost.agent_costs) == 2
assert cost.total_tokens_in == 180
assert cost.total_tokens_out == 90
assert cost.total_tokens == 270
assert cost.total_cost_usd == pytest.approx(0.018)
def test_add_deliberation_cost(self) -> None:
cost = ReviewCost()
cost.add_deliberation_cost(tokens_in=50, tokens_out=100, cost_usd=0.005)
assert cost.deliberation_tokens_in == 50
assert cost.deliberation_tokens_out == 100
assert cost.deliberation_cost_usd == 0.005
assert cost.total_tokens == 150
def test_combined_costs(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
cost.add_deliberation_cost(tokens_in=50, tokens_out=25, cost_usd=0.005)
assert cost.total_tokens_in == 150
assert cost.total_tokens_out == 75
assert cost.total_tokens == 225
assert cost.total_cost_usd == pytest.approx(0.015)
def test_to_agent_dict(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008)
agent_dict = cost.to_agent_dict()
assert agent_dict == {"security": 150, "style": 120}
def test_to_cost_dict(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008)
cost_dict = cost.to_cost_dict()
assert cost_dict == {"security": 0.01, "style": 0.008}
def test_is_within_budget_true(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
assert cost.is_within_budget(max_tokens=1000, max_cost_usd=0.50) is True
def test_is_within_budget_false_tokens(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=1000, tokens_out=500, cost_usd=0.01)
assert cost.is_within_budget(max_tokens=1000, max_cost_usd=0.50) is False
def test_is_within_budget_false_cost(self) -> None:
cost = ReviewCost()
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=1.0)
assert cost.is_within_budget(max_tokens=10000, max_cost_usd=0.50) is False
class TestCostEstimate:
def test_estimate_small_diff(self) -> None:
estimate = CostEstimate.estimate(
diff_size=1000,
agents=[AgentName.SECURITY, AgentName.STYLE],
model="gpt-4o-mini",
)
assert estimate.estimated_tokens > 0
assert estimate.estimated_cost_usd > 0
assert estimate.agents_enabled == [AgentName.SECURITY, AgentName.STYLE]
assert estimate.model == "gpt-4o-mini"
assert estimate.within_budget is True
def test_estimate_large_diff(self) -> None:
estimate = CostEstimate.estimate(
diff_size=100000,
agents=[AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY],
model="gpt-4o",
max_tokens=10000,
max_cost_usd=0.10,
)
# Large diff with expensive model should exceed budget
assert estimate.within_budget is False
def test_estimate_gpt4o_vs_mini(self) -> None:
estimate_4o = CostEstimate.estimate(
diff_size=10000,
agents=[AgentName.SECURITY],
model="gpt-4o",
)
estimate_mini = CostEstimate.estimate(
diff_size=10000,
agents=[AgentName.SECURITY],
model="gpt-4o-mini",
)
assert estimate_4o.estimated_cost_usd > estimate_mini.estimated_cost_usd
def test_estimate_more_agents_higher_cost(self) -> None:
estimate_one = CostEstimate.estimate(
diff_size=5000,
agents=[AgentName.SECURITY],
model="gpt-4o",
)
estimate_three = CostEstimate.estimate(
diff_size=5000,
agents=[AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY],
model="gpt-4o",
)
assert estimate_three.estimated_tokens > estimate_one.estimated_tokens
assert estimate_three.estimated_cost_usd > estimate_one.estimated_cost_usd

222
tests/test_db_models.py Normal file
View File

@@ -0,0 +1,222 @@
"""Tests for database models."""
from uuid import uuid4
from arbiter.db.models import (
Base,
ConflictModel,
DeliberationStepModel,
FindingModel,
PolicyModel,
ReviewModel,
)
from arbiter.deliberation.conflicts import ConflictNature
from arbiter.deliberation.coordinator import StepType
from arbiter.models.enums import AgentName, Severity, Verdict
class TestReviewModel:
def test_review_model_creation(self) -> None:
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo",
pr_number=42,
pr_title="Test PR",
base_sha="abc1234567890123456789012345678901234567",
head_sha="def1234567890123456789012345678901234567",
author="testuser",
is_draft=False,
status="pending",
)
assert review.repository == "owner/repo"
assert review.pr_number == 42
assert review.status == "pending"
assert review.is_draft is False
def test_review_model_with_verdict(self) -> None:
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo",
pr_number=1,
base_sha="a" * 40,
head_sha="b" * 40,
status="completed",
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Found some issues",
)
assert review.verdict == Verdict.COMMENT
assert review.verdict_confidence == 0.75
def test_review_model_cost_tracking(self) -> None:
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo",
pr_number=1,
base_sha="a" * 40,
head_sha="b" * 40,
total_tokens=1500,
total_cost_usd=0.015,
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
)
assert review.total_tokens == 1500
assert review.total_cost_usd == 0.015
assert review.tokens_by_agent["security"] == 500
class TestFindingModel:
def test_finding_model_creation(self) -> None:
finding = FindingModel(
id=str(uuid4()),
review_id=str(uuid4()),
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input concatenated in SQL",
reasoning="String concatenation allows injection",
suggestion="Use parameterized queries",
references=["https://owasp.org"],
prompt_version="security-v1.0",
)
assert finding.agent == AgentName.SECURITY
assert finding.severity == Severity.HIGH
assert finding.confidence == 0.9
assert finding.line_start == 10
assert finding.line_end == 15
class TestConflictModel:
def test_conflict_model_creation(self) -> None:
conflict = ConflictModel(
id=str(uuid4()),
review_id=str(uuid4()),
finding_ids=["finding-1", "finding-2"],
nature=ConflictNature.TRADE_OFF,
description="Security vs simplicity trade-off",
severity_weight=0.7,
)
assert conflict.nature == ConflictNature.TRADE_OFF
assert len(conflict.finding_ids) == 2
assert conflict.severity_weight == 0.7
def test_conflict_model_with_resolution(self) -> None:
conflict = ConflictModel(
id=str(uuid4()),
review_id=str(uuid4()),
finding_ids=["finding-1", "finding-2"],
nature=ConflictNature.CONTRADICTORY,
description="Opposing recommendations",
severity_weight=0.8,
resolution="Security takes precedence",
winning_finding_id="finding-1",
)
assert conflict.resolution is not None
assert conflict.winning_finding_id == "finding-1"
class TestDeliberationStepModel:
def test_deliberation_step_creation(self) -> None:
step = DeliberationStepModel(
id=str(uuid4()),
review_id=str(uuid4()),
step_type=StepType.MERGE,
description="Merged 5 findings",
details={"groups": 3, "unique": 5},
sequence=0,
)
assert step.step_type == StepType.MERGE
assert step.sequence == 0
assert step.details["groups"] == 3
def test_all_step_types(self) -> None:
review_id = str(uuid4())
steps = [
DeliberationStepModel(
id=str(uuid4()),
review_id=review_id,
step_type=StepType.MERGE,
description="Merge step",
sequence=0,
),
DeliberationStepModel(
id=str(uuid4()),
review_id=review_id,
step_type=StepType.CONFLICT_DETECTION,
description="Conflict detection step",
sequence=1,
),
DeliberationStepModel(
id=str(uuid4()),
review_id=review_id,
step_type=StepType.SYNTHESIS,
description="Synthesis step",
sequence=2,
),
DeliberationStepModel(
id=str(uuid4()),
review_id=review_id,
step_type=StepType.VERDICT,
description="Verdict step",
sequence=3,
),
]
assert len(steps) == 4
assert steps[0].step_type == StepType.MERGE
assert steps[3].step_type == StepType.VERDICT
class TestPolicyModel:
def test_policy_model_creation(self) -> None:
policy = PolicyModel(
id=str(uuid4()),
name="default",
organization="test-org",
description="Default policy",
is_active=True,
)
assert policy.name == "default"
assert policy.organization == "test-org"
assert policy.is_active is True
def test_policy_model_with_config(self) -> None:
policy = PolicyModel(
id=str(uuid4()),
name="strict",
agents_config={
"security": {"enabled": True, "model": "gpt-4o"},
"style": {"enabled": True},
"complexity": {"enabled": False},
},
cost_controls={
"max_tokens": 50000,
"max_cost_usd": 0.50,
},
verdict_thresholds={
"critical_threshold": 1,
"high_threshold": 3,
},
)
assert policy.agents_config["security"]["model"] == "gpt-4o"
assert policy.cost_controls["max_tokens"] == 50000
class TestBase:
def test_base_is_declarative_base(self) -> None:
assert hasattr(Base, "metadata")
assert hasattr(Base, "registry")

959
tests/test_worker.py Normal file
View File

@@ -0,0 +1,959 @@
"""Tests for the worker module."""
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from arbiter.db.models import (
ConflictModel,
DeliberationStepModel,
FindingModel,
ReviewModel,
)
from arbiter.integrations import ARBITER_MARKER
from arbiter.integrations.base import Comment, CommitStatus, Platform
from arbiter.models.enums import AgentName, Severity, Verdict
from arbiter.worker.queue import JobPriority, cancel_job, generate_job_id, get_job_status
from arbiter.worker.tasks import (
_post_or_update_comment,
_verdict_to_status,
detect_platform,
get_platform_client,
process_followup,
process_review,
)
from tests.conftest import MockPlatformClient
class TestJobQueue:
def test_generate_job_id_deterministic(self) -> None:
id1 = generate_job_id("owner/repo", 42, "abc123")
id2 = generate_job_id("owner/repo", 42, "abc123")
assert id1 == id2
def test_job_id_unique(self) -> None:
id1 = generate_job_id("owner/repo", 42, "abc123")
id2 = generate_job_id("owner/repo", 42, "def456") # Different SHA
id3 = generate_job_id("owner/repo", 43, "abc123") # Different PR
id4 = generate_job_id("other/repo", 42, "abc123") # Different repo
assert len({id1, id2, id3, id4}) == 4 # All unique
def test_generate_job_id_format(self) -> None:
job_id = generate_job_id("owner/repo", 42, "abc123")
assert len(job_id) == 16
assert all(c in "0123456789abcdef" for c in job_id)
def test_job_priority_ordering(self) -> None:
assert JobPriority.HIGH < JobPriority.NORMAL < JobPriority.LOW
assert int(JobPriority.HIGH) == 1
assert int(JobPriority.NORMAL) == 2
assert int(JobPriority.LOW) == 3
class TestWorkerSettings:
def test_worker_settings_has_functions(self) -> None:
from arbiter.worker.settings import WorkerSettings
assert WorkerSettings.functions is not None
assert len(WorkerSettings.functions) > 0
def test_worker_settings_has_cron_jobs(self) -> None:
from arbiter.worker.settings import WorkerSettings
assert WorkerSettings.cron_jobs is not None
def test_worker_settings_lifecycle_hooks(self) -> None:
from arbiter.worker.settings import WorkerSettings
assert WorkerSettings.on_startup is not None
assert WorkerSettings.on_shutdown is not None
class TestReviewTask:
@pytest.fixture
def mock_context(self) -> dict[str, Any]:
return {
"settings": None,
"redis": None,
}
async def test_review_task_requires_diff(self, mock_context: dict[str, Any]) -> None: # noqa: ARG002
from arbiter.worker.tasks import process_review
# Note: This would need database fixtures to run fully
# For now, we just verify the function signature
assert callable(process_review)
class MockRedisForQueue:
"""Mock Redis with enqueue_job support."""
def __init__(self) -> None:
self._data: dict[str, Any] = {}
self._jobs: list[dict[str, Any]] = []
async def get(self, key: str) -> str | None:
return self._data.get(key)
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
self._data[key] = value
return True
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
return 1
return 0
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any:
job = {"func": func_name, "kwargs": kwargs}
self._jobs.append(job)
return type("Job", (), {"job_id": kwargs.get("_job_id", "test-id")})()
class TestEnqueueReview:
@pytest.fixture
def mock_redis_pool(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
mock = MockRedisForQueue()
async def get_pool() -> MockRedisForQueue:
return mock
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
return mock
async def test_enqueue_review_creates_job(self, mock_redis_pool: MockRedisForQueue) -> None:
from arbiter.worker.queue import enqueue_review
job_id = await enqueue_review(
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
pr_title="Test PR",
author="testuser",
is_draft=False,
)
assert job_id is not None
assert len(mock_redis_pool._jobs) == 1
assert mock_redis_pool._jobs[0]["func"] == "process_review"
async def test_enqueue_review_deduplication(
self,
mock_redis_pool: MockRedisForQueue, # noqa: ARG002
) -> None:
from arbiter.worker.queue import enqueue_review
# First call should succeed
job_id1 = await enqueue_review(
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
)
assert job_id1 is not None
# Second call with same params should be deduplicated
job_id2 = await enqueue_review(
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
)
assert job_id2 is None
async def test_enqueue_review_draft_lower_priority(
self, mock_redis_pool: MockRedisForQueue
) -> None:
from arbiter.worker.queue import enqueue_review
await enqueue_review(
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
is_draft=True,
)
assert len(mock_redis_pool._jobs) == 1
job = mock_redis_pool._jobs[0]
# Draft PRs should be in the low priority queue
assert "arbiter:queue:3" in job["kwargs"]["_queue_name"]
class TestJobStatusAndCancel:
@pytest.fixture
def mock_redis_pool_with_jobs(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
mock = MockRedisForQueue()
mock._data["arbiter:job:test-job-id"] = "pending"
async def get_pool() -> MockRedisForQueue:
return mock
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
return mock
async def test_get_job_status_found(
self,
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
) -> None:
status = await get_job_status("test-job-id")
assert status is not None
assert status["job_id"] == "test-job-id"
assert status["status"] == "pending"
async def test_get_job_status_not_found(
self,
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
) -> None:
status = await get_job_status("nonexistent")
assert status is None
async def test_cancel_job_success(
self,
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
) -> None:
result = await cancel_job("test-job-id")
assert result is True
async def test_cancel_job_not_found(
self,
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
) -> None:
result = await cancel_job("nonexistent")
assert result is False
class TestWorkerStartupShutdown:
async def test_startup_hook(self) -> None:
from unittest.mock import AsyncMock, patch
from arbiter.worker.settings import startup
# Mock init_db and get_settings
with (
patch("arbiter.worker.settings.init_db", new_callable=AsyncMock) as mock_init,
patch("arbiter.worker.settings.get_settings") as mock_settings,
):
mock_settings.return_value = "mock_settings"
ctx: dict[str, Any] = {}
await startup(ctx)
mock_init.assert_called_once()
assert ctx["settings"] == "mock_settings"
async def test_shutdown_hook(self) -> None:
from unittest.mock import AsyncMock, patch
from arbiter.worker.settings import shutdown
with patch("arbiter.worker.settings.close_db", new_callable=AsyncMock) as mock_close:
ctx: dict[str, Any] = {}
await shutdown(ctx)
mock_close.assert_called_once()
async def test_health_check(self) -> None:
from arbiter.worker.settings import health_check
ctx: dict[str, Any] = {}
result = await health_check(ctx)
assert result == "healthy"
def test_worker_settings_redis_settings(self) -> None:
from arbiter.worker.settings import WorkerSettings
redis_settings = WorkerSettings.redis_settings()
assert redis_settings is not None
def test_worker_settings_get_functions(self) -> None:
from arbiter.worker.settings import WorkerSettings
functions = WorkerSettings._get_functions()
assert len(functions) == 2
# Verify the functions are the expected ones
func_names = [f.__name__ for f in functions]
assert "process_review" in func_names
assert "process_followup" in func_names
class TestDetectPlatform:
def test_detect_platform_from_webhook_github(self) -> None:
platform = detect_platform("owner/repo", "github")
assert platform == Platform.GITHUB
def test_detect_platform_from_webhook_gitlab(self) -> None:
platform = detect_platform("owner/repo", "gitlab")
assert platform == Platform.GITLAB
def test_detect_platform_case_insensitive(self) -> None:
assert detect_platform("owner/repo", "GITHUB") == Platform.GITHUB
assert detect_platform("owner/repo", "GitLab") == Platform.GITLAB
def test_detect_platform_defaults_to_github(self) -> None:
platform = detect_platform("owner/repo")
assert platform == Platform.GITHUB
class TestGetPlatformClient:
def test_get_platform_client_github_no_token(self, mock_settings_no_github: Any) -> None:
client = get_platform_client(Platform.GITHUB, mock_settings_no_github)
assert client is None
def test_get_platform_client_gitlab_no_token(self, mock_settings_no_github: Any) -> None:
client = get_platform_client(Platform.GITLAB, mock_settings_no_github)
assert client is None
def test_github_client_with_token(self, mock_settings: Any) -> None:
from arbiter.integrations import GitHubClient
client = get_platform_client(Platform.GITHUB, mock_settings)
assert client is not None
assert isinstance(client, GitHubClient)
def test_gitlab_client_with_token(self, mock_settings: Any) -> None:
from arbiter.integrations import GitLabClient
client = get_platform_client(Platform.GITLAB, mock_settings)
assert client is not None
assert isinstance(client, GitLabClient)
class TestVerdictToStatus:
def test_approve_returns_success(self) -> None:
assert _verdict_to_status(Verdict.APPROVE) == CommitStatus.SUCCESS
def test_request_changes_returns_failure(self) -> None:
assert _verdict_to_status(Verdict.REQUEST_CHANGES) == CommitStatus.FAILURE
def test_comment_returns_success(self) -> None:
assert _verdict_to_status(Verdict.COMMENT) == CommitStatus.SUCCESS
class TestPostOrUpdateComment:
async def test_post_new_comment(self, mock_platform_client: MockPlatformClient) -> None:
body = f"Test comment {ARBITER_MARKER}"
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
assert url is not None
assert "owner/repo" in url
assert len(mock_platform_client._posted_comments) == 1
assert mock_platform_client._posted_comments[0]["body"] == body
async def test_update_existing_comment(self, mock_platform_client: MockPlatformClient) -> None:
# Add an existing Arbiter comment
mock_platform_client._comments = [
Comment(
id="existing-123",
body=f"Old review {ARBITER_MARKER}",
author="arbiter-bot",
url="https://github.com/owner/repo/pull/42#comment-existing-123",
created_at=datetime.now(UTC),
)
]
body = f"Updated review {ARBITER_MARKER}"
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
assert url is not None
assert len(mock_platform_client._posted_comments) == 1
# Should be an update, not a new post
assert mock_platform_client._posted_comments[0].get("comment_id") == "existing-123"
async def test_fallback_on_fetch_failure(
self, mock_platform_client: MockPlatformClient
) -> None:
mock_platform_client._fail_on.add("get_comments")
body = f"Test comment {ARBITER_MARKER}"
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
# Should still post a new comment
assert url is not None
assert len(mock_platform_client._posted_comments) == 1
# Should be a new post since fetching failed
assert mock_platform_client._posted_comments[0].get("comment_id") is None
async def test_returns_none_on_post_failure(
self, mock_platform_client: MockPlatformClient
) -> None:
mock_platform_client._fail_on.add("post_comment")
body = f"Test comment {ARBITER_MARKER}"
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
assert url is None
class TestProcessReview:
@pytest.fixture
def mock_deliberation_result(self) -> Any:
from arbiter.deliberation import DeliberationResult, DeliberationStep
from arbiter.deliberation.conflicts import Conflict, ConflictNature
from arbiter.deliberation.coordinator import StepType
from arbiter.models import Finding
finding = Finding(
id=str(uuid4()),
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection",
description="User input concatenated into SQL",
reasoning="Allows SQL injection attacks",
prompt_version="security-v1.0",
)
return DeliberationResult(
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Found security issues",
findings=[finding],
conflicts=[
Conflict(
id="conflict-1",
finding_ids=["f1", "f2"],
nature=ConflictNature.TRADE_OFF,
description="Trade-off detected",
severity_weight=0.5,
)
],
steps=[
DeliberationStep(
step_type=StepType.MERGE,
timestamp=datetime.now(UTC),
description="Merged findings",
details={"count": 1},
)
],
tokens_used=500,
cost_usd=0.005,
)
@pytest.fixture
def mock_review_result(self) -> Any:
from arbiter.models import ReviewResult
return ReviewResult(
agent_name=AgentName.SECURITY,
findings=[],
duration_ms=1000,
tokens_used=500,
cost_usd=0.005,
)
async def test_process_review_creates_review_record(
self,
db_session: AsyncSession,
mock_deliberation_result: Any,
mock_review_result: Any,
) -> None:
# Mock the review pipeline
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
return_value=([mock_review_result], mock_deliberation_result),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
review_id = await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
pr_title="Test PR",
author="testuser",
diff_content="mock diff",
)
# Verify review was created
result = await db_session.execute(
select(ReviewModel).where(ReviewModel.id == review_id)
)
review = result.scalar_one()
assert review.repository == "owner/repo"
assert review.pr_number == 42
assert review.status == "completed"
assert review.verdict == Verdict.COMMENT
async def test_process_review_stores_findings(
self,
db_session: AsyncSession,
mock_deliberation_result: Any,
mock_review_result: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
return_value=([mock_review_result], mock_deliberation_result),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
review_id = await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
diff_content="mock diff",
)
# Verify findings were stored
result = await db_session.execute(
select(FindingModel).where(FindingModel.review_id == review_id)
)
findings = result.scalars().all()
assert len(findings) == 1
assert findings[0].title == "SQL Injection"
assert findings[0].severity == Severity.HIGH
async def test_process_review_stores_conflicts(
self,
db_session: AsyncSession,
mock_deliberation_result: Any,
mock_review_result: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
return_value=([mock_review_result], mock_deliberation_result),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
review_id = await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
diff_content="mock diff",
)
# Verify conflicts were stored
result = await db_session.execute(
select(ConflictModel).where(ConflictModel.review_id == review_id)
)
conflicts = result.scalars().all()
assert len(conflicts) == 1
assert conflicts[0].description == "Trade-off detected"
async def test_process_review_stores_deliberation_steps(
self,
db_session: AsyncSession,
mock_deliberation_result: Any,
mock_review_result: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
return_value=([mock_review_result], mock_deliberation_result),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
review_id = await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
diff_content="mock diff",
)
# Verify deliberation steps were stored
result = await db_session.execute(
select(DeliberationStepModel).where(DeliberationStepModel.review_id == review_id)
)
steps = result.scalars().all()
assert len(steps) == 1
assert steps[0].description == "Merged findings"
async def test_process_review_handles_errors(
self,
db_session: AsyncSession,
) -> None:
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
side_effect=ValueError("Test error"),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
with pytest.raises(ValueError, match="Test error"):
await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
diff_content="mock diff",
)
# Verify review was marked as failed
result = await db_session.execute(
select(ReviewModel).where(ReviewModel.repository == "owner/repo")
)
review = result.scalar_one()
assert review.status == "failed"
assert "Test error" in (review.error_message or "")
async def test_process_review_posts_comment(
self,
db_session: AsyncSession,
mock_platform_client: MockPlatformClient,
mock_deliberation_result: Any,
mock_review_result: Any,
mock_settings: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
return_value=([mock_review_result], mock_deliberation_result),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=mock_platform_client,
),
patch(
"arbiter.worker.tasks.get_settings",
return_value=mock_settings,
),
):
await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
diff_content="mock diff",
platform="github",
)
# Verify comment was posted
assert len(mock_platform_client._posted_comments) == 1
assert ARBITER_MARKER in mock_platform_client._posted_comments[0]["body"]
async def test_process_review_updates_status(
self,
db_session: AsyncSession,
mock_platform_client: MockPlatformClient,
mock_deliberation_result: Any,
mock_review_result: Any,
mock_settings: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks._run_review_pipeline",
return_value=([mock_review_result], mock_deliberation_result),
),
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=mock_platform_client,
),
patch(
"arbiter.worker.tasks.get_settings",
return_value=mock_settings,
),
):
await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
diff_content="mock diff",
platform="github",
)
# Verify status was updated (pending then final)
assert len(mock_platform_client._status_updates) >= 1
# Last update should be the final status
final_update = mock_platform_client._status_updates[-1]
assert final_update["status"] == CommitStatus.SUCCESS
async def test_process_review_requires_diff(
self,
db_session: AsyncSession,
) -> None:
with (
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
pytest.raises(ValueError, match="diff_content not provided"),
):
await process_review(
{},
repository="owner/repo",
pr_number=42,
base_sha="abc123",
head_sha="def456",
# No diff_content
)
class TestProcessFollowup:
async def test_process_followup_no_review(
self,
db_session: AsyncSession,
mock_settings: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_settings",
return_value=mock_settings,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
result = await process_followup(
{},
repository="owner/repo",
pr_number=999, # Non-existent PR
comment_id="comment-123",
comment_body="Why is this a security issue?",
author="testuser",
platform="github",
)
assert result is None
async def test_process_followup_disabled(
self,
db_session: AsyncSession,
completed_review_fixture: ReviewModel, # noqa: ARG002
) -> None:
class DisabledSettings:
followup_enabled = False
with (
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_settings",
return_value=DisabledSettings(),
),
):
result = await process_followup(
{},
repository="owner/repo",
pr_number=42,
comment_id="comment-123",
comment_body="Why is this a security issue?",
author="testuser",
platform="github",
)
assert result is None
async def test_process_followup_not_a_question(
self,
db_session: AsyncSession,
completed_review_fixture: ReviewModel, # noqa: ARG002
mock_settings: Any,
) -> None:
with (
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_settings",
return_value=mock_settings,
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
result = await process_followup(
{},
repository="owner/repo",
pr_number=42,
comment_id="comment-123",
comment_body="This looks good to me.", # Not a question
author="testuser",
platform="github",
)
assert result is None
async def test_process_followup_low_confidence(
self,
db_session: AsyncSession,
completed_review_fixture: ReviewModel, # noqa: ARG002
) -> None:
class HighThresholdSettings:
followup_enabled = True
followup_confidence_threshold = 0.99 # Very high threshold
llm_timeout = 60
llm_max_retries = 3
templates_dir = Path("templates")
post_comments = False
with (
patch(
"arbiter.worker.tasks.async_session_factory",
return_value=lambda: db_session,
),
patch(
"arbiter.worker.tasks.get_settings",
return_value=HighThresholdSettings(),
),
patch(
"arbiter.worker.tasks.get_platform_client",
return_value=None,
),
):
result = await process_followup(
{},
repository="owner/repo",
pr_number=42,
comment_id="comment-123",
comment_body="What does this mean?", # A question but low confidence
author="testuser",
platform="github",
)
assert result is None
class TestEnqueueFollowup:
@pytest.fixture
def mock_redis_pool_followup(self, monkeypatch: pytest.MonkeyPatch) -> "MockRedisForQueue":
mock = MockRedisForQueue()
async def get_pool() -> MockRedisForQueue:
return mock
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
return mock
async def test_enqueue_followup_creates_job(
self, mock_redis_pool_followup: "MockRedisForQueue"
) -> None:
from arbiter.worker.queue import enqueue_followup
job_id = await enqueue_followup(
repository="owner/repo",
pr_number=42,
comment_id="comment-123",
comment_body="Why is this a security issue?",
author="testuser",
platform="github",
)
assert job_id is not None
assert len(mock_redis_pool_followup._jobs) == 1
assert mock_redis_pool_followup._jobs[0]["func"] == "process_followup"
async def test_enqueue_followup_deduplication(
self, mock_redis_pool_followup: "MockRedisForQueue"
) -> None:
from arbiter.worker.queue import enqueue_followup, generate_followup_job_id
# Pre-set the job as existing
job_id = generate_followup_job_id("owner/repo", 42, "comment-123")
mock_redis_pool_followup._data[f"arbiter:followup:{job_id}"] = "pending"
result = await enqueue_followup(
repository="owner/repo",
pr_number=42,
comment_id="comment-123",
comment_body="Why is this a security issue?",
author="testuser",
platform="github",
)
assert result is None
# No new job should be added
assert len(mock_redis_pool_followup._jobs) == 0
class TestGenerateFollowupJobId:
def test_followup_job_id_stable(self) -> None:
from arbiter.worker.queue import generate_followup_job_id
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
id2 = generate_followup_job_id("owner/repo", 42, "comment-123")
assert id1 == id2
def test_generate_followup_job_id_unique(self) -> None:
from arbiter.worker.queue import generate_followup_job_id
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
id2 = generate_followup_job_id("owner/repo", 42, "comment-456")
id3 = generate_followup_job_id("owner/repo", 43, "comment-123")
id4 = generate_followup_job_id("other/repo", 42, "comment-123")
assert len({id1, id2, id3, id4}) == 4