960 lines
32 KiB
Python
960 lines
32 KiB
Python
"""Tests for the worker module."""
|
|
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from arbiter.db.models import (
|
|
ConflictModel,
|
|
DeliberationStepModel,
|
|
FindingModel,
|
|
ReviewModel,
|
|
)
|
|
from arbiter.integrations import ARBITER_MARKER
|
|
from arbiter.integrations.base import Comment, CommitStatus, Platform
|
|
from arbiter.models.enums import AgentName, Severity, Verdict
|
|
from arbiter.worker.queue import JobPriority, cancel_job, generate_job_id, get_job_status
|
|
from arbiter.worker.tasks import (
|
|
_post_or_update_comment,
|
|
_verdict_to_status,
|
|
detect_platform,
|
|
get_platform_client,
|
|
process_followup,
|
|
process_review,
|
|
)
|
|
from tests.conftest import MockPlatformClient
|
|
|
|
|
|
class TestJobQueue:
|
|
def test_generate_job_id_deterministic(self) -> None:
|
|
id1 = generate_job_id("owner/repo", 42, "abc123")
|
|
id2 = generate_job_id("owner/repo", 42, "abc123")
|
|
assert id1 == id2
|
|
|
|
def test_job_id_unique(self) -> None:
|
|
id1 = generate_job_id("owner/repo", 42, "abc123")
|
|
id2 = generate_job_id("owner/repo", 42, "def456") # Different SHA
|
|
id3 = generate_job_id("owner/repo", 43, "abc123") # Different PR
|
|
id4 = generate_job_id("other/repo", 42, "abc123") # Different repo
|
|
|
|
assert len({id1, id2, id3, id4}) == 4 # All unique
|
|
|
|
def test_generate_job_id_format(self) -> None:
|
|
job_id = generate_job_id("owner/repo", 42, "abc123")
|
|
assert len(job_id) == 16
|
|
assert all(c in "0123456789abcdef" for c in job_id)
|
|
|
|
def test_job_priority_ordering(self) -> None:
|
|
assert JobPriority.HIGH < JobPriority.NORMAL < JobPriority.LOW
|
|
assert int(JobPriority.HIGH) == 1
|
|
assert int(JobPriority.NORMAL) == 2
|
|
assert int(JobPriority.LOW) == 3
|
|
|
|
|
|
class TestWorkerSettings:
|
|
def test_worker_settings_has_functions(self) -> None:
|
|
from arbiter.worker.settings import WorkerSettings
|
|
|
|
assert WorkerSettings.functions is not None
|
|
assert len(WorkerSettings.functions) > 0
|
|
|
|
def test_worker_settings_has_cron_jobs(self) -> None:
|
|
from arbiter.worker.settings import WorkerSettings
|
|
|
|
assert WorkerSettings.cron_jobs is not None
|
|
|
|
def test_worker_settings_lifecycle_hooks(self) -> None:
|
|
from arbiter.worker.settings import WorkerSettings
|
|
|
|
assert WorkerSettings.on_startup is not None
|
|
assert WorkerSettings.on_shutdown is not None
|
|
|
|
|
|
class TestReviewTask:
|
|
@pytest.fixture
|
|
def mock_context(self) -> dict[str, Any]:
|
|
return {
|
|
"settings": None,
|
|
"redis": None,
|
|
}
|
|
|
|
async def test_review_task_requires_diff(self, mock_context: dict[str, Any]) -> None: # noqa: ARG002
|
|
from arbiter.worker.tasks import process_review
|
|
|
|
# Note: This would need database fixtures to run fully
|
|
# For now, we just verify the function signature
|
|
assert callable(process_review)
|
|
|
|
|
|
class MockRedisForQueue:
|
|
"""Mock Redis with enqueue_job support."""
|
|
|
|
def __init__(self) -> None:
|
|
self._data: dict[str, Any] = {}
|
|
self._jobs: list[dict[str, Any]] = []
|
|
|
|
async def get(self, key: str) -> str | None:
|
|
return self._data.get(key)
|
|
|
|
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
|
self._data[key] = value
|
|
return True
|
|
|
|
async def delete(self, key: str) -> int:
|
|
if key in self._data:
|
|
del self._data[key]
|
|
return 1
|
|
return 0
|
|
|
|
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any:
|
|
job = {"func": func_name, "kwargs": kwargs}
|
|
self._jobs.append(job)
|
|
return type("Job", (), {"job_id": kwargs.get("_job_id", "test-id")})()
|
|
|
|
|
|
class TestEnqueueReview:
|
|
@pytest.fixture
|
|
def mock_redis_pool(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
|
|
mock = MockRedisForQueue()
|
|
|
|
async def get_pool() -> MockRedisForQueue:
|
|
return mock
|
|
|
|
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
|
return mock
|
|
|
|
async def test_enqueue_review_creates_job(self, mock_redis_pool: MockRedisForQueue) -> None:
|
|
from arbiter.worker.queue import enqueue_review
|
|
|
|
job_id = await enqueue_review(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
pr_title="Test PR",
|
|
author="testuser",
|
|
is_draft=False,
|
|
)
|
|
|
|
assert job_id is not None
|
|
assert len(mock_redis_pool._jobs) == 1
|
|
assert mock_redis_pool._jobs[0]["func"] == "process_review"
|
|
|
|
async def test_enqueue_review_deduplication(
|
|
self,
|
|
mock_redis_pool: MockRedisForQueue, # noqa: ARG002
|
|
) -> None:
|
|
from arbiter.worker.queue import enqueue_review
|
|
|
|
# First call should succeed
|
|
job_id1 = await enqueue_review(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
)
|
|
assert job_id1 is not None
|
|
|
|
# Second call with same params should be deduplicated
|
|
job_id2 = await enqueue_review(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
)
|
|
assert job_id2 is None
|
|
|
|
async def test_enqueue_review_draft_lower_priority(
|
|
self, mock_redis_pool: MockRedisForQueue
|
|
) -> None:
|
|
from arbiter.worker.queue import enqueue_review
|
|
|
|
await enqueue_review(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
is_draft=True,
|
|
)
|
|
|
|
assert len(mock_redis_pool._jobs) == 1
|
|
job = mock_redis_pool._jobs[0]
|
|
# Draft PRs should be in the low priority queue
|
|
assert "arbiter:queue:3" in job["kwargs"]["_queue_name"]
|
|
|
|
|
|
class TestJobStatusAndCancel:
|
|
@pytest.fixture
|
|
def mock_redis_pool_with_jobs(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
|
|
mock = MockRedisForQueue()
|
|
mock._data["arbiter:job:test-job-id"] = "pending"
|
|
|
|
async def get_pool() -> MockRedisForQueue:
|
|
return mock
|
|
|
|
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
|
return mock
|
|
|
|
async def test_get_job_status_found(
|
|
self,
|
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
|
) -> None:
|
|
status = await get_job_status("test-job-id")
|
|
assert status is not None
|
|
assert status["job_id"] == "test-job-id"
|
|
assert status["status"] == "pending"
|
|
|
|
async def test_get_job_status_not_found(
|
|
self,
|
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
|
) -> None:
|
|
status = await get_job_status("nonexistent")
|
|
assert status is None
|
|
|
|
async def test_cancel_job_success(
|
|
self,
|
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
|
) -> None:
|
|
result = await cancel_job("test-job-id")
|
|
assert result is True
|
|
|
|
async def test_cancel_job_not_found(
|
|
self,
|
|
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
|
) -> None:
|
|
result = await cancel_job("nonexistent")
|
|
assert result is False
|
|
|
|
|
|
class TestWorkerStartupShutdown:
|
|
async def test_startup_hook(self) -> None:
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from arbiter.worker.settings import startup
|
|
|
|
# Mock init_db and get_settings
|
|
with (
|
|
patch("arbiter.worker.settings.init_db", new_callable=AsyncMock) as mock_init,
|
|
patch("arbiter.worker.settings.get_settings") as mock_settings,
|
|
):
|
|
mock_settings.return_value = "mock_settings"
|
|
ctx: dict[str, Any] = {}
|
|
await startup(ctx)
|
|
|
|
mock_init.assert_called_once()
|
|
assert ctx["settings"] == "mock_settings"
|
|
|
|
async def test_shutdown_hook(self) -> None:
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from arbiter.worker.settings import shutdown
|
|
|
|
with patch("arbiter.worker.settings.close_db", new_callable=AsyncMock) as mock_close:
|
|
ctx: dict[str, Any] = {}
|
|
await shutdown(ctx)
|
|
mock_close.assert_called_once()
|
|
|
|
async def test_health_check(self) -> None:
|
|
from arbiter.worker.settings import health_check
|
|
|
|
ctx: dict[str, Any] = {}
|
|
result = await health_check(ctx)
|
|
assert result == "healthy"
|
|
|
|
def test_worker_settings_redis_settings(self) -> None:
|
|
from arbiter.worker.settings import WorkerSettings
|
|
|
|
redis_settings = WorkerSettings.redis_settings()
|
|
assert redis_settings is not None
|
|
|
|
def test_worker_settings_get_functions(self) -> None:
|
|
from arbiter.worker.settings import WorkerSettings
|
|
|
|
functions = WorkerSettings._get_functions()
|
|
assert len(functions) == 2
|
|
# Verify the functions are the expected ones
|
|
func_names = [f.__name__ for f in functions]
|
|
assert "process_review" in func_names
|
|
assert "process_followup" in func_names
|
|
|
|
|
|
class TestDetectPlatform:
|
|
def test_detect_platform_from_webhook_github(self) -> None:
|
|
platform = detect_platform("owner/repo", "github")
|
|
assert platform == Platform.GITHUB
|
|
|
|
def test_detect_platform_from_webhook_gitlab(self) -> None:
|
|
platform = detect_platform("owner/repo", "gitlab")
|
|
assert platform == Platform.GITLAB
|
|
|
|
def test_detect_platform_case_insensitive(self) -> None:
|
|
assert detect_platform("owner/repo", "GITHUB") == Platform.GITHUB
|
|
assert detect_platform("owner/repo", "GitLab") == Platform.GITLAB
|
|
|
|
def test_detect_platform_defaults_to_github(self) -> None:
|
|
platform = detect_platform("owner/repo")
|
|
assert platform == Platform.GITHUB
|
|
|
|
|
|
class TestGetPlatformClient:
|
|
def test_get_platform_client_github_no_token(self, mock_settings_no_github: Any) -> None:
|
|
client = get_platform_client(Platform.GITHUB, mock_settings_no_github)
|
|
assert client is None
|
|
|
|
def test_get_platform_client_gitlab_no_token(self, mock_settings_no_github: Any) -> None:
|
|
client = get_platform_client(Platform.GITLAB, mock_settings_no_github)
|
|
assert client is None
|
|
|
|
def test_github_client_with_token(self, mock_settings: Any) -> None:
|
|
from arbiter.integrations import GitHubClient
|
|
|
|
client = get_platform_client(Platform.GITHUB, mock_settings)
|
|
assert client is not None
|
|
assert isinstance(client, GitHubClient)
|
|
|
|
def test_gitlab_client_with_token(self, mock_settings: Any) -> None:
|
|
from arbiter.integrations import GitLabClient
|
|
|
|
client = get_platform_client(Platform.GITLAB, mock_settings)
|
|
assert client is not None
|
|
assert isinstance(client, GitLabClient)
|
|
|
|
|
|
class TestVerdictToStatus:
|
|
def test_approve_returns_success(self) -> None:
|
|
assert _verdict_to_status(Verdict.APPROVE) == CommitStatus.SUCCESS
|
|
|
|
def test_request_changes_returns_failure(self) -> None:
|
|
assert _verdict_to_status(Verdict.REQUEST_CHANGES) == CommitStatus.FAILURE
|
|
|
|
def test_comment_returns_success(self) -> None:
|
|
assert _verdict_to_status(Verdict.COMMENT) == CommitStatus.SUCCESS
|
|
|
|
|
|
class TestPostOrUpdateComment:
|
|
async def test_post_new_comment(self, mock_platform_client: MockPlatformClient) -> None:
|
|
body = f"Test comment {ARBITER_MARKER}"
|
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
|
|
|
assert url is not None
|
|
assert "owner/repo" in url
|
|
assert len(mock_platform_client._posted_comments) == 1
|
|
assert mock_platform_client._posted_comments[0]["body"] == body
|
|
|
|
async def test_update_existing_comment(self, mock_platform_client: MockPlatformClient) -> None:
|
|
# Add an existing Arbiter comment
|
|
mock_platform_client._comments = [
|
|
Comment(
|
|
id="existing-123",
|
|
body=f"Old review {ARBITER_MARKER}",
|
|
author="arbiter-bot",
|
|
url="https://github.com/owner/repo/pull/42#comment-existing-123",
|
|
created_at=datetime.now(UTC),
|
|
)
|
|
]
|
|
|
|
body = f"Updated review {ARBITER_MARKER}"
|
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
|
|
|
assert url is not None
|
|
assert len(mock_platform_client._posted_comments) == 1
|
|
# Should be an update, not a new post
|
|
assert mock_platform_client._posted_comments[0].get("comment_id") == "existing-123"
|
|
|
|
async def test_fallback_on_fetch_failure(
|
|
self, mock_platform_client: MockPlatformClient
|
|
) -> None:
|
|
mock_platform_client._fail_on.add("get_comments")
|
|
|
|
body = f"Test comment {ARBITER_MARKER}"
|
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
|
|
|
# Should still post a new comment
|
|
assert url is not None
|
|
assert len(mock_platform_client._posted_comments) == 1
|
|
# Should be a new post since fetching failed
|
|
assert mock_platform_client._posted_comments[0].get("comment_id") is None
|
|
|
|
async def test_returns_none_on_post_failure(
|
|
self, mock_platform_client: MockPlatformClient
|
|
) -> None:
|
|
mock_platform_client._fail_on.add("post_comment")
|
|
|
|
body = f"Test comment {ARBITER_MARKER}"
|
|
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
|
|
|
assert url is None
|
|
|
|
|
|
class TestProcessReview:
|
|
@pytest.fixture
|
|
def mock_deliberation_result(self) -> Any:
|
|
from arbiter.deliberation import DeliberationResult, DeliberationStep
|
|
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
|
from arbiter.deliberation.coordinator import StepType
|
|
from arbiter.models import Finding
|
|
|
|
finding = Finding(
|
|
id=str(uuid4()),
|
|
agent=AgentName.SECURITY,
|
|
file="src/auth.py",
|
|
line_start=10,
|
|
line_end=15,
|
|
severity=Severity.HIGH,
|
|
confidence=0.9,
|
|
title="SQL Injection",
|
|
description="User input concatenated into SQL",
|
|
reasoning="Allows SQL injection attacks",
|
|
prompt_version="security-v1.0",
|
|
)
|
|
|
|
return DeliberationResult(
|
|
verdict=Verdict.COMMENT,
|
|
verdict_confidence=0.75,
|
|
verdict_reasoning="Found security issues",
|
|
findings=[finding],
|
|
conflicts=[
|
|
Conflict(
|
|
id="conflict-1",
|
|
finding_ids=["f1", "f2"],
|
|
nature=ConflictNature.TRADE_OFF,
|
|
description="Trade-off detected",
|
|
severity_weight=0.5,
|
|
)
|
|
],
|
|
steps=[
|
|
DeliberationStep(
|
|
step_type=StepType.MERGE,
|
|
timestamp=datetime.now(UTC),
|
|
description="Merged findings",
|
|
details={"count": 1},
|
|
)
|
|
],
|
|
tokens_used=500,
|
|
cost_usd=0.005,
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_review_result(self) -> Any:
|
|
from arbiter.models import ReviewResult
|
|
|
|
return ReviewResult(
|
|
agent_name=AgentName.SECURITY,
|
|
findings=[],
|
|
duration_ms=1000,
|
|
tokens_used=500,
|
|
cost_usd=0.005,
|
|
)
|
|
|
|
async def test_process_review_creates_review_record(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_deliberation_result: Any,
|
|
mock_review_result: Any,
|
|
) -> None:
|
|
# Mock the review pipeline
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
return_value=([mock_review_result], mock_deliberation_result),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
review_id = await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
pr_title="Test PR",
|
|
author="testuser",
|
|
diff_content="mock diff",
|
|
)
|
|
|
|
# Verify review was created
|
|
result = await db_session.execute(
|
|
select(ReviewModel).where(ReviewModel.id == review_id)
|
|
)
|
|
review = result.scalar_one()
|
|
|
|
assert review.repository == "owner/repo"
|
|
assert review.pr_number == 42
|
|
assert review.status == "completed"
|
|
assert review.verdict == Verdict.COMMENT
|
|
|
|
async def test_process_review_stores_findings(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_deliberation_result: Any,
|
|
mock_review_result: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
return_value=([mock_review_result], mock_deliberation_result),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
review_id = await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
diff_content="mock diff",
|
|
)
|
|
|
|
# Verify findings were stored
|
|
result = await db_session.execute(
|
|
select(FindingModel).where(FindingModel.review_id == review_id)
|
|
)
|
|
findings = result.scalars().all()
|
|
|
|
assert len(findings) == 1
|
|
assert findings[0].title == "SQL Injection"
|
|
assert findings[0].severity == Severity.HIGH
|
|
|
|
async def test_process_review_stores_conflicts(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_deliberation_result: Any,
|
|
mock_review_result: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
return_value=([mock_review_result], mock_deliberation_result),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
review_id = await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
diff_content="mock diff",
|
|
)
|
|
|
|
# Verify conflicts were stored
|
|
result = await db_session.execute(
|
|
select(ConflictModel).where(ConflictModel.review_id == review_id)
|
|
)
|
|
conflicts = result.scalars().all()
|
|
|
|
assert len(conflicts) == 1
|
|
assert conflicts[0].description == "Trade-off detected"
|
|
|
|
async def test_process_review_stores_deliberation_steps(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_deliberation_result: Any,
|
|
mock_review_result: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
return_value=([mock_review_result], mock_deliberation_result),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
review_id = await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
diff_content="mock diff",
|
|
)
|
|
|
|
# Verify deliberation steps were stored
|
|
result = await db_session.execute(
|
|
select(DeliberationStepModel).where(DeliberationStepModel.review_id == review_id)
|
|
)
|
|
steps = result.scalars().all()
|
|
|
|
assert len(steps) == 1
|
|
assert steps[0].description == "Merged findings"
|
|
|
|
async def test_process_review_handles_errors(
|
|
self,
|
|
db_session: AsyncSession,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
side_effect=ValueError("Test error"),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
with pytest.raises(ValueError, match="Test error"):
|
|
await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
diff_content="mock diff",
|
|
)
|
|
|
|
# Verify review was marked as failed
|
|
result = await db_session.execute(
|
|
select(ReviewModel).where(ReviewModel.repository == "owner/repo")
|
|
)
|
|
review = result.scalar_one()
|
|
|
|
assert review.status == "failed"
|
|
assert "Test error" in (review.error_message or "")
|
|
|
|
async def test_process_review_posts_comment(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_platform_client: MockPlatformClient,
|
|
mock_deliberation_result: Any,
|
|
mock_review_result: Any,
|
|
mock_settings: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
return_value=([mock_review_result], mock_deliberation_result),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=mock_platform_client,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_settings",
|
|
return_value=mock_settings,
|
|
),
|
|
):
|
|
await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
diff_content="mock diff",
|
|
platform="github",
|
|
)
|
|
|
|
# Verify comment was posted
|
|
assert len(mock_platform_client._posted_comments) == 1
|
|
assert ARBITER_MARKER in mock_platform_client._posted_comments[0]["body"]
|
|
|
|
async def test_process_review_updates_status(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_platform_client: MockPlatformClient,
|
|
mock_deliberation_result: Any,
|
|
mock_review_result: Any,
|
|
mock_settings: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks._run_review_pipeline",
|
|
return_value=([mock_review_result], mock_deliberation_result),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=mock_platform_client,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_settings",
|
|
return_value=mock_settings,
|
|
),
|
|
):
|
|
await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
diff_content="mock diff",
|
|
platform="github",
|
|
)
|
|
|
|
# Verify status was updated (pending then final)
|
|
assert len(mock_platform_client._status_updates) >= 1
|
|
# Last update should be the final status
|
|
final_update = mock_platform_client._status_updates[-1]
|
|
assert final_update["status"] == CommitStatus.SUCCESS
|
|
|
|
async def test_process_review_requires_diff(
|
|
self,
|
|
db_session: AsyncSession,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
pytest.raises(ValueError, match="diff_content not provided"),
|
|
):
|
|
await process_review(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
# No diff_content
|
|
)
|
|
|
|
|
|
class TestProcessFollowup:
|
|
async def test_process_followup_no_review(
|
|
self,
|
|
db_session: AsyncSession,
|
|
mock_settings: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_settings",
|
|
return_value=mock_settings,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
result = await process_followup(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=999, # Non-existent PR
|
|
comment_id="comment-123",
|
|
comment_body="Why is this a security issue?",
|
|
author="testuser",
|
|
platform="github",
|
|
)
|
|
|
|
assert result is None
|
|
|
|
async def test_process_followup_disabled(
|
|
self,
|
|
db_session: AsyncSession,
|
|
completed_review_fixture: ReviewModel, # noqa: ARG002
|
|
) -> None:
|
|
class DisabledSettings:
|
|
followup_enabled = False
|
|
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_settings",
|
|
return_value=DisabledSettings(),
|
|
),
|
|
):
|
|
result = await process_followup(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
comment_id="comment-123",
|
|
comment_body="Why is this a security issue?",
|
|
author="testuser",
|
|
platform="github",
|
|
)
|
|
|
|
assert result is None
|
|
|
|
async def test_process_followup_not_a_question(
|
|
self,
|
|
db_session: AsyncSession,
|
|
completed_review_fixture: ReviewModel, # noqa: ARG002
|
|
mock_settings: Any,
|
|
) -> None:
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_settings",
|
|
return_value=mock_settings,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
result = await process_followup(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
comment_id="comment-123",
|
|
comment_body="This looks good to me.", # Not a question
|
|
author="testuser",
|
|
platform="github",
|
|
)
|
|
|
|
assert result is None
|
|
|
|
async def test_process_followup_low_confidence(
|
|
self,
|
|
db_session: AsyncSession,
|
|
completed_review_fixture: ReviewModel, # noqa: ARG002
|
|
) -> None:
|
|
class HighThresholdSettings:
|
|
followup_enabled = True
|
|
followup_confidence_threshold = 0.99 # Very high threshold
|
|
llm_timeout = 60
|
|
llm_max_retries = 3
|
|
templates_dir = Path("templates")
|
|
post_comments = False
|
|
|
|
with (
|
|
patch(
|
|
"arbiter.worker.tasks.async_session_factory",
|
|
return_value=lambda: db_session,
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_settings",
|
|
return_value=HighThresholdSettings(),
|
|
),
|
|
patch(
|
|
"arbiter.worker.tasks.get_platform_client",
|
|
return_value=None,
|
|
),
|
|
):
|
|
result = await process_followup(
|
|
{},
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
comment_id="comment-123",
|
|
comment_body="What does this mean?", # A question but low confidence
|
|
author="testuser",
|
|
platform="github",
|
|
)
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestEnqueueFollowup:
|
|
@pytest.fixture
|
|
def mock_redis_pool_followup(self, monkeypatch: pytest.MonkeyPatch) -> "MockRedisForQueue":
|
|
mock = MockRedisForQueue()
|
|
|
|
async def get_pool() -> MockRedisForQueue:
|
|
return mock
|
|
|
|
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
|
return mock
|
|
|
|
async def test_enqueue_followup_creates_job(
|
|
self, mock_redis_pool_followup: "MockRedisForQueue"
|
|
) -> None:
|
|
from arbiter.worker.queue import enqueue_followup
|
|
|
|
job_id = await enqueue_followup(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
comment_id="comment-123",
|
|
comment_body="Why is this a security issue?",
|
|
author="testuser",
|
|
platform="github",
|
|
)
|
|
|
|
assert job_id is not None
|
|
assert len(mock_redis_pool_followup._jobs) == 1
|
|
assert mock_redis_pool_followup._jobs[0]["func"] == "process_followup"
|
|
|
|
async def test_enqueue_followup_deduplication(
|
|
self, mock_redis_pool_followup: "MockRedisForQueue"
|
|
) -> None:
|
|
from arbiter.worker.queue import enqueue_followup, generate_followup_job_id
|
|
|
|
# Pre-set the job as existing
|
|
job_id = generate_followup_job_id("owner/repo", 42, "comment-123")
|
|
mock_redis_pool_followup._data[f"arbiter:followup:{job_id}"] = "pending"
|
|
|
|
result = await enqueue_followup(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
comment_id="comment-123",
|
|
comment_body="Why is this a security issue?",
|
|
author="testuser",
|
|
platform="github",
|
|
)
|
|
|
|
assert result is None
|
|
# No new job should be added
|
|
assert len(mock_redis_pool_followup._jobs) == 0
|
|
|
|
|
|
class TestGenerateFollowupJobId:
|
|
def test_followup_job_id_stable(self) -> None:
|
|
from arbiter.worker.queue import generate_followup_job_id
|
|
|
|
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
|
id2 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
|
assert id1 == id2
|
|
|
|
def test_generate_followup_job_id_unique(self) -> None:
|
|
from arbiter.worker.queue import generate_followup_job_id
|
|
|
|
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
|
id2 = generate_followup_job_id("owner/repo", 42, "comment-456")
|
|
id3 = generate_followup_job_id("owner/repo", 43, "comment-123")
|
|
id4 = generate_followup_job_id("other/repo", 42, "comment-123")
|
|
|
|
assert len({id1, id2, id3, id4}) == 4
|