"""Tests for the worker module.""" from datetime import UTC, datetime from pathlib import Path from typing import Any from unittest.mock import patch from uuid import uuid4 import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from arbiter.db.models import ( ConflictModel, DeliberationStepModel, FindingModel, ReviewModel, ) from arbiter.integrations import ARBITER_MARKER from arbiter.integrations.base import Comment, CommitStatus, Platform from arbiter.models.enums import AgentName, Severity, Verdict from arbiter.worker.queue import JobPriority, cancel_job, generate_job_id, get_job_status from arbiter.worker.tasks import ( _post_or_update_comment, _verdict_to_status, detect_platform, get_platform_client, process_followup, process_review, ) from tests.conftest import MockPlatformClient class TestJobQueue: def test_generate_job_id_deterministic(self) -> None: id1 = generate_job_id("owner/repo", 42, "abc123") id2 = generate_job_id("owner/repo", 42, "abc123") assert id1 == id2 def test_job_id_unique(self) -> None: id1 = generate_job_id("owner/repo", 42, "abc123") id2 = generate_job_id("owner/repo", 42, "def456") # Different SHA id3 = generate_job_id("owner/repo", 43, "abc123") # Different PR id4 = generate_job_id("other/repo", 42, "abc123") # Different repo assert len({id1, id2, id3, id4}) == 4 # All unique def test_generate_job_id_format(self) -> None: job_id = generate_job_id("owner/repo", 42, "abc123") assert len(job_id) == 16 assert all(c in "0123456789abcdef" for c in job_id) def test_job_priority_ordering(self) -> None: assert JobPriority.HIGH < JobPriority.NORMAL < JobPriority.LOW assert int(JobPriority.HIGH) == 1 assert int(JobPriority.NORMAL) == 2 assert int(JobPriority.LOW) == 3 class TestWorkerSettings: def test_worker_settings_has_functions(self) -> None: from arbiter.worker.settings import WorkerSettings assert WorkerSettings.functions is not None assert len(WorkerSettings.functions) > 0 def test_worker_settings_has_cron_jobs(self) -> None: from arbiter.worker.settings import WorkerSettings assert WorkerSettings.cron_jobs is not None def test_worker_settings_lifecycle_hooks(self) -> None: from arbiter.worker.settings import WorkerSettings assert WorkerSettings.on_startup is not None assert WorkerSettings.on_shutdown is not None class TestReviewTask: @pytest.fixture def mock_context(self) -> dict[str, Any]: return { "settings": None, "redis": None, } async def test_review_task_requires_diff(self, mock_context: dict[str, Any]) -> None: # noqa: ARG002 from arbiter.worker.tasks import process_review # Note: This would need database fixtures to run fully # For now, we just verify the function signature assert callable(process_review) class MockRedisForQueue: """Mock Redis with enqueue_job support.""" def __init__(self) -> None: self._data: dict[str, Any] = {} self._jobs: list[dict[str, Any]] = [] async def get(self, key: str) -> str | None: return self._data.get(key) async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002 self._data[key] = value return True async def delete(self, key: str) -> int: if key in self._data: del self._data[key] return 1 return 0 async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: job = {"func": func_name, "kwargs": kwargs} self._jobs.append(job) return type("Job", (), {"job_id": kwargs.get("_job_id", "test-id")})() class TestEnqueueReview: @pytest.fixture def mock_redis_pool(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue: mock = MockRedisForQueue() async def get_pool() -> MockRedisForQueue: return mock monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool) return mock async def test_enqueue_review_creates_job(self, mock_redis_pool: MockRedisForQueue) -> None: from arbiter.worker.queue import enqueue_review job_id = await enqueue_review( repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", pr_title="Test PR", author="testuser", is_draft=False, ) assert job_id is not None assert len(mock_redis_pool._jobs) == 1 assert mock_redis_pool._jobs[0]["func"] == "process_review" async def test_enqueue_review_deduplication( self, mock_redis_pool: MockRedisForQueue, # noqa: ARG002 ) -> None: from arbiter.worker.queue import enqueue_review # First call should succeed job_id1 = await enqueue_review( repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", ) assert job_id1 is not None # Second call with same params should be deduplicated job_id2 = await enqueue_review( repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", ) assert job_id2 is None async def test_enqueue_review_draft_lower_priority( self, mock_redis_pool: MockRedisForQueue ) -> None: from arbiter.worker.queue import enqueue_review await enqueue_review( repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", is_draft=True, ) assert len(mock_redis_pool._jobs) == 1 job = mock_redis_pool._jobs[0] # Draft PRs should be in the low priority queue assert "arbiter:queue:3" in job["kwargs"]["_queue_name"] class TestJobStatusAndCancel: @pytest.fixture def mock_redis_pool_with_jobs(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue: mock = MockRedisForQueue() mock._data["arbiter:job:test-job-id"] = "pending" async def get_pool() -> MockRedisForQueue: return mock monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool) return mock async def test_get_job_status_found( self, mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 ) -> None: status = await get_job_status("test-job-id") assert status is not None assert status["job_id"] == "test-job-id" assert status["status"] == "pending" async def test_get_job_status_not_found( self, mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 ) -> None: status = await get_job_status("nonexistent") assert status is None async def test_cancel_job_success( self, mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 ) -> None: result = await cancel_job("test-job-id") assert result is True async def test_cancel_job_not_found( self, mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 ) -> None: result = await cancel_job("nonexistent") assert result is False class TestWorkerStartupShutdown: async def test_startup_hook(self) -> None: from unittest.mock import AsyncMock, patch from arbiter.worker.settings import startup # Mock init_db and get_settings with ( patch("arbiter.worker.settings.init_db", new_callable=AsyncMock) as mock_init, patch("arbiter.worker.settings.get_settings") as mock_settings, ): mock_settings.return_value = "mock_settings" ctx: dict[str, Any] = {} await startup(ctx) mock_init.assert_called_once() assert ctx["settings"] == "mock_settings" async def test_shutdown_hook(self) -> None: from unittest.mock import AsyncMock, patch from arbiter.worker.settings import shutdown with patch("arbiter.worker.settings.close_db", new_callable=AsyncMock) as mock_close: ctx: dict[str, Any] = {} await shutdown(ctx) mock_close.assert_called_once() async def test_health_check(self) -> None: from arbiter.worker.settings import health_check ctx: dict[str, Any] = {} result = await health_check(ctx) assert result == "healthy" def test_worker_settings_redis_settings(self) -> None: from arbiter.worker.settings import WorkerSettings redis_settings = WorkerSettings.redis_settings() assert redis_settings is not None def test_worker_settings_get_functions(self) -> None: from arbiter.worker.settings import WorkerSettings functions = WorkerSettings._get_functions() assert len(functions) == 2 # Verify the functions are the expected ones func_names = [f.__name__ for f in functions] assert "process_review" in func_names assert "process_followup" in func_names class TestDetectPlatform: def test_detect_platform_from_webhook_github(self) -> None: platform = detect_platform("owner/repo", "github") assert platform == Platform.GITHUB def test_detect_platform_from_webhook_gitlab(self) -> None: platform = detect_platform("owner/repo", "gitlab") assert platform == Platform.GITLAB def test_detect_platform_case_insensitive(self) -> None: assert detect_platform("owner/repo", "GITHUB") == Platform.GITHUB assert detect_platform("owner/repo", "GitLab") == Platform.GITLAB def test_detect_platform_defaults_to_github(self) -> None: platform = detect_platform("owner/repo") assert platform == Platform.GITHUB class TestGetPlatformClient: def test_get_platform_client_github_no_token(self, mock_settings_no_github: Any) -> None: client = get_platform_client(Platform.GITHUB, mock_settings_no_github) assert client is None def test_get_platform_client_gitlab_no_token(self, mock_settings_no_github: Any) -> None: client = get_platform_client(Platform.GITLAB, mock_settings_no_github) assert client is None def test_github_client_with_token(self, mock_settings: Any) -> None: from arbiter.integrations import GitHubClient client = get_platform_client(Platform.GITHUB, mock_settings) assert client is not None assert isinstance(client, GitHubClient) def test_gitlab_client_with_token(self, mock_settings: Any) -> None: from arbiter.integrations import GitLabClient client = get_platform_client(Platform.GITLAB, mock_settings) assert client is not None assert isinstance(client, GitLabClient) class TestVerdictToStatus: def test_approve_returns_success(self) -> None: assert _verdict_to_status(Verdict.APPROVE) == CommitStatus.SUCCESS def test_request_changes_returns_failure(self) -> None: assert _verdict_to_status(Verdict.REQUEST_CHANGES) == CommitStatus.FAILURE def test_comment_returns_success(self) -> None: assert _verdict_to_status(Verdict.COMMENT) == CommitStatus.SUCCESS class TestPostOrUpdateComment: async def test_post_new_comment(self, mock_platform_client: MockPlatformClient) -> None: body = f"Test comment {ARBITER_MARKER}" url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) assert url is not None assert "owner/repo" in url assert len(mock_platform_client._posted_comments) == 1 assert mock_platform_client._posted_comments[0]["body"] == body async def test_update_existing_comment(self, mock_platform_client: MockPlatformClient) -> None: # Add an existing Arbiter comment mock_platform_client._comments = [ Comment( id="existing-123", body=f"Old review {ARBITER_MARKER}", author="arbiter-bot", url="https://github.com/owner/repo/pull/42#comment-existing-123", created_at=datetime.now(UTC), ) ] body = f"Updated review {ARBITER_MARKER}" url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) assert url is not None assert len(mock_platform_client._posted_comments) == 1 # Should be an update, not a new post assert mock_platform_client._posted_comments[0].get("comment_id") == "existing-123" async def test_fallback_on_fetch_failure( self, mock_platform_client: MockPlatformClient ) -> None: mock_platform_client._fail_on.add("get_comments") body = f"Test comment {ARBITER_MARKER}" url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) # Should still post a new comment assert url is not None assert len(mock_platform_client._posted_comments) == 1 # Should be a new post since fetching failed assert mock_platform_client._posted_comments[0].get("comment_id") is None async def test_returns_none_on_post_failure( self, mock_platform_client: MockPlatformClient ) -> None: mock_platform_client._fail_on.add("post_comment") body = f"Test comment {ARBITER_MARKER}" url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) assert url is None class TestProcessReview: @pytest.fixture def mock_deliberation_result(self) -> Any: from arbiter.deliberation import DeliberationResult, DeliberationStep from arbiter.deliberation.conflicts import Conflict, ConflictNature from arbiter.deliberation.coordinator import StepType from arbiter.models import Finding finding = Finding( id=str(uuid4()), agent=AgentName.SECURITY, file="src/auth.py", line_start=10, line_end=15, severity=Severity.HIGH, confidence=0.9, title="SQL Injection", description="User input concatenated into SQL", reasoning="Allows SQL injection attacks", prompt_version="security-v1.0", ) return DeliberationResult( verdict=Verdict.COMMENT, verdict_confidence=0.75, verdict_reasoning="Found security issues", findings=[finding], conflicts=[ Conflict( id="conflict-1", finding_ids=["f1", "f2"], nature=ConflictNature.TRADE_OFF, description="Trade-off detected", severity_weight=0.5, ) ], steps=[ DeliberationStep( step_type=StepType.MERGE, timestamp=datetime.now(UTC), description="Merged findings", details={"count": 1}, ) ], tokens_used=500, cost_usd=0.005, ) @pytest.fixture def mock_review_result(self) -> Any: from arbiter.models import ReviewResult return ReviewResult( agent_name=AgentName.SECURITY, findings=[], duration_ms=1000, tokens_used=500, cost_usd=0.005, ) async def test_process_review_creates_review_record( self, db_session: AsyncSession, mock_deliberation_result: Any, mock_review_result: Any, ) -> None: # Mock the review pipeline with ( patch( "arbiter.worker.tasks._run_review_pipeline", return_value=([mock_review_result], mock_deliberation_result), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): review_id = await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", pr_title="Test PR", author="testuser", diff_content="mock diff", ) # Verify review was created result = await db_session.execute( select(ReviewModel).where(ReviewModel.id == review_id) ) review = result.scalar_one() assert review.repository == "owner/repo" assert review.pr_number == 42 assert review.status == "completed" assert review.verdict == Verdict.COMMENT async def test_process_review_stores_findings( self, db_session: AsyncSession, mock_deliberation_result: Any, mock_review_result: Any, ) -> None: with ( patch( "arbiter.worker.tasks._run_review_pipeline", return_value=([mock_review_result], mock_deliberation_result), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): review_id = await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", diff_content="mock diff", ) # Verify findings were stored result = await db_session.execute( select(FindingModel).where(FindingModel.review_id == review_id) ) findings = result.scalars().all() assert len(findings) == 1 assert findings[0].title == "SQL Injection" assert findings[0].severity == Severity.HIGH async def test_process_review_stores_conflicts( self, db_session: AsyncSession, mock_deliberation_result: Any, mock_review_result: Any, ) -> None: with ( patch( "arbiter.worker.tasks._run_review_pipeline", return_value=([mock_review_result], mock_deliberation_result), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): review_id = await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", diff_content="mock diff", ) # Verify conflicts were stored result = await db_session.execute( select(ConflictModel).where(ConflictModel.review_id == review_id) ) conflicts = result.scalars().all() assert len(conflicts) == 1 assert conflicts[0].description == "Trade-off detected" async def test_process_review_stores_deliberation_steps( self, db_session: AsyncSession, mock_deliberation_result: Any, mock_review_result: Any, ) -> None: with ( patch( "arbiter.worker.tasks._run_review_pipeline", return_value=([mock_review_result], mock_deliberation_result), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): review_id = await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", diff_content="mock diff", ) # Verify deliberation steps were stored result = await db_session.execute( select(DeliberationStepModel).where(DeliberationStepModel.review_id == review_id) ) steps = result.scalars().all() assert len(steps) == 1 assert steps[0].description == "Merged findings" async def test_process_review_handles_errors( self, db_session: AsyncSession, ) -> None: with ( patch( "arbiter.worker.tasks._run_review_pipeline", side_effect=ValueError("Test error"), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): with pytest.raises(ValueError, match="Test error"): await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", diff_content="mock diff", ) # Verify review was marked as failed result = await db_session.execute( select(ReviewModel).where(ReviewModel.repository == "owner/repo") ) review = result.scalar_one() assert review.status == "failed" assert "Test error" in (review.error_message or "") async def test_process_review_posts_comment( self, db_session: AsyncSession, mock_platform_client: MockPlatformClient, mock_deliberation_result: Any, mock_review_result: Any, mock_settings: Any, ) -> None: with ( patch( "arbiter.worker.tasks._run_review_pipeline", return_value=([mock_review_result], mock_deliberation_result), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=mock_platform_client, ), patch( "arbiter.worker.tasks.get_settings", return_value=mock_settings, ), ): await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", diff_content="mock diff", platform="github", ) # Verify comment was posted assert len(mock_platform_client._posted_comments) == 1 assert ARBITER_MARKER in mock_platform_client._posted_comments[0]["body"] async def test_process_review_updates_status( self, db_session: AsyncSession, mock_platform_client: MockPlatformClient, mock_deliberation_result: Any, mock_review_result: Any, mock_settings: Any, ) -> None: with ( patch( "arbiter.worker.tasks._run_review_pipeline", return_value=([mock_review_result], mock_deliberation_result), ), patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=mock_platform_client, ), patch( "arbiter.worker.tasks.get_settings", return_value=mock_settings, ), ): await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", diff_content="mock diff", platform="github", ) # Verify status was updated (pending then final) assert len(mock_platform_client._status_updates) >= 1 # Last update should be the final status final_update = mock_platform_client._status_updates[-1] assert final_update["status"] == CommitStatus.SUCCESS async def test_process_review_requires_diff( self, db_session: AsyncSession, ) -> None: with ( patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), pytest.raises(ValueError, match="diff_content not provided"), ): await process_review( {}, repository="owner/repo", pr_number=42, base_sha="abc123", head_sha="def456", # No diff_content ) class TestProcessFollowup: async def test_process_followup_no_review( self, db_session: AsyncSession, mock_settings: Any, ) -> None: with ( patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_settings", return_value=mock_settings, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): result = await process_followup( {}, repository="owner/repo", pr_number=999, # Non-existent PR comment_id="comment-123", comment_body="Why is this a security issue?", author="testuser", platform="github", ) assert result is None async def test_process_followup_disabled( self, db_session: AsyncSession, completed_review_fixture: ReviewModel, # noqa: ARG002 ) -> None: class DisabledSettings: followup_enabled = False with ( patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_settings", return_value=DisabledSettings(), ), ): result = await process_followup( {}, repository="owner/repo", pr_number=42, comment_id="comment-123", comment_body="Why is this a security issue?", author="testuser", platform="github", ) assert result is None async def test_process_followup_not_a_question( self, db_session: AsyncSession, completed_review_fixture: ReviewModel, # noqa: ARG002 mock_settings: Any, ) -> None: with ( patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_settings", return_value=mock_settings, ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): result = await process_followup( {}, repository="owner/repo", pr_number=42, comment_id="comment-123", comment_body="This looks good to me.", # Not a question author="testuser", platform="github", ) assert result is None async def test_process_followup_low_confidence( self, db_session: AsyncSession, completed_review_fixture: ReviewModel, # noqa: ARG002 ) -> None: class HighThresholdSettings: followup_enabled = True followup_confidence_threshold = 0.99 # Very high threshold llm_timeout = 60 llm_max_retries = 3 templates_dir = Path("templates") post_comments = False with ( patch( "arbiter.worker.tasks.async_session_factory", return_value=lambda: db_session, ), patch( "arbiter.worker.tasks.get_settings", return_value=HighThresholdSettings(), ), patch( "arbiter.worker.tasks.get_platform_client", return_value=None, ), ): result = await process_followup( {}, repository="owner/repo", pr_number=42, comment_id="comment-123", comment_body="What does this mean?", # A question but low confidence author="testuser", platform="github", ) assert result is None class TestEnqueueFollowup: @pytest.fixture def mock_redis_pool_followup(self, monkeypatch: pytest.MonkeyPatch) -> "MockRedisForQueue": mock = MockRedisForQueue() async def get_pool() -> MockRedisForQueue: return mock monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool) return mock async def test_enqueue_followup_creates_job( self, mock_redis_pool_followup: "MockRedisForQueue" ) -> None: from arbiter.worker.queue import enqueue_followup job_id = await enqueue_followup( repository="owner/repo", pr_number=42, comment_id="comment-123", comment_body="Why is this a security issue?", author="testuser", platform="github", ) assert job_id is not None assert len(mock_redis_pool_followup._jobs) == 1 assert mock_redis_pool_followup._jobs[0]["func"] == "process_followup" async def test_enqueue_followup_deduplication( self, mock_redis_pool_followup: "MockRedisForQueue" ) -> None: from arbiter.worker.queue import enqueue_followup, generate_followup_job_id # Pre-set the job as existing job_id = generate_followup_job_id("owner/repo", 42, "comment-123") mock_redis_pool_followup._data[f"arbiter:followup:{job_id}"] = "pending" result = await enqueue_followup( repository="owner/repo", pr_number=42, comment_id="comment-123", comment_body="Why is this a security issue?", author="testuser", platform="github", ) assert result is None # No new job should be added assert len(mock_redis_pool_followup._jobs) == 0 class TestGenerateFollowupJobId: def test_followup_job_id_stable(self) -> None: from arbiter.worker.queue import generate_followup_job_id id1 = generate_followup_job_id("owner/repo", 42, "comment-123") id2 = generate_followup_job_id("owner/repo", 42, "comment-123") assert id1 == id2 def test_generate_followup_job_id_unique(self) -> None: from arbiter.worker.queue import generate_followup_job_id id1 = generate_followup_job_id("owner/repo", 42, "comment-123") id2 = generate_followup_job_id("owner/repo", 42, "comment-456") id3 = generate_followup_job_id("owner/repo", 43, "comment-123") id4 = generate_followup_job_id("other/repo", 42, "comment-123") assert len({id1, id2, id3, id4}) == 4