tests for api, worker, cache
This commit is contained in:
1153
tests/test_api.py
Normal file
1153
tests/test_api.py
Normal file
File diff suppressed because it is too large
Load Diff
164
tests/test_cache.py
Normal file
164
tests/test_cache.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Tests for the LLM cache module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbiter.llm.cache import LLMCache, compute_policy_hash
|
||||||
|
from arbiter.llm.client import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputePolicyHash:
|
||||||
|
def test_compute_policy_hash_deterministic(self) -> None:
|
||||||
|
policy = {"agents": {"security": {"enabled": True}}}
|
||||||
|
hash1 = compute_policy_hash(policy)
|
||||||
|
hash2 = compute_policy_hash(policy)
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
def test_policy_hash_varies(self) -> None:
|
||||||
|
policy1 = {"agents": {"security": {"enabled": True}}}
|
||||||
|
policy2 = {"agents": {"security": {"enabled": False}}}
|
||||||
|
assert compute_policy_hash(policy1) != compute_policy_hash(policy2)
|
||||||
|
|
||||||
|
def test_compute_policy_hash_format(self) -> None:
|
||||||
|
policy = {"test": "data"}
|
||||||
|
hash_value = compute_policy_hash(policy)
|
||||||
|
assert len(hash_value) == 16
|
||||||
|
assert all(c in "0123456789abcdef" for c in hash_value)
|
||||||
|
|
||||||
|
|
||||||
|
class MockRedisForCache:
|
||||||
|
"""Mock Redis client for cache testing."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._data: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
return self._data.get(key)
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||||
|
self._data[key] = value
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> int:
|
||||||
|
if key in self._data:
|
||||||
|
del self._data[key]
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def scan_iter(self, match: str | None = None): # noqa: ARG002
|
||||||
|
async def _gen():
|
||||||
|
for key in list(self._data.keys()):
|
||||||
|
yield key
|
||||||
|
|
||||||
|
return _gen()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMCache:
|
||||||
|
@pytest.fixture
|
||||||
|
def cache(self) -> LLMCache:
|
||||||
|
mock_redis = MockRedisForCache()
|
||||||
|
return LLMCache(mock_redis) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_compute_key(self, cache: LLMCache) -> None:
|
||||||
|
key = cache._compute_key("diff content", "security", "v1.0", "policy123")
|
||||||
|
assert key.startswith("arbiter:llm:cache:")
|
||||||
|
assert len(key) > 20 # prefix + hash
|
||||||
|
|
||||||
|
def test_compute_key_deterministic(self, cache: LLMCache) -> None:
|
||||||
|
key1 = cache._compute_key("diff", "security", "v1.0")
|
||||||
|
key2 = cache._compute_key("diff", "security", "v1.0")
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
def test_compute_key_unique(self, cache: LLMCache) -> None:
|
||||||
|
key1 = cache._compute_key("diff1", "security", "v1.0")
|
||||||
|
key2 = cache._compute_key("diff2", "security", "v1.0")
|
||||||
|
key3 = cache._compute_key("diff1", "style", "v1.0")
|
||||||
|
key4 = cache._compute_key("diff1", "security", "v2.0")
|
||||||
|
|
||||||
|
assert len({key1, key2, key3, key4}) == 4
|
||||||
|
|
||||||
|
def test_serialize_deserialize_response(self, cache: LLMCache) -> None:
|
||||||
|
response = LLMResponse(
|
||||||
|
content="test content",
|
||||||
|
model="gpt-4o",
|
||||||
|
tokens_in=100,
|
||||||
|
tokens_out=50,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = cache._serialize_response(response)
|
||||||
|
deserialized = cache._deserialize_response(serialized)
|
||||||
|
|
||||||
|
assert deserialized.content == response.content
|
||||||
|
assert deserialized.model == response.model
|
||||||
|
assert deserialized.tokens_in == response.tokens_in
|
||||||
|
assert deserialized.tokens_out == response.tokens_out
|
||||||
|
assert deserialized.cost_usd == response.cost_usd
|
||||||
|
|
||||||
|
async def test_cache_get_miss(self, cache: LLMCache) -> None:
|
||||||
|
result = await cache.get("diff", "security", "v1.0")
|
||||||
|
assert result is None
|
||||||
|
assert cache._misses == 1
|
||||||
|
assert cache._hits == 0
|
||||||
|
|
||||||
|
async def test_cache_set_and_get(self, cache: LLMCache) -> None:
|
||||||
|
response = LLMResponse(
|
||||||
|
content="cached content",
|
||||||
|
model="gpt-4o",
|
||||||
|
tokens_in=100,
|
||||||
|
tokens_out=50,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set("diff", "security", "v1.0", response)
|
||||||
|
result = await cache.get("diff", "security", "v1.0")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "cached content"
|
||||||
|
assert cache._hits == 1
|
||||||
|
|
||||||
|
async def test_cache_invalidate(self, cache: LLMCache) -> None:
|
||||||
|
response = LLMResponse(
|
||||||
|
content="test",
|
||||||
|
model="gpt-4o",
|
||||||
|
tokens_in=100,
|
||||||
|
tokens_out=50,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set("diff", "security", "v1.0", response)
|
||||||
|
deleted = await cache.invalidate("diff", "security", "v1.0")
|
||||||
|
assert deleted is True
|
||||||
|
|
||||||
|
result = await cache.get("diff", "security", "v1.0")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_cache_invalidate_nonexistent(self, cache: LLMCache) -> None:
|
||||||
|
deleted = await cache.invalidate("nonexistent", "security", "v1.0")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
def test_get_stats(self, cache: LLMCache) -> None:
|
||||||
|
stats = cache.get_stats()
|
||||||
|
assert stats["hits"] == 0
|
||||||
|
assert stats["misses"] == 0
|
||||||
|
assert stats["total"] == 0
|
||||||
|
assert stats["hit_rate"] == 0.0
|
||||||
|
|
||||||
|
async def test_get_stats_after_operations(self, cache: LLMCache) -> None:
|
||||||
|
await cache.get("key1", "agent", "v1") # miss
|
||||||
|
await cache.get("key2", "agent", "v1") # miss
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content="test",
|
||||||
|
model="gpt-4o",
|
||||||
|
tokens_in=100,
|
||||||
|
tokens_out=50,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
await cache.set("key1", "agent", "v1", response)
|
||||||
|
await cache.get("key1", "agent", "v1") # hit
|
||||||
|
|
||||||
|
stats = cache.get_stats()
|
||||||
|
assert stats["hits"] == 1
|
||||||
|
assert stats["misses"] == 2
|
||||||
|
assert stats["total"] == 3
|
||||||
|
assert stats["hit_rate"] == pytest.approx(1 / 3)
|
||||||
157
tests/test_cost.py
Normal file
157
tests/test_cost.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Tests for the cost tracking module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost
|
||||||
|
from arbiter.models.enums import AgentName
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentCost:
|
||||||
|
def test_agent_cost_creation(self) -> None:
|
||||||
|
cost = AgentCost(
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
tokens_in=100,
|
||||||
|
tokens_out=50,
|
||||||
|
total_tokens=150,
|
||||||
|
cost_usd=0.01,
|
||||||
|
)
|
||||||
|
assert cost.agent == AgentName.SECURITY
|
||||||
|
assert cost.total_tokens == 150
|
||||||
|
assert cost.cost_usd == 0.01
|
||||||
|
|
||||||
|
def test_agent_cost_defaults(self) -> None:
|
||||||
|
cost = AgentCost(agent=AgentName.STYLE)
|
||||||
|
assert cost.tokens_in == 0
|
||||||
|
assert cost.tokens_out == 0
|
||||||
|
assert cost.total_tokens == 0
|
||||||
|
assert cost.cost_usd == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewCost:
|
||||||
|
def test_review_cost_defaults(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
assert cost.total_tokens == 0
|
||||||
|
assert cost.total_cost_usd == 0.0
|
||||||
|
assert cost.agent_costs == []
|
||||||
|
assert cost.cache_hits == 0
|
||||||
|
assert cost.cache_misses == 0
|
||||||
|
|
||||||
|
def test_add_agent_cost(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
|
||||||
|
cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008)
|
||||||
|
|
||||||
|
assert len(cost.agent_costs) == 2
|
||||||
|
assert cost.total_tokens_in == 180
|
||||||
|
assert cost.total_tokens_out == 90
|
||||||
|
assert cost.total_tokens == 270
|
||||||
|
assert cost.total_cost_usd == pytest.approx(0.018)
|
||||||
|
|
||||||
|
def test_add_deliberation_cost(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_deliberation_cost(tokens_in=50, tokens_out=100, cost_usd=0.005)
|
||||||
|
|
||||||
|
assert cost.deliberation_tokens_in == 50
|
||||||
|
assert cost.deliberation_tokens_out == 100
|
||||||
|
assert cost.deliberation_cost_usd == 0.005
|
||||||
|
assert cost.total_tokens == 150
|
||||||
|
|
||||||
|
def test_combined_costs(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
|
||||||
|
cost.add_deliberation_cost(tokens_in=50, tokens_out=25, cost_usd=0.005)
|
||||||
|
|
||||||
|
assert cost.total_tokens_in == 150
|
||||||
|
assert cost.total_tokens_out == 75
|
||||||
|
assert cost.total_tokens == 225
|
||||||
|
assert cost.total_cost_usd == pytest.approx(0.015)
|
||||||
|
|
||||||
|
def test_to_agent_dict(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
|
||||||
|
cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008)
|
||||||
|
|
||||||
|
agent_dict = cost.to_agent_dict()
|
||||||
|
assert agent_dict == {"security": 150, "style": 120}
|
||||||
|
|
||||||
|
def test_to_cost_dict(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
|
||||||
|
cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008)
|
||||||
|
|
||||||
|
cost_dict = cost.to_cost_dict()
|
||||||
|
assert cost_dict == {"security": 0.01, "style": 0.008}
|
||||||
|
|
||||||
|
def test_is_within_budget_true(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01)
|
||||||
|
|
||||||
|
assert cost.is_within_budget(max_tokens=1000, max_cost_usd=0.50) is True
|
||||||
|
|
||||||
|
def test_is_within_budget_false_tokens(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=1000, tokens_out=500, cost_usd=0.01)
|
||||||
|
|
||||||
|
assert cost.is_within_budget(max_tokens=1000, max_cost_usd=0.50) is False
|
||||||
|
|
||||||
|
def test_is_within_budget_false_cost(self) -> None:
|
||||||
|
cost = ReviewCost()
|
||||||
|
cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=1.0)
|
||||||
|
|
||||||
|
assert cost.is_within_budget(max_tokens=10000, max_cost_usd=0.50) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCostEstimate:
|
||||||
|
def test_estimate_small_diff(self) -> None:
|
||||||
|
estimate = CostEstimate.estimate(
|
||||||
|
diff_size=1000,
|
||||||
|
agents=[AgentName.SECURITY, AgentName.STYLE],
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert estimate.estimated_tokens > 0
|
||||||
|
assert estimate.estimated_cost_usd > 0
|
||||||
|
assert estimate.agents_enabled == [AgentName.SECURITY, AgentName.STYLE]
|
||||||
|
assert estimate.model == "gpt-4o-mini"
|
||||||
|
assert estimate.within_budget is True
|
||||||
|
|
||||||
|
def test_estimate_large_diff(self) -> None:
|
||||||
|
estimate = CostEstimate.estimate(
|
||||||
|
diff_size=100000,
|
||||||
|
agents=[AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY],
|
||||||
|
model="gpt-4o",
|
||||||
|
max_tokens=10000,
|
||||||
|
max_cost_usd=0.10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Large diff with expensive model should exceed budget
|
||||||
|
assert estimate.within_budget is False
|
||||||
|
|
||||||
|
def test_estimate_gpt4o_vs_mini(self) -> None:
|
||||||
|
estimate_4o = CostEstimate.estimate(
|
||||||
|
diff_size=10000,
|
||||||
|
agents=[AgentName.SECURITY],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
estimate_mini = CostEstimate.estimate(
|
||||||
|
diff_size=10000,
|
||||||
|
agents=[AgentName.SECURITY],
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert estimate_4o.estimated_cost_usd > estimate_mini.estimated_cost_usd
|
||||||
|
|
||||||
|
def test_estimate_more_agents_higher_cost(self) -> None:
|
||||||
|
estimate_one = CostEstimate.estimate(
|
||||||
|
diff_size=5000,
|
||||||
|
agents=[AgentName.SECURITY],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
estimate_three = CostEstimate.estimate(
|
||||||
|
diff_size=5000,
|
||||||
|
agents=[AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert estimate_three.estimated_tokens > estimate_one.estimated_tokens
|
||||||
|
assert estimate_three.estimated_cost_usd > estimate_one.estimated_cost_usd
|
||||||
222
tests/test_db_models.py
Normal file
222
tests/test_db_models.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Tests for database models."""
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from arbiter.db.models import (
|
||||||
|
Base,
|
||||||
|
ConflictModel,
|
||||||
|
DeliberationStepModel,
|
||||||
|
FindingModel,
|
||||||
|
PolicyModel,
|
||||||
|
ReviewModel,
|
||||||
|
)
|
||||||
|
from arbiter.deliberation.conflicts import ConflictNature
|
||||||
|
from arbiter.deliberation.coordinator import StepType
|
||||||
|
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewModel:
|
||||||
|
def test_review_model_creation(self) -> None:
|
||||||
|
review = ReviewModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
pr_title="Test PR",
|
||||||
|
base_sha="abc1234567890123456789012345678901234567",
|
||||||
|
head_sha="def1234567890123456789012345678901234567",
|
||||||
|
author="testuser",
|
||||||
|
is_draft=False,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert review.repository == "owner/repo"
|
||||||
|
assert review.pr_number == 42
|
||||||
|
assert review.status == "pending"
|
||||||
|
assert review.is_draft is False
|
||||||
|
|
||||||
|
def test_review_model_with_verdict(self) -> None:
|
||||||
|
review = ReviewModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=1,
|
||||||
|
base_sha="a" * 40,
|
||||||
|
head_sha="b" * 40,
|
||||||
|
status="completed",
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.75,
|
||||||
|
verdict_reasoning="Found some issues",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert review.verdict == Verdict.COMMENT
|
||||||
|
assert review.verdict_confidence == 0.75
|
||||||
|
|
||||||
|
def test_review_model_cost_tracking(self) -> None:
|
||||||
|
review = ReviewModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=1,
|
||||||
|
base_sha="a" * 40,
|
||||||
|
head_sha="b" * 40,
|
||||||
|
total_tokens=1500,
|
||||||
|
total_cost_usd=0.015,
|
||||||
|
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
|
||||||
|
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert review.total_tokens == 1500
|
||||||
|
assert review.total_cost_usd == 0.015
|
||||||
|
assert review.tokens_by_agent["security"] == 500
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindingModel:
|
||||||
|
def test_finding_model_creation(self) -> None:
|
||||||
|
finding = FindingModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=str(uuid4()),
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="src/auth.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="SQL Injection",
|
||||||
|
description="User input concatenated in SQL",
|
||||||
|
reasoning="String concatenation allows injection",
|
||||||
|
suggestion="Use parameterized queries",
|
||||||
|
references=["https://owasp.org"],
|
||||||
|
prompt_version="security-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert finding.agent == AgentName.SECURITY
|
||||||
|
assert finding.severity == Severity.HIGH
|
||||||
|
assert finding.confidence == 0.9
|
||||||
|
assert finding.line_start == 10
|
||||||
|
assert finding.line_end == 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestConflictModel:
|
||||||
|
def test_conflict_model_creation(self) -> None:
|
||||||
|
conflict = ConflictModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=str(uuid4()),
|
||||||
|
finding_ids=["finding-1", "finding-2"],
|
||||||
|
nature=ConflictNature.TRADE_OFF,
|
||||||
|
description="Security vs simplicity trade-off",
|
||||||
|
severity_weight=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conflict.nature == ConflictNature.TRADE_OFF
|
||||||
|
assert len(conflict.finding_ids) == 2
|
||||||
|
assert conflict.severity_weight == 0.7
|
||||||
|
|
||||||
|
def test_conflict_model_with_resolution(self) -> None:
|
||||||
|
conflict = ConflictModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=str(uuid4()),
|
||||||
|
finding_ids=["finding-1", "finding-2"],
|
||||||
|
nature=ConflictNature.CONTRADICTORY,
|
||||||
|
description="Opposing recommendations",
|
||||||
|
severity_weight=0.8,
|
||||||
|
resolution="Security takes precedence",
|
||||||
|
winning_finding_id="finding-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conflict.resolution is not None
|
||||||
|
assert conflict.winning_finding_id == "finding-1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeliberationStepModel:
|
||||||
|
def test_deliberation_step_creation(self) -> None:
|
||||||
|
step = DeliberationStepModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=str(uuid4()),
|
||||||
|
step_type=StepType.MERGE,
|
||||||
|
description="Merged 5 findings",
|
||||||
|
details={"groups": 3, "unique": 5},
|
||||||
|
sequence=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert step.step_type == StepType.MERGE
|
||||||
|
assert step.sequence == 0
|
||||||
|
assert step.details["groups"] == 3
|
||||||
|
|
||||||
|
def test_all_step_types(self) -> None:
|
||||||
|
review_id = str(uuid4())
|
||||||
|
|
||||||
|
steps = [
|
||||||
|
DeliberationStepModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review_id,
|
||||||
|
step_type=StepType.MERGE,
|
||||||
|
description="Merge step",
|
||||||
|
sequence=0,
|
||||||
|
),
|
||||||
|
DeliberationStepModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review_id,
|
||||||
|
step_type=StepType.CONFLICT_DETECTION,
|
||||||
|
description="Conflict detection step",
|
||||||
|
sequence=1,
|
||||||
|
),
|
||||||
|
DeliberationStepModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review_id,
|
||||||
|
step_type=StepType.SYNTHESIS,
|
||||||
|
description="Synthesis step",
|
||||||
|
sequence=2,
|
||||||
|
),
|
||||||
|
DeliberationStepModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
review_id=review_id,
|
||||||
|
step_type=StepType.VERDICT,
|
||||||
|
description="Verdict step",
|
||||||
|
sequence=3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(steps) == 4
|
||||||
|
assert steps[0].step_type == StepType.MERGE
|
||||||
|
assert steps[3].step_type == StepType.VERDICT
|
||||||
|
|
||||||
|
|
||||||
|
class TestPolicyModel:
|
||||||
|
def test_policy_model_creation(self) -> None:
|
||||||
|
policy = PolicyModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="default",
|
||||||
|
organization="test-org",
|
||||||
|
description="Default policy",
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert policy.name == "default"
|
||||||
|
assert policy.organization == "test-org"
|
||||||
|
assert policy.is_active is True
|
||||||
|
|
||||||
|
def test_policy_model_with_config(self) -> None:
|
||||||
|
policy = PolicyModel(
|
||||||
|
id=str(uuid4()),
|
||||||
|
name="strict",
|
||||||
|
agents_config={
|
||||||
|
"security": {"enabled": True, "model": "gpt-4o"},
|
||||||
|
"style": {"enabled": True},
|
||||||
|
"complexity": {"enabled": False},
|
||||||
|
},
|
||||||
|
cost_controls={
|
||||||
|
"max_tokens": 50000,
|
||||||
|
"max_cost_usd": 0.50,
|
||||||
|
},
|
||||||
|
verdict_thresholds={
|
||||||
|
"critical_threshold": 1,
|
||||||
|
"high_threshold": 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert policy.agents_config["security"]["model"] == "gpt-4o"
|
||||||
|
assert policy.cost_controls["max_tokens"] == 50000
|
||||||
|
|
||||||
|
|
||||||
|
class TestBase:
|
||||||
|
def test_base_is_declarative_base(self) -> None:
|
||||||
|
assert hasattr(Base, "metadata")
|
||||||
|
assert hasattr(Base, "registry")
|
||||||
959
tests/test_worker.py
Normal file
959
tests/test_worker.py
Normal file
@@ -0,0 +1,959 @@
|
|||||||
|
"""Tests for the worker module."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from arbiter.db.models import (
|
||||||
|
ConflictModel,
|
||||||
|
DeliberationStepModel,
|
||||||
|
FindingModel,
|
||||||
|
ReviewModel,
|
||||||
|
)
|
||||||
|
from arbiter.integrations import ARBITER_MARKER
|
||||||
|
from arbiter.integrations.base import Comment, CommitStatus, Platform
|
||||||
|
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||||
|
from arbiter.worker.queue import JobPriority, cancel_job, generate_job_id, get_job_status
|
||||||
|
from arbiter.worker.tasks import (
|
||||||
|
_post_or_update_comment,
|
||||||
|
_verdict_to_status,
|
||||||
|
detect_platform,
|
||||||
|
get_platform_client,
|
||||||
|
process_followup,
|
||||||
|
process_review,
|
||||||
|
)
|
||||||
|
from tests.conftest import MockPlatformClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobQueue:
|
||||||
|
def test_generate_job_id_deterministic(self) -> None:
|
||||||
|
id1 = generate_job_id("owner/repo", 42, "abc123")
|
||||||
|
id2 = generate_job_id("owner/repo", 42, "abc123")
|
||||||
|
assert id1 == id2
|
||||||
|
|
||||||
|
def test_job_id_unique(self) -> None:
|
||||||
|
id1 = generate_job_id("owner/repo", 42, "abc123")
|
||||||
|
id2 = generate_job_id("owner/repo", 42, "def456") # Different SHA
|
||||||
|
id3 = generate_job_id("owner/repo", 43, "abc123") # Different PR
|
||||||
|
id4 = generate_job_id("other/repo", 42, "abc123") # Different repo
|
||||||
|
|
||||||
|
assert len({id1, id2, id3, id4}) == 4 # All unique
|
||||||
|
|
||||||
|
def test_generate_job_id_format(self) -> None:
|
||||||
|
job_id = generate_job_id("owner/repo", 42, "abc123")
|
||||||
|
assert len(job_id) == 16
|
||||||
|
assert all(c in "0123456789abcdef" for c in job_id)
|
||||||
|
|
||||||
|
def test_job_priority_ordering(self) -> None:
|
||||||
|
assert JobPriority.HIGH < JobPriority.NORMAL < JobPriority.LOW
|
||||||
|
assert int(JobPriority.HIGH) == 1
|
||||||
|
assert int(JobPriority.NORMAL) == 2
|
||||||
|
assert int(JobPriority.LOW) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerSettings:
|
||||||
|
def test_worker_settings_has_functions(self) -> None:
|
||||||
|
from arbiter.worker.settings import WorkerSettings
|
||||||
|
|
||||||
|
assert WorkerSettings.functions is not None
|
||||||
|
assert len(WorkerSettings.functions) > 0
|
||||||
|
|
||||||
|
def test_worker_settings_has_cron_jobs(self) -> None:
|
||||||
|
from arbiter.worker.settings import WorkerSettings
|
||||||
|
|
||||||
|
assert WorkerSettings.cron_jobs is not None
|
||||||
|
|
||||||
|
def test_worker_settings_lifecycle_hooks(self) -> None:
|
||||||
|
from arbiter.worker.settings import WorkerSettings
|
||||||
|
|
||||||
|
assert WorkerSettings.on_startup is not None
|
||||||
|
assert WorkerSettings.on_shutdown is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewTask:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"settings": None,
|
||||||
|
"redis": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_review_task_requires_diff(self, mock_context: dict[str, Any]) -> None: # noqa: ARG002
|
||||||
|
from arbiter.worker.tasks import process_review
|
||||||
|
|
||||||
|
# Note: This would need database fixtures to run fully
|
||||||
|
# For now, we just verify the function signature
|
||||||
|
assert callable(process_review)
|
||||||
|
|
||||||
|
|
||||||
|
class MockRedisForQueue:
|
||||||
|
"""Mock Redis with enqueue_job support."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._data: dict[str, Any] = {}
|
||||||
|
self._jobs: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def get(self, key: str) -> str | None:
|
||||||
|
return self._data.get(key)
|
||||||
|
|
||||||
|
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||||
|
self._data[key] = value
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> int:
|
||||||
|
if key in self._data:
|
||||||
|
del self._data[key]
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any:
|
||||||
|
job = {"func": func_name, "kwargs": kwargs}
|
||||||
|
self._jobs.append(job)
|
||||||
|
return type("Job", (), {"job_id": kwargs.get("_job_id", "test-id")})()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueReview:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_pool(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
|
||||||
|
mock = MockRedisForQueue()
|
||||||
|
|
||||||
|
async def get_pool() -> MockRedisForQueue:
|
||||||
|
return mock
|
||||||
|
|
||||||
|
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
async def test_enqueue_review_creates_job(self, mock_redis_pool: MockRedisForQueue) -> None:
|
||||||
|
from arbiter.worker.queue import enqueue_review
|
||||||
|
|
||||||
|
job_id = await enqueue_review(
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
pr_title="Test PR",
|
||||||
|
author="testuser",
|
||||||
|
is_draft=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert job_id is not None
|
||||||
|
assert len(mock_redis_pool._jobs) == 1
|
||||||
|
assert mock_redis_pool._jobs[0]["func"] == "process_review"
|
||||||
|
|
||||||
|
async def test_enqueue_review_deduplication(
|
||||||
|
self,
|
||||||
|
mock_redis_pool: MockRedisForQueue, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
from arbiter.worker.queue import enqueue_review
|
||||||
|
|
||||||
|
# First call should succeed
|
||||||
|
job_id1 = await enqueue_review(
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
)
|
||||||
|
assert job_id1 is not None
|
||||||
|
|
||||||
|
# Second call with same params should be deduplicated
|
||||||
|
job_id2 = await enqueue_review(
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
)
|
||||||
|
assert job_id2 is None
|
||||||
|
|
||||||
|
async def test_enqueue_review_draft_lower_priority(
|
||||||
|
self, mock_redis_pool: MockRedisForQueue
|
||||||
|
) -> None:
|
||||||
|
from arbiter.worker.queue import enqueue_review
|
||||||
|
|
||||||
|
await enqueue_review(
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
is_draft=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mock_redis_pool._jobs) == 1
|
||||||
|
job = mock_redis_pool._jobs[0]
|
||||||
|
# Draft PRs should be in the low priority queue
|
||||||
|
assert "arbiter:queue:3" in job["kwargs"]["_queue_name"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobStatusAndCancel:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_pool_with_jobs(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
|
||||||
|
mock = MockRedisForQueue()
|
||||||
|
mock._data["arbiter:job:test-job-id"] = "pending"
|
||||||
|
|
||||||
|
async def get_pool() -> MockRedisForQueue:
|
||||||
|
return mock
|
||||||
|
|
||||||
|
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
async def test_get_job_status_found(
|
||||||
|
self,
|
||||||
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
status = await get_job_status("test-job-id")
|
||||||
|
assert status is not None
|
||||||
|
assert status["job_id"] == "test-job-id"
|
||||||
|
assert status["status"] == "pending"
|
||||||
|
|
||||||
|
async def test_get_job_status_not_found(
|
||||||
|
self,
|
||||||
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
status = await get_job_status("nonexistent")
|
||||||
|
assert status is None
|
||||||
|
|
||||||
|
async def test_cancel_job_success(
|
||||||
|
self,
|
||||||
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
result = await cancel_job("test-job-id")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_cancel_job_not_found(
|
||||||
|
self,
|
||||||
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
result = await cancel_job("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerStartupShutdown:
|
||||||
|
async def test_startup_hook(self) -> None:
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from arbiter.worker.settings import startup
|
||||||
|
|
||||||
|
# Mock init_db and get_settings
|
||||||
|
with (
|
||||||
|
patch("arbiter.worker.settings.init_db", new_callable=AsyncMock) as mock_init,
|
||||||
|
patch("arbiter.worker.settings.get_settings") as mock_settings,
|
||||||
|
):
|
||||||
|
mock_settings.return_value = "mock_settings"
|
||||||
|
ctx: dict[str, Any] = {}
|
||||||
|
await startup(ctx)
|
||||||
|
|
||||||
|
mock_init.assert_called_once()
|
||||||
|
assert ctx["settings"] == "mock_settings"
|
||||||
|
|
||||||
|
async def test_shutdown_hook(self) -> None:
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from arbiter.worker.settings import shutdown
|
||||||
|
|
||||||
|
with patch("arbiter.worker.settings.close_db", new_callable=AsyncMock) as mock_close:
|
||||||
|
ctx: dict[str, Any] = {}
|
||||||
|
await shutdown(ctx)
|
||||||
|
mock_close.assert_called_once()
|
||||||
|
|
||||||
|
async def test_health_check(self) -> None:
|
||||||
|
from arbiter.worker.settings import health_check
|
||||||
|
|
||||||
|
ctx: dict[str, Any] = {}
|
||||||
|
result = await health_check(ctx)
|
||||||
|
assert result == "healthy"
|
||||||
|
|
||||||
|
def test_worker_settings_redis_settings(self) -> None:
|
||||||
|
from arbiter.worker.settings import WorkerSettings
|
||||||
|
|
||||||
|
redis_settings = WorkerSettings.redis_settings()
|
||||||
|
assert redis_settings is not None
|
||||||
|
|
||||||
|
def test_worker_settings_get_functions(self) -> None:
|
||||||
|
from arbiter.worker.settings import WorkerSettings
|
||||||
|
|
||||||
|
functions = WorkerSettings._get_functions()
|
||||||
|
assert len(functions) == 2
|
||||||
|
# Verify the functions are the expected ones
|
||||||
|
func_names = [f.__name__ for f in functions]
|
||||||
|
assert "process_review" in func_names
|
||||||
|
assert "process_followup" in func_names
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectPlatform:
|
||||||
|
def test_detect_platform_from_webhook_github(self) -> None:
|
||||||
|
platform = detect_platform("owner/repo", "github")
|
||||||
|
assert platform == Platform.GITHUB
|
||||||
|
|
||||||
|
def test_detect_platform_from_webhook_gitlab(self) -> None:
|
||||||
|
platform = detect_platform("owner/repo", "gitlab")
|
||||||
|
assert platform == Platform.GITLAB
|
||||||
|
|
||||||
|
def test_detect_platform_case_insensitive(self) -> None:
|
||||||
|
assert detect_platform("owner/repo", "GITHUB") == Platform.GITHUB
|
||||||
|
assert detect_platform("owner/repo", "GitLab") == Platform.GITLAB
|
||||||
|
|
||||||
|
def test_detect_platform_defaults_to_github(self) -> None:
|
||||||
|
platform = detect_platform("owner/repo")
|
||||||
|
assert platform == Platform.GITHUB
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetPlatformClient:
|
||||||
|
def test_get_platform_client_github_no_token(self, mock_settings_no_github: Any) -> None:
|
||||||
|
client = get_platform_client(Platform.GITHUB, mock_settings_no_github)
|
||||||
|
assert client is None
|
||||||
|
|
||||||
|
def test_get_platform_client_gitlab_no_token(self, mock_settings_no_github: Any) -> None:
|
||||||
|
client = get_platform_client(Platform.GITLAB, mock_settings_no_github)
|
||||||
|
assert client is None
|
||||||
|
|
||||||
|
def test_github_client_with_token(self, mock_settings: Any) -> None:
|
||||||
|
from arbiter.integrations import GitHubClient
|
||||||
|
|
||||||
|
client = get_platform_client(Platform.GITHUB, mock_settings)
|
||||||
|
assert client is not None
|
||||||
|
assert isinstance(client, GitHubClient)
|
||||||
|
|
||||||
|
def test_gitlab_client_with_token(self, mock_settings: Any) -> None:
|
||||||
|
from arbiter.integrations import GitLabClient
|
||||||
|
|
||||||
|
client = get_platform_client(Platform.GITLAB, mock_settings)
|
||||||
|
assert client is not None
|
||||||
|
assert isinstance(client, GitLabClient)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerdictToStatus:
|
||||||
|
def test_approve_returns_success(self) -> None:
|
||||||
|
assert _verdict_to_status(Verdict.APPROVE) == CommitStatus.SUCCESS
|
||||||
|
|
||||||
|
def test_request_changes_returns_failure(self) -> None:
|
||||||
|
assert _verdict_to_status(Verdict.REQUEST_CHANGES) == CommitStatus.FAILURE
|
||||||
|
|
||||||
|
def test_comment_returns_success(self) -> None:
|
||||||
|
assert _verdict_to_status(Verdict.COMMENT) == CommitStatus.SUCCESS
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostOrUpdateComment:
|
||||||
|
async def test_post_new_comment(self, mock_platform_client: MockPlatformClient) -> None:
|
||||||
|
body = f"Test comment {ARBITER_MARKER}"
|
||||||
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
assert "owner/repo" in url
|
||||||
|
assert len(mock_platform_client._posted_comments) == 1
|
||||||
|
assert mock_platform_client._posted_comments[0]["body"] == body
|
||||||
|
|
||||||
|
async def test_update_existing_comment(self, mock_platform_client: MockPlatformClient) -> None:
|
||||||
|
# Add an existing Arbiter comment
|
||||||
|
mock_platform_client._comments = [
|
||||||
|
Comment(
|
||||||
|
id="existing-123",
|
||||||
|
body=f"Old review {ARBITER_MARKER}",
|
||||||
|
author="arbiter-bot",
|
||||||
|
url="https://github.com/owner/repo/pull/42#comment-existing-123",
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
body = f"Updated review {ARBITER_MARKER}"
|
||||||
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
assert len(mock_platform_client._posted_comments) == 1
|
||||||
|
# Should be an update, not a new post
|
||||||
|
assert mock_platform_client._posted_comments[0].get("comment_id") == "existing-123"
|
||||||
|
|
||||||
|
async def test_fallback_on_fetch_failure(
|
||||||
|
self, mock_platform_client: MockPlatformClient
|
||||||
|
) -> None:
|
||||||
|
mock_platform_client._fail_on.add("get_comments")
|
||||||
|
|
||||||
|
body = f"Test comment {ARBITER_MARKER}"
|
||||||
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||||
|
|
||||||
|
# Should still post a new comment
|
||||||
|
assert url is not None
|
||||||
|
assert len(mock_platform_client._posted_comments) == 1
|
||||||
|
# Should be a new post since fetching failed
|
||||||
|
assert mock_platform_client._posted_comments[0].get("comment_id") is None
|
||||||
|
|
||||||
|
async def test_returns_none_on_post_failure(
|
||||||
|
self, mock_platform_client: MockPlatformClient
|
||||||
|
) -> None:
|
||||||
|
mock_platform_client._fail_on.add("post_comment")
|
||||||
|
|
||||||
|
body = f"Test comment {ARBITER_MARKER}"
|
||||||
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||||
|
|
||||||
|
assert url is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessReview:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_deliberation_result(self) -> Any:
|
||||||
|
from arbiter.deliberation import DeliberationResult, DeliberationStep
|
||||||
|
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||||
|
from arbiter.deliberation.coordinator import StepType
|
||||||
|
from arbiter.models import Finding
|
||||||
|
|
||||||
|
finding = Finding(
|
||||||
|
id=str(uuid4()),
|
||||||
|
agent=AgentName.SECURITY,
|
||||||
|
file="src/auth.py",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
severity=Severity.HIGH,
|
||||||
|
confidence=0.9,
|
||||||
|
title="SQL Injection",
|
||||||
|
description="User input concatenated into SQL",
|
||||||
|
reasoning="Allows SQL injection attacks",
|
||||||
|
prompt_version="security-v1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
return DeliberationResult(
|
||||||
|
verdict=Verdict.COMMENT,
|
||||||
|
verdict_confidence=0.75,
|
||||||
|
verdict_reasoning="Found security issues",
|
||||||
|
findings=[finding],
|
||||||
|
conflicts=[
|
||||||
|
Conflict(
|
||||||
|
id="conflict-1",
|
||||||
|
finding_ids=["f1", "f2"],
|
||||||
|
nature=ConflictNature.TRADE_OFF,
|
||||||
|
description="Trade-off detected",
|
||||||
|
severity_weight=0.5,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
steps=[
|
||||||
|
DeliberationStep(
|
||||||
|
step_type=StepType.MERGE,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
description="Merged findings",
|
||||||
|
details={"count": 1},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tokens_used=500,
|
||||||
|
cost_usd=0.005,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_review_result(self) -> Any:
|
||||||
|
from arbiter.models import ReviewResult
|
||||||
|
|
||||||
|
return ReviewResult(
|
||||||
|
agent_name=AgentName.SECURITY,
|
||||||
|
findings=[],
|
||||||
|
duration_ms=1000,
|
||||||
|
tokens_used=500,
|
||||||
|
cost_usd=0.005,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_process_review_creates_review_record(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_deliberation_result: Any,
|
||||||
|
mock_review_result: Any,
|
||||||
|
) -> None:
|
||||||
|
# Mock the review pipeline
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
return_value=([mock_review_result], mock_deliberation_result),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
review_id = await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
pr_title="Test PR",
|
||||||
|
author="testuser",
|
||||||
|
diff_content="mock diff",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify review was created
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(ReviewModel).where(ReviewModel.id == review_id)
|
||||||
|
)
|
||||||
|
review = result.scalar_one()
|
||||||
|
|
||||||
|
assert review.repository == "owner/repo"
|
||||||
|
assert review.pr_number == 42
|
||||||
|
assert review.status == "completed"
|
||||||
|
assert review.verdict == Verdict.COMMENT
|
||||||
|
|
||||||
|
async def test_process_review_stores_findings(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_deliberation_result: Any,
|
||||||
|
mock_review_result: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
return_value=([mock_review_result], mock_deliberation_result),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
review_id = await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
diff_content="mock diff",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify findings were stored
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(FindingModel).where(FindingModel.review_id == review_id)
|
||||||
|
)
|
||||||
|
findings = result.scalars().all()
|
||||||
|
|
||||||
|
assert len(findings) == 1
|
||||||
|
assert findings[0].title == "SQL Injection"
|
||||||
|
assert findings[0].severity == Severity.HIGH
|
||||||
|
|
||||||
|
async def test_process_review_stores_conflicts(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_deliberation_result: Any,
|
||||||
|
mock_review_result: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
return_value=([mock_review_result], mock_deliberation_result),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
review_id = await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
diff_content="mock diff",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify conflicts were stored
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(ConflictModel).where(ConflictModel.review_id == review_id)
|
||||||
|
)
|
||||||
|
conflicts = result.scalars().all()
|
||||||
|
|
||||||
|
assert len(conflicts) == 1
|
||||||
|
assert conflicts[0].description == "Trade-off detected"
|
||||||
|
|
||||||
|
async def test_process_review_stores_deliberation_steps(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_deliberation_result: Any,
|
||||||
|
mock_review_result: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
return_value=([mock_review_result], mock_deliberation_result),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
review_id = await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
diff_content="mock diff",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify deliberation steps were stored
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(DeliberationStepModel).where(DeliberationStepModel.review_id == review_id)
|
||||||
|
)
|
||||||
|
steps = result.scalars().all()
|
||||||
|
|
||||||
|
assert len(steps) == 1
|
||||||
|
assert steps[0].description == "Merged findings"
|
||||||
|
|
||||||
|
async def test_process_review_handles_errors(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
side_effect=ValueError("Test error"),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
|
await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
diff_content="mock diff",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify review was marked as failed
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(ReviewModel).where(ReviewModel.repository == "owner/repo")
|
||||||
|
)
|
||||||
|
review = result.scalar_one()
|
||||||
|
|
||||||
|
assert review.status == "failed"
|
||||||
|
assert "Test error" in (review.error_message or "")
|
||||||
|
|
||||||
|
async def test_process_review_posts_comment(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_platform_client: MockPlatformClient,
|
||||||
|
mock_deliberation_result: Any,
|
||||||
|
mock_review_result: Any,
|
||||||
|
mock_settings: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
return_value=([mock_review_result], mock_deliberation_result),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=mock_platform_client,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_settings",
|
||||||
|
return_value=mock_settings,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
diff_content="mock diff",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify comment was posted
|
||||||
|
assert len(mock_platform_client._posted_comments) == 1
|
||||||
|
assert ARBITER_MARKER in mock_platform_client._posted_comments[0]["body"]
|
||||||
|
|
||||||
|
async def test_process_review_updates_status(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_platform_client: MockPlatformClient,
|
||||||
|
mock_deliberation_result: Any,
|
||||||
|
mock_review_result: Any,
|
||||||
|
mock_settings: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks._run_review_pipeline",
|
||||||
|
return_value=([mock_review_result], mock_deliberation_result),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=mock_platform_client,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_settings",
|
||||||
|
return_value=mock_settings,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
diff_content="mock diff",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify status was updated (pending then final)
|
||||||
|
assert len(mock_platform_client._status_updates) >= 1
|
||||||
|
# Last update should be the final status
|
||||||
|
final_update = mock_platform_client._status_updates[-1]
|
||||||
|
assert final_update["status"] == CommitStatus.SUCCESS
|
||||||
|
|
||||||
|
async def test_process_review_requires_diff(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
pytest.raises(ValueError, match="diff_content not provided"),
|
||||||
|
):
|
||||||
|
await process_review(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
base_sha="abc123",
|
||||||
|
head_sha="def456",
|
||||||
|
# No diff_content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessFollowup:
|
||||||
|
async def test_process_followup_no_review(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_settings: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_settings",
|
||||||
|
return_value=mock_settings,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await process_followup(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=999, # Non-existent PR
|
||||||
|
comment_id="comment-123",
|
||||||
|
comment_body="Why is this a security issue?",
|
||||||
|
author="testuser",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_process_followup_disabled(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
completed_review_fixture: ReviewModel, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
class DisabledSettings:
|
||||||
|
followup_enabled = False
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_settings",
|
||||||
|
return_value=DisabledSettings(),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await process_followup(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
comment_id="comment-123",
|
||||||
|
comment_body="Why is this a security issue?",
|
||||||
|
author="testuser",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_process_followup_not_a_question(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
completed_review_fixture: ReviewModel, # noqa: ARG002
|
||||||
|
mock_settings: Any,
|
||||||
|
) -> None:
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_settings",
|
||||||
|
return_value=mock_settings,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await process_followup(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
comment_id="comment-123",
|
||||||
|
comment_body="This looks good to me.", # Not a question
|
||||||
|
author="testuser",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_process_followup_low_confidence(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
completed_review_fixture: ReviewModel, # noqa: ARG002
|
||||||
|
) -> None:
|
||||||
|
class HighThresholdSettings:
|
||||||
|
followup_enabled = True
|
||||||
|
followup_confidence_threshold = 0.99 # Very high threshold
|
||||||
|
llm_timeout = 60
|
||||||
|
llm_max_retries = 3
|
||||||
|
templates_dir = Path("templates")
|
||||||
|
post_comments = False
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.async_session_factory",
|
||||||
|
return_value=lambda: db_session,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_settings",
|
||||||
|
return_value=HighThresholdSettings(),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"arbiter.worker.tasks.get_platform_client",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await process_followup(
|
||||||
|
{},
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
comment_id="comment-123",
|
||||||
|
comment_body="What does this mean?", # A question but low confidence
|
||||||
|
author="testuser",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueFollowup:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis_pool_followup(self, monkeypatch: pytest.MonkeyPatch) -> "MockRedisForQueue":
|
||||||
|
mock = MockRedisForQueue()
|
||||||
|
|
||||||
|
async def get_pool() -> MockRedisForQueue:
|
||||||
|
return mock
|
||||||
|
|
||||||
|
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
async def test_enqueue_followup_creates_job(
|
||||||
|
self, mock_redis_pool_followup: "MockRedisForQueue"
|
||||||
|
) -> None:
|
||||||
|
from arbiter.worker.queue import enqueue_followup
|
||||||
|
|
||||||
|
job_id = await enqueue_followup(
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
comment_id="comment-123",
|
||||||
|
comment_body="Why is this a security issue?",
|
||||||
|
author="testuser",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert job_id is not None
|
||||||
|
assert len(mock_redis_pool_followup._jobs) == 1
|
||||||
|
assert mock_redis_pool_followup._jobs[0]["func"] == "process_followup"
|
||||||
|
|
||||||
|
async def test_enqueue_followup_deduplication(
|
||||||
|
self, mock_redis_pool_followup: "MockRedisForQueue"
|
||||||
|
) -> None:
|
||||||
|
from arbiter.worker.queue import enqueue_followup, generate_followup_job_id
|
||||||
|
|
||||||
|
# Pre-set the job as existing
|
||||||
|
job_id = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||||
|
mock_redis_pool_followup._data[f"arbiter:followup:{job_id}"] = "pending"
|
||||||
|
|
||||||
|
result = await enqueue_followup(
|
||||||
|
repository="owner/repo",
|
||||||
|
pr_number=42,
|
||||||
|
comment_id="comment-123",
|
||||||
|
comment_body="Why is this a security issue?",
|
||||||
|
author="testuser",
|
||||||
|
platform="github",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
# No new job should be added
|
||||||
|
assert len(mock_redis_pool_followup._jobs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateFollowupJobId:
|
||||||
|
def test_followup_job_id_stable(self) -> None:
|
||||||
|
from arbiter.worker.queue import generate_followup_job_id
|
||||||
|
|
||||||
|
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||||
|
id2 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||||
|
assert id1 == id2
|
||||||
|
|
||||||
|
def test_generate_followup_job_id_unique(self) -> None:
|
||||||
|
from arbiter.worker.queue import generate_followup_job_id
|
||||||
|
|
||||||
|
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||||
|
id2 = generate_followup_job_id("owner/repo", 42, "comment-456")
|
||||||
|
id3 = generate_followup_job_id("owner/repo", 43, "comment-123")
|
||||||
|
id4 = generate_followup_job_id("other/repo", 42, "comment-123")
|
||||||
|
|
||||||
|
assert len({id1, id2, id3, id4}) == 4
|
||||||
Reference in New Issue
Block a user