tests for api, worker, cache
This commit is contained in:
959
tests/test_worker.py
Normal file
959
tests/test_worker.py
Normal file
@@ -0,0 +1,959 @@
|
||||
"""Tests for the worker module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from arbiter.db.models import (
|
||||
ConflictModel,
|
||||
DeliberationStepModel,
|
||||
FindingModel,
|
||||
ReviewModel,
|
||||
)
|
||||
from arbiter.integrations import ARBITER_MARKER
|
||||
from arbiter.integrations.base import Comment, CommitStatus, Platform
|
||||
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||
from arbiter.worker.queue import JobPriority, cancel_job, generate_job_id, get_job_status
|
||||
from arbiter.worker.tasks import (
|
||||
_post_or_update_comment,
|
||||
_verdict_to_status,
|
||||
detect_platform,
|
||||
get_platform_client,
|
||||
process_followup,
|
||||
process_review,
|
||||
)
|
||||
from tests.conftest import MockPlatformClient
|
||||
|
||||
|
||||
class TestJobQueue:
|
||||
def test_generate_job_id_deterministic(self) -> None:
|
||||
id1 = generate_job_id("owner/repo", 42, "abc123")
|
||||
id2 = generate_job_id("owner/repo", 42, "abc123")
|
||||
assert id1 == id2
|
||||
|
||||
def test_job_id_unique(self) -> None:
|
||||
id1 = generate_job_id("owner/repo", 42, "abc123")
|
||||
id2 = generate_job_id("owner/repo", 42, "def456") # Different SHA
|
||||
id3 = generate_job_id("owner/repo", 43, "abc123") # Different PR
|
||||
id4 = generate_job_id("other/repo", 42, "abc123") # Different repo
|
||||
|
||||
assert len({id1, id2, id3, id4}) == 4 # All unique
|
||||
|
||||
def test_generate_job_id_format(self) -> None:
|
||||
job_id = generate_job_id("owner/repo", 42, "abc123")
|
||||
assert len(job_id) == 16
|
||||
assert all(c in "0123456789abcdef" for c in job_id)
|
||||
|
||||
def test_job_priority_ordering(self) -> None:
|
||||
assert JobPriority.HIGH < JobPriority.NORMAL < JobPriority.LOW
|
||||
assert int(JobPriority.HIGH) == 1
|
||||
assert int(JobPriority.NORMAL) == 2
|
||||
assert int(JobPriority.LOW) == 3
|
||||
|
||||
|
||||
class TestWorkerSettings:
|
||||
def test_worker_settings_has_functions(self) -> None:
|
||||
from arbiter.worker.settings import WorkerSettings
|
||||
|
||||
assert WorkerSettings.functions is not None
|
||||
assert len(WorkerSettings.functions) > 0
|
||||
|
||||
def test_worker_settings_has_cron_jobs(self) -> None:
|
||||
from arbiter.worker.settings import WorkerSettings
|
||||
|
||||
assert WorkerSettings.cron_jobs is not None
|
||||
|
||||
def test_worker_settings_lifecycle_hooks(self) -> None:
|
||||
from arbiter.worker.settings import WorkerSettings
|
||||
|
||||
assert WorkerSettings.on_startup is not None
|
||||
assert WorkerSettings.on_shutdown is not None
|
||||
|
||||
|
||||
class TestReviewTask:
|
||||
@pytest.fixture
|
||||
def mock_context(self) -> dict[str, Any]:
|
||||
return {
|
||||
"settings": None,
|
||||
"redis": None,
|
||||
}
|
||||
|
||||
async def test_review_task_requires_diff(self, mock_context: dict[str, Any]) -> None: # noqa: ARG002
|
||||
from arbiter.worker.tasks import process_review
|
||||
|
||||
# Note: This would need database fixtures to run fully
|
||||
# For now, we just verify the function signature
|
||||
assert callable(process_review)
|
||||
|
||||
|
||||
class MockRedisForQueue:
|
||||
"""Mock Redis with enqueue_job support."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, Any] = {}
|
||||
self._jobs: list[dict[str, Any]] = []
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._data.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||
self._data[key] = value
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any:
|
||||
job = {"func": func_name, "kwargs": kwargs}
|
||||
self._jobs.append(job)
|
||||
return type("Job", (), {"job_id": kwargs.get("_job_id", "test-id")})()
|
||||
|
||||
|
||||
class TestEnqueueReview:
|
||||
@pytest.fixture
|
||||
def mock_redis_pool(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
|
||||
mock = MockRedisForQueue()
|
||||
|
||||
async def get_pool() -> MockRedisForQueue:
|
||||
return mock
|
||||
|
||||
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
||||
return mock
|
||||
|
||||
async def test_enqueue_review_creates_job(self, mock_redis_pool: MockRedisForQueue) -> None:
|
||||
from arbiter.worker.queue import enqueue_review
|
||||
|
||||
job_id = await enqueue_review(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
pr_title="Test PR",
|
||||
author="testuser",
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
assert job_id is not None
|
||||
assert len(mock_redis_pool._jobs) == 1
|
||||
assert mock_redis_pool._jobs[0]["func"] == "process_review"
|
||||
|
||||
async def test_enqueue_review_deduplication(
|
||||
self,
|
||||
mock_redis_pool: MockRedisForQueue, # noqa: ARG002
|
||||
) -> None:
|
||||
from arbiter.worker.queue import enqueue_review
|
||||
|
||||
# First call should succeed
|
||||
job_id1 = await enqueue_review(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
)
|
||||
assert job_id1 is not None
|
||||
|
||||
# Second call with same params should be deduplicated
|
||||
job_id2 = await enqueue_review(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
)
|
||||
assert job_id2 is None
|
||||
|
||||
async def test_enqueue_review_draft_lower_priority(
|
||||
self, mock_redis_pool: MockRedisForQueue
|
||||
) -> None:
|
||||
from arbiter.worker.queue import enqueue_review
|
||||
|
||||
await enqueue_review(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
is_draft=True,
|
||||
)
|
||||
|
||||
assert len(mock_redis_pool._jobs) == 1
|
||||
job = mock_redis_pool._jobs[0]
|
||||
# Draft PRs should be in the low priority queue
|
||||
assert "arbiter:queue:3" in job["kwargs"]["_queue_name"]
|
||||
|
||||
|
||||
class TestJobStatusAndCancel:
|
||||
@pytest.fixture
|
||||
def mock_redis_pool_with_jobs(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue:
|
||||
mock = MockRedisForQueue()
|
||||
mock._data["arbiter:job:test-job-id"] = "pending"
|
||||
|
||||
async def get_pool() -> MockRedisForQueue:
|
||||
return mock
|
||||
|
||||
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
||||
return mock
|
||||
|
||||
async def test_get_job_status_found(
|
||||
self,
|
||||
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||
) -> None:
|
||||
status = await get_job_status("test-job-id")
|
||||
assert status is not None
|
||||
assert status["job_id"] == "test-job-id"
|
||||
assert status["status"] == "pending"
|
||||
|
||||
async def test_get_job_status_not_found(
|
||||
self,
|
||||
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||
) -> None:
|
||||
status = await get_job_status("nonexistent")
|
||||
assert status is None
|
||||
|
||||
async def test_cancel_job_success(
|
||||
self,
|
||||
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||
) -> None:
|
||||
result = await cancel_job("test-job-id")
|
||||
assert result is True
|
||||
|
||||
async def test_cancel_job_not_found(
|
||||
self,
|
||||
mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002
|
||||
) -> None:
|
||||
result = await cancel_job("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestWorkerStartupShutdown:
|
||||
async def test_startup_hook(self) -> None:
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from arbiter.worker.settings import startup
|
||||
|
||||
# Mock init_db and get_settings
|
||||
with (
|
||||
patch("arbiter.worker.settings.init_db", new_callable=AsyncMock) as mock_init,
|
||||
patch("arbiter.worker.settings.get_settings") as mock_settings,
|
||||
):
|
||||
mock_settings.return_value = "mock_settings"
|
||||
ctx: dict[str, Any] = {}
|
||||
await startup(ctx)
|
||||
|
||||
mock_init.assert_called_once()
|
||||
assert ctx["settings"] == "mock_settings"
|
||||
|
||||
async def test_shutdown_hook(self) -> None:
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from arbiter.worker.settings import shutdown
|
||||
|
||||
with patch("arbiter.worker.settings.close_db", new_callable=AsyncMock) as mock_close:
|
||||
ctx: dict[str, Any] = {}
|
||||
await shutdown(ctx)
|
||||
mock_close.assert_called_once()
|
||||
|
||||
async def test_health_check(self) -> None:
|
||||
from arbiter.worker.settings import health_check
|
||||
|
||||
ctx: dict[str, Any] = {}
|
||||
result = await health_check(ctx)
|
||||
assert result == "healthy"
|
||||
|
||||
def test_worker_settings_redis_settings(self) -> None:
|
||||
from arbiter.worker.settings import WorkerSettings
|
||||
|
||||
redis_settings = WorkerSettings.redis_settings()
|
||||
assert redis_settings is not None
|
||||
|
||||
def test_worker_settings_get_functions(self) -> None:
|
||||
from arbiter.worker.settings import WorkerSettings
|
||||
|
||||
functions = WorkerSettings._get_functions()
|
||||
assert len(functions) == 2
|
||||
# Verify the functions are the expected ones
|
||||
func_names = [f.__name__ for f in functions]
|
||||
assert "process_review" in func_names
|
||||
assert "process_followup" in func_names
|
||||
|
||||
|
||||
class TestDetectPlatform:
|
||||
def test_detect_platform_from_webhook_github(self) -> None:
|
||||
platform = detect_platform("owner/repo", "github")
|
||||
assert platform == Platform.GITHUB
|
||||
|
||||
def test_detect_platform_from_webhook_gitlab(self) -> None:
|
||||
platform = detect_platform("owner/repo", "gitlab")
|
||||
assert platform == Platform.GITLAB
|
||||
|
||||
def test_detect_platform_case_insensitive(self) -> None:
|
||||
assert detect_platform("owner/repo", "GITHUB") == Platform.GITHUB
|
||||
assert detect_platform("owner/repo", "GitLab") == Platform.GITLAB
|
||||
|
||||
def test_detect_platform_defaults_to_github(self) -> None:
|
||||
platform = detect_platform("owner/repo")
|
||||
assert platform == Platform.GITHUB
|
||||
|
||||
|
||||
class TestGetPlatformClient:
|
||||
def test_get_platform_client_github_no_token(self, mock_settings_no_github: Any) -> None:
|
||||
client = get_platform_client(Platform.GITHUB, mock_settings_no_github)
|
||||
assert client is None
|
||||
|
||||
def test_get_platform_client_gitlab_no_token(self, mock_settings_no_github: Any) -> None:
|
||||
client = get_platform_client(Platform.GITLAB, mock_settings_no_github)
|
||||
assert client is None
|
||||
|
||||
def test_github_client_with_token(self, mock_settings: Any) -> None:
|
||||
from arbiter.integrations import GitHubClient
|
||||
|
||||
client = get_platform_client(Platform.GITHUB, mock_settings)
|
||||
assert client is not None
|
||||
assert isinstance(client, GitHubClient)
|
||||
|
||||
def test_gitlab_client_with_token(self, mock_settings: Any) -> None:
|
||||
from arbiter.integrations import GitLabClient
|
||||
|
||||
client = get_platform_client(Platform.GITLAB, mock_settings)
|
||||
assert client is not None
|
||||
assert isinstance(client, GitLabClient)
|
||||
|
||||
|
||||
class TestVerdictToStatus:
|
||||
def test_approve_returns_success(self) -> None:
|
||||
assert _verdict_to_status(Verdict.APPROVE) == CommitStatus.SUCCESS
|
||||
|
||||
def test_request_changes_returns_failure(self) -> None:
|
||||
assert _verdict_to_status(Verdict.REQUEST_CHANGES) == CommitStatus.FAILURE
|
||||
|
||||
def test_comment_returns_success(self) -> None:
|
||||
assert _verdict_to_status(Verdict.COMMENT) == CommitStatus.SUCCESS
|
||||
|
||||
|
||||
class TestPostOrUpdateComment:
|
||||
async def test_post_new_comment(self, mock_platform_client: MockPlatformClient) -> None:
|
||||
body = f"Test comment {ARBITER_MARKER}"
|
||||
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||
|
||||
assert url is not None
|
||||
assert "owner/repo" in url
|
||||
assert len(mock_platform_client._posted_comments) == 1
|
||||
assert mock_platform_client._posted_comments[0]["body"] == body
|
||||
|
||||
async def test_update_existing_comment(self, mock_platform_client: MockPlatformClient) -> None:
|
||||
# Add an existing Arbiter comment
|
||||
mock_platform_client._comments = [
|
||||
Comment(
|
||||
id="existing-123",
|
||||
body=f"Old review {ARBITER_MARKER}",
|
||||
author="arbiter-bot",
|
||||
url="https://github.com/owner/repo/pull/42#comment-existing-123",
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
]
|
||||
|
||||
body = f"Updated review {ARBITER_MARKER}"
|
||||
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||
|
||||
assert url is not None
|
||||
assert len(mock_platform_client._posted_comments) == 1
|
||||
# Should be an update, not a new post
|
||||
assert mock_platform_client._posted_comments[0].get("comment_id") == "existing-123"
|
||||
|
||||
async def test_fallback_on_fetch_failure(
|
||||
self, mock_platform_client: MockPlatformClient
|
||||
) -> None:
|
||||
mock_platform_client._fail_on.add("get_comments")
|
||||
|
||||
body = f"Test comment {ARBITER_MARKER}"
|
||||
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||
|
||||
# Should still post a new comment
|
||||
assert url is not None
|
||||
assert len(mock_platform_client._posted_comments) == 1
|
||||
# Should be a new post since fetching failed
|
||||
assert mock_platform_client._posted_comments[0].get("comment_id") is None
|
||||
|
||||
async def test_returns_none_on_post_failure(
|
||||
self, mock_platform_client: MockPlatformClient
|
||||
) -> None:
|
||||
mock_platform_client._fail_on.add("post_comment")
|
||||
|
||||
body = f"Test comment {ARBITER_MARKER}"
|
||||
url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body)
|
||||
|
||||
assert url is None
|
||||
|
||||
|
||||
class TestProcessReview:
|
||||
@pytest.fixture
|
||||
def mock_deliberation_result(self) -> Any:
|
||||
from arbiter.deliberation import DeliberationResult, DeliberationStep
|
||||
from arbiter.deliberation.conflicts import Conflict, ConflictNature
|
||||
from arbiter.deliberation.coordinator import StepType
|
||||
from arbiter.models import Finding
|
||||
|
||||
finding = Finding(
|
||||
id=str(uuid4()),
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection",
|
||||
description="User input concatenated into SQL",
|
||||
reasoning="Allows SQL injection attacks",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
|
||||
return DeliberationResult(
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.75,
|
||||
verdict_reasoning="Found security issues",
|
||||
findings=[finding],
|
||||
conflicts=[
|
||||
Conflict(
|
||||
id="conflict-1",
|
||||
finding_ids=["f1", "f2"],
|
||||
nature=ConflictNature.TRADE_OFF,
|
||||
description="Trade-off detected",
|
||||
severity_weight=0.5,
|
||||
)
|
||||
],
|
||||
steps=[
|
||||
DeliberationStep(
|
||||
step_type=StepType.MERGE,
|
||||
timestamp=datetime.now(UTC),
|
||||
description="Merged findings",
|
||||
details={"count": 1},
|
||||
)
|
||||
],
|
||||
tokens_used=500,
|
||||
cost_usd=0.005,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_review_result(self) -> Any:
|
||||
from arbiter.models import ReviewResult
|
||||
|
||||
return ReviewResult(
|
||||
agent_name=AgentName.SECURITY,
|
||||
findings=[],
|
||||
duration_ms=1000,
|
||||
tokens_used=500,
|
||||
cost_usd=0.005,
|
||||
)
|
||||
|
||||
async def test_process_review_creates_review_record(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_deliberation_result: Any,
|
||||
mock_review_result: Any,
|
||||
) -> None:
|
||||
# Mock the review pipeline
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
return_value=([mock_review_result], mock_deliberation_result),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
review_id = await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
pr_title="Test PR",
|
||||
author="testuser",
|
||||
diff_content="mock diff",
|
||||
)
|
||||
|
||||
# Verify review was created
|
||||
result = await db_session.execute(
|
||||
select(ReviewModel).where(ReviewModel.id == review_id)
|
||||
)
|
||||
review = result.scalar_one()
|
||||
|
||||
assert review.repository == "owner/repo"
|
||||
assert review.pr_number == 42
|
||||
assert review.status == "completed"
|
||||
assert review.verdict == Verdict.COMMENT
|
||||
|
||||
async def test_process_review_stores_findings(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_deliberation_result: Any,
|
||||
mock_review_result: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
return_value=([mock_review_result], mock_deliberation_result),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
review_id = await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
diff_content="mock diff",
|
||||
)
|
||||
|
||||
# Verify findings were stored
|
||||
result = await db_session.execute(
|
||||
select(FindingModel).where(FindingModel.review_id == review_id)
|
||||
)
|
||||
findings = result.scalars().all()
|
||||
|
||||
assert len(findings) == 1
|
||||
assert findings[0].title == "SQL Injection"
|
||||
assert findings[0].severity == Severity.HIGH
|
||||
|
||||
async def test_process_review_stores_conflicts(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_deliberation_result: Any,
|
||||
mock_review_result: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
return_value=([mock_review_result], mock_deliberation_result),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
review_id = await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
diff_content="mock diff",
|
||||
)
|
||||
|
||||
# Verify conflicts were stored
|
||||
result = await db_session.execute(
|
||||
select(ConflictModel).where(ConflictModel.review_id == review_id)
|
||||
)
|
||||
conflicts = result.scalars().all()
|
||||
|
||||
assert len(conflicts) == 1
|
||||
assert conflicts[0].description == "Trade-off detected"
|
||||
|
||||
async def test_process_review_stores_deliberation_steps(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_deliberation_result: Any,
|
||||
mock_review_result: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
return_value=([mock_review_result], mock_deliberation_result),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
review_id = await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
diff_content="mock diff",
|
||||
)
|
||||
|
||||
# Verify deliberation steps were stored
|
||||
result = await db_session.execute(
|
||||
select(DeliberationStepModel).where(DeliberationStepModel.review_id == review_id)
|
||||
)
|
||||
steps = result.scalars().all()
|
||||
|
||||
assert len(steps) == 1
|
||||
assert steps[0].description == "Merged findings"
|
||||
|
||||
async def test_process_review_handles_errors(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
side_effect=ValueError("Test error"),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
diff_content="mock diff",
|
||||
)
|
||||
|
||||
# Verify review was marked as failed
|
||||
result = await db_session.execute(
|
||||
select(ReviewModel).where(ReviewModel.repository == "owner/repo")
|
||||
)
|
||||
review = result.scalar_one()
|
||||
|
||||
assert review.status == "failed"
|
||||
assert "Test error" in (review.error_message or "")
|
||||
|
||||
async def test_process_review_posts_comment(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_platform_client: MockPlatformClient,
|
||||
mock_deliberation_result: Any,
|
||||
mock_review_result: Any,
|
||||
mock_settings: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
return_value=([mock_review_result], mock_deliberation_result),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=mock_platform_client,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_settings",
|
||||
return_value=mock_settings,
|
||||
),
|
||||
):
|
||||
await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
diff_content="mock diff",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
# Verify comment was posted
|
||||
assert len(mock_platform_client._posted_comments) == 1
|
||||
assert ARBITER_MARKER in mock_platform_client._posted_comments[0]["body"]
|
||||
|
||||
async def test_process_review_updates_status(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_platform_client: MockPlatformClient,
|
||||
mock_deliberation_result: Any,
|
||||
mock_review_result: Any,
|
||||
mock_settings: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks._run_review_pipeline",
|
||||
return_value=([mock_review_result], mock_deliberation_result),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=mock_platform_client,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_settings",
|
||||
return_value=mock_settings,
|
||||
),
|
||||
):
|
||||
await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
diff_content="mock diff",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
# Verify status was updated (pending then final)
|
||||
assert len(mock_platform_client._status_updates) >= 1
|
||||
# Last update should be the final status
|
||||
final_update = mock_platform_client._status_updates[-1]
|
||||
assert final_update["status"] == CommitStatus.SUCCESS
|
||||
|
||||
async def test_process_review_requires_diff(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
pytest.raises(ValueError, match="diff_content not provided"),
|
||||
):
|
||||
await process_review(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
# No diff_content
|
||||
)
|
||||
|
||||
|
||||
class TestProcessFollowup:
|
||||
async def test_process_followup_no_review(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_settings: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_settings",
|
||||
return_value=mock_settings,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = await process_followup(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=999, # Non-existent PR
|
||||
comment_id="comment-123",
|
||||
comment_body="Why is this a security issue?",
|
||||
author="testuser",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_process_followup_disabled(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
completed_review_fixture: ReviewModel, # noqa: ARG002
|
||||
) -> None:
|
||||
class DisabledSettings:
|
||||
followup_enabled = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_settings",
|
||||
return_value=DisabledSettings(),
|
||||
),
|
||||
):
|
||||
result = await process_followup(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
comment_id="comment-123",
|
||||
comment_body="Why is this a security issue?",
|
||||
author="testuser",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_process_followup_not_a_question(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
completed_review_fixture: ReviewModel, # noqa: ARG002
|
||||
mock_settings: Any,
|
||||
) -> None:
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_settings",
|
||||
return_value=mock_settings,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = await process_followup(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
comment_id="comment-123",
|
||||
comment_body="This looks good to me.", # Not a question
|
||||
author="testuser",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_process_followup_low_confidence(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
completed_review_fixture: ReviewModel, # noqa: ARG002
|
||||
) -> None:
|
||||
class HighThresholdSettings:
|
||||
followup_enabled = True
|
||||
followup_confidence_threshold = 0.99 # Very high threshold
|
||||
llm_timeout = 60
|
||||
llm_max_retries = 3
|
||||
templates_dir = Path("templates")
|
||||
post_comments = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
"arbiter.worker.tasks.async_session_factory",
|
||||
return_value=lambda: db_session,
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_settings",
|
||||
return_value=HighThresholdSettings(),
|
||||
),
|
||||
patch(
|
||||
"arbiter.worker.tasks.get_platform_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = await process_followup(
|
||||
{},
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
comment_id="comment-123",
|
||||
comment_body="What does this mean?", # A question but low confidence
|
||||
author="testuser",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEnqueueFollowup:
|
||||
@pytest.fixture
|
||||
def mock_redis_pool_followup(self, monkeypatch: pytest.MonkeyPatch) -> "MockRedisForQueue":
|
||||
mock = MockRedisForQueue()
|
||||
|
||||
async def get_pool() -> MockRedisForQueue:
|
||||
return mock
|
||||
|
||||
monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool)
|
||||
return mock
|
||||
|
||||
async def test_enqueue_followup_creates_job(
|
||||
self, mock_redis_pool_followup: "MockRedisForQueue"
|
||||
) -> None:
|
||||
from arbiter.worker.queue import enqueue_followup
|
||||
|
||||
job_id = await enqueue_followup(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
comment_id="comment-123",
|
||||
comment_body="Why is this a security issue?",
|
||||
author="testuser",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
assert job_id is not None
|
||||
assert len(mock_redis_pool_followup._jobs) == 1
|
||||
assert mock_redis_pool_followup._jobs[0]["func"] == "process_followup"
|
||||
|
||||
async def test_enqueue_followup_deduplication(
|
||||
self, mock_redis_pool_followup: "MockRedisForQueue"
|
||||
) -> None:
|
||||
from arbiter.worker.queue import enqueue_followup, generate_followup_job_id
|
||||
|
||||
# Pre-set the job as existing
|
||||
job_id = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||
mock_redis_pool_followup._data[f"arbiter:followup:{job_id}"] = "pending"
|
||||
|
||||
result = await enqueue_followup(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
comment_id="comment-123",
|
||||
comment_body="Why is this a security issue?",
|
||||
author="testuser",
|
||||
platform="github",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
# No new job should be added
|
||||
assert len(mock_redis_pool_followup._jobs) == 0
|
||||
|
||||
|
||||
class TestGenerateFollowupJobId:
|
||||
def test_followup_job_id_stable(self) -> None:
|
||||
from arbiter.worker.queue import generate_followup_job_id
|
||||
|
||||
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||
id2 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||
assert id1 == id2
|
||||
|
||||
def test_generate_followup_job_id_unique(self) -> None:
|
||||
from arbiter.worker.queue import generate_followup_job_id
|
||||
|
||||
id1 = generate_followup_job_id("owner/repo", 42, "comment-123")
|
||||
id2 = generate_followup_job_id("owner/repo", 42, "comment-456")
|
||||
id3 = generate_followup_job_id("owner/repo", 43, "comment-123")
|
||||
id4 = generate_followup_job_id("other/repo", 42, "comment-123")
|
||||
|
||||
assert len({id1, id2, id3, id4}) == 4
|
||||
Reference in New Issue
Block a user