From 9a23e2c9c4a657f8214160dfc6ec975c12d74d90 Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Sat, 22 Mar 2025 11:04:46 +0000 Subject: [PATCH] tests for api, worker, cache --- tests/test_api.py | 1153 +++++++++++++++++++++++++++++++++++++++ tests/test_cache.py | 164 ++++++ tests/test_cost.py | 157 ++++++ tests/test_db_models.py | 222 ++++++++ tests/test_worker.py | 959 ++++++++++++++++++++++++++++++++ 5 files changed, 2655 insertions(+) create mode 100644 tests/test_api.py create mode 100644 tests/test_cache.py create mode 100644 tests/test_cost.py create mode 100644 tests/test_db_models.py create mode 100644 tests/test_worker.py diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..7df334a --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,1153 @@ +"""Tests for the API endpoints.""" + +from datetime import UTC, datetime +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from arbiter.api.routes.health import record_review_completed +from arbiter.db.models import ( + ConversationMessageModel, + ConversationModel, + DeliberationStepModel, + ReviewModel, +) +from arbiter.deliberation.coordinator import StepType +from tests.conftest import MockRedis + + +class TestHealthEndpoints: + async def test_liveness_check(self, test_client: AsyncClient) -> None: + response = await test_client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "version" in data + + async def test_liveness_probe(self, test_client: AsyncClient) -> None: + response = await test_client.get("/health/live") + assert response.status_code == 200 + assert response.json()["status"] == "alive" + + +class TestWebhookSignatureValidation: + def test_github_signature_verification(self) -> None: + import hashlib + import hmac + + from arbiter.api.routes.webhooks import _verify_github_signature + + secret = "test-secret" + payload = b'{"test": "data"}' + expected_sig = "sha256=" + hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest() + + assert _verify_github_signature(payload, expected_sig, secret) + assert not _verify_github_signature(payload, "sha256=invalid", secret) + assert not _verify_github_signature(payload, None, secret) + assert not _verify_github_signature(payload, "invalid-format", secret) + + def test_gitlab_token_verification(self) -> None: + from arbiter.api.routes.webhooks import _verify_gitlab_token + + assert _verify_gitlab_token("correct-token", "correct-token") + assert not _verify_gitlab_token("wrong-token", "correct-token") + assert not _verify_gitlab_token(None, "correct-token") + + +class TestApiResponseModels: + def test_finding_response_model(self) -> None: + from arbiter.api.routes.reviews import FindingResponse + + finding = FindingResponse( + id="test-id", + agent="security", + file="src/test.py", + line_start=10, + line_end=15, + severity="high", + confidence=0.9, + title="Test finding", + description="Test description", + reasoning="Test reasoning", + suggestion="Test suggestion", + references=["http://example.com"], + prompt_version="security-v1.0", + ) + + assert finding.id == "test-id" + assert finding.severity == "high" + assert finding.confidence == 0.9 + + def test_review_summary_model(self) -> None: + from datetime import UTC, datetime + + from arbiter.api.routes.reviews import ReviewSummary + + summary = ReviewSummary( + id="test-id", + repository="owner/repo", + pr_number=42, + pr_title="Test PR", + author="testuser", + status="completed", + verdict="comment", + verdict_confidence=0.75, + finding_count=5, + critical_count=0, + high_count=2, + total_cost_usd=0.015, + created_at=datetime.now(UTC), + completed_at=datetime.now(UTC), + ) + + assert summary.repository == "owner/repo" + assert summary.pr_number == 42 + assert summary.finding_count == 5 + + def test_manual_review_request_validation(self) -> None: + from arbiter.api.routes.reviews import ManualReviewRequest + + # Valid request + request = ManualReviewRequest( + repository="owner/repo", + pr_number=1, + base_sha="abc1234", + head_sha="def5678", + diff_content="test diff", + ) + assert request.repository == "owner/repo" + + # Invalid SHA (too short) + with pytest.raises(ValueError): + ManualReviewRequest( + repository="owner/repo", + pr_number=1, + base_sha="abc", # Too short + head_sha="def5678", + ) + + +class TestMockRedis: + async def test_mock_redis_basic_operations(self) -> None: + redis = MockRedis() + + # Test set and get + await redis.set("key", "value") + result = await redis.get("key") + assert result == "value" + + # Test get nonexistent key + result = await redis.get("nonexistent") + assert result is None + + # Test delete + deleted = await redis.delete("key") + assert deleted == 1 + deleted = await redis.delete("key") + assert deleted == 0 + + # Test ping + assert await redis.ping() is True + + # Test llen + length = await redis.llen("queue") + assert length == 0 + + +class TestDeliberationLogResponse: + def test_deliberation_step_response_model(self) -> None: + from datetime import UTC, datetime + + from arbiter.api.routes.reviews import DeliberationStepResponse + + step = DeliberationStepResponse( + id="step-id", + step_type="merge", + timestamp=datetime.now(UTC), + description="Merged findings", + details={"groups": 3}, + sequence=0, + ) + assert step.step_type == "merge" + assert step.sequence == 0 + + def test_deliberation_log_response_model(self) -> None: + from arbiter.api.routes.reviews import DeliberationLogResponse + + response = DeliberationLogResponse(review_id="review-123", steps=[]) + assert response.review_id == "review-123" + assert response.steps == [] + + +class TestConflictResponse: + def test_conflict_response_model(self) -> None: + from arbiter.api.routes.reviews import ConflictResponse + + conflict = ConflictResponse( + id="conflict-id", + finding_ids=["finding-1", "finding-2"], + nature="trade_off", + description="Security vs simplicity", + severity_weight=0.7, + resolution="Prefer security", + winning_finding_id="finding-1", + ) + assert conflict.nature == "trade_off" + assert len(conflict.finding_ids) == 2 + + +class TestReviewListResponse: + def test_review_list_response_model(self) -> None: + from arbiter.api.routes.reviews import ReviewListResponse + + response = ReviewListResponse( + items=[], + total=0, + page=1, + page_size=20, + pages=0, + ) + assert response.total == 0 + assert response.page == 1 + + +class TestReviewDetail: + def test_review_detail_model(self) -> None: + from datetime import UTC, datetime + + from arbiter.api.routes.reviews import ReviewDetail + + detail = ReviewDetail( + id="review-id", + repository="owner/repo", + pr_number=42, + pr_title="Test PR", + base_sha="abc1234", + head_sha="def5678", + author="testuser", + is_draft=False, + status="completed", + verdict="comment", + verdict_confidence=0.75, + verdict_reasoning="Found issues", + total_tokens=1500, + total_cost_usd=0.015, + tokens_by_agent={"security": 500}, + cost_by_agent={"security": 0.005}, + created_at=datetime.now(UTC), + started_at=datetime.now(UTC), + completed_at=datetime.now(UTC), + error_message=None, + findings=[], + conflicts=[], + ) + assert detail.repository == "owner/repo" + assert detail.verdict == "comment" + + +class TestManualReviewResponse: + def test_manual_review_response_model(self) -> None: + from arbiter.api.routes.reviews import ManualReviewResponse + + response = ManualReviewResponse( + status="queued", + job_id="job-123", + review_id=None, + message="Review queued", + ) + assert response.status == "queued" + assert response.job_id == "job-123" + + +class TestHealthMetrics: + def test_record_review_completed(self) -> None: + # This should not raise any errors + record_review_completed( + duration_seconds=10.5, + review_status="completed", + verdict="comment", + findings_by_severity={"high": 2, "medium": 3}, + tokens_in=1000, + tokens_out=500, + cost_usd=0.015, + ) + + def test_record_review_completed_no_verdict(self) -> None: + record_review_completed( + duration_seconds=5.0, + review_status="failed", + verdict=None, + findings_by_severity={}, + tokens_in=100, + tokens_out=50, + cost_usd=0.001, + ) + + +class TestHealthCheckModels: + def test_health_check_model(self) -> None: + from arbiter.api.routes.health import HealthCheck + + check = HealthCheck(status="healthy", version="0.3.0") + assert check.status == "healthy" + assert check.version == "0.3.0" + + def test_readiness_check_model(self) -> None: + from arbiter.api.routes.health import ReadinessCheck + + check = ReadinessCheck( + status="ready", + components={ + "database": {"status": "healthy"}, + "redis": {"status": "healthy"}, + }, + ) + assert check.status == "ready" + assert check.components["database"]["status"] == "healthy" + + +class TestReviewsListEndpoint: + async def test_list_reviews_empty(self, test_client: AsyncClient) -> None: + response = await test_client.get("/api/reviews") + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["total"] == 0 + assert data["page"] == 1 + assert data["pages"] == 0 + + async def test_list_reviews_with_data( + self, + test_client: AsyncClient, + sample_reviews_fixture: list[ReviewModel], # noqa: ARG002 + ) -> None: + response = await test_client.get("/api/reviews") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 5 + assert data["total"] == 5 + # Should be ordered by created_at desc + assert data["items"][0]["repository"] in ["owner/repo", "other/repo"] + + async def test_list_reviews_filter_by_repository( + self, + test_client: AsyncClient, + sample_reviews_fixture: list[ReviewModel], # noqa: ARG002 + ) -> None: + response = await test_client.get("/api/reviews?repository=owner/repo") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 3 + for item in data["items"]: + assert item["repository"] == "owner/repo" + + async def test_list_reviews_filter_by_status( + self, + test_client: AsyncClient, + sample_reviews_fixture: list[ReviewModel], # noqa: ARG002 + ) -> None: + response = await test_client.get("/api/reviews?status=completed") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 4 + for item in data["items"]: + assert item["status"] == "completed" + + async def test_list_reviews_filter_by_verdict( + self, + test_client: AsyncClient, + sample_reviews_fixture: list[ReviewModel], # noqa: ARG002 + ) -> None: + response = await test_client.get("/api/reviews?verdict=approve") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert data["items"][0]["verdict"] == "approve" + + async def test_list_reviews_pagination( + self, + test_client: AsyncClient, + sample_reviews_fixture: list[ReviewModel], # noqa: ARG002 + ) -> None: + response = await test_client.get("/api/reviews?page=1&page_size=2") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 2 + assert data["total"] == 5 + assert data["page"] == 1 + assert data["page_size"] == 2 + assert data["pages"] == 3 # 5 items / 2 per page = 3 pages + + +class TestReviewDetailEndpoint: + async def test_get_review_success( + self, test_client: AsyncClient, completed_review_fixture: ReviewModel + ) -> None: + response = await test_client.get(f"/api/reviews/{completed_review_fixture.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == completed_review_fixture.id + assert data["repository"] == "owner/repo" + assert data["pr_number"] == 42 + assert data["verdict"] == "comment" + + async def test_get_review_not_found(self, test_client: AsyncClient) -> None: + fake_id = str(uuid4()) + response = await test_client.get(f"/api/reviews/{fake_id}") + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + async def test_get_review_includes_findings( + self, test_client: AsyncClient, completed_review_fixture: ReviewModel + ) -> None: + response = await test_client.get(f"/api/reviews/{completed_review_fixture.id}") + assert response.status_code == 200 + data = response.json() + assert len(data["findings"]) == 2 + # Check finding structure + finding = data["findings"][0] + assert "id" in finding + assert "agent" in finding + assert "severity" in finding + assert "title" in finding + + +class TestMetricsEndpoint: + async def test_get_metrics_empty(self, test_client: AsyncClient) -> None: + response = await test_client.get("/api/reviews/metrics") + assert response.status_code == 200 + data = response.json() + assert data["total_reviews"] == 0 + assert data["completed_reviews"] == 0 + assert data["average_cost_usd"] == 0.0 + + async def test_get_metrics_with_reviews( + self, + test_client: AsyncClient, + sample_reviews_fixture: list[ReviewModel], # noqa: ARG002 + ) -> None: + response = await test_client.get("/api/reviews/metrics") + assert response.status_code == 200 + data = response.json() + assert data["total_reviews"] == 5 + assert data["completed_reviews"] == 4 + assert data["average_cost_usd"] > 0 + assert "verdict_counts" in data + assert "severity_counts" in data + + +class TestManualReviewEndpoint: + async def test_trigger_manual_review_missing_diff(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/api/reviews", + json={ + "repository": "owner/repo", + "pr_number": 1, + "base_sha": "abc1234", + "head_sha": "def5678", + # No diff_content + }, + ) + assert response.status_code == 400 + assert "diff_content" in response.json()["detail"].lower() + + async def test_trigger_manual_review_success(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.reviews.enqueue_review", return_value="job-123"): + response = await test_client.post( + "/api/reviews", + json={ + "repository": "owner/repo", + "pr_number": 1, + "base_sha": "abc1234", + "head_sha": "def5678", + "diff_content": "mock diff content", + }, + ) + assert response.status_code == 202 + data = response.json() + assert data["status"] == "queued" + assert data["job_id"] == "job-123" + + async def test_trigger_manual_review_duplicate(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.reviews.enqueue_review", return_value=None): + response = await test_client.post( + "/api/reviews", + json={ + "repository": "owner/repo", + "pr_number": 1, + "base_sha": "abc1234", + "head_sha": "def5678", + "diff_content": "mock diff content", + }, + ) + assert response.status_code == 202 + data = response.json() + assert data["status"] == "duplicate" + + +class TestDeliberationLogEndpoint: + async def test_get_deliberation_log_not_found(self, test_client: AsyncClient) -> None: + fake_id = str(uuid4()) + response = await test_client.get(f"/api/reviews/{fake_id}/deliberation") + assert response.status_code == 404 + + async def test_get_deliberation_log_success( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, + db_session: AsyncSession, + ) -> None: + # Add deliberation steps + step = DeliberationStepModel( + id=str(uuid4()), + review_id=completed_review_fixture.id, + step_type=StepType.MERGE, + timestamp=datetime.now(UTC), + description="Merged findings from all agents", + details={"groups": 2}, + sequence=0, + ) + db_session.add(step) + await db_session.commit() + + response = await test_client.get(f"/api/reviews/{completed_review_fixture.id}/deliberation") + assert response.status_code == 200 + data = response.json() + assert data["review_id"] == completed_review_fixture.id + assert len(data["steps"]) == 1 + assert data["steps"][0]["step_type"] == "merge" + + +class TestGitHubWebhookExpanded: + async def test_github_pr_opened(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_review", return_value="job-123"): + response = await test_client.post( + "/webhooks/github", + json={ + "action": "opened", + "pull_request": { + "number": 1, + "title": "Test PR", + "base": {"sha": "abc123"}, + "head": {"sha": "def456"}, + "user": {"login": "testuser"}, + "draft": False, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "queued" + assert data["job_id"] == "job-123" + + async def test_github_pr_synchronize(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_review", return_value="job-456"): + response = await test_client.post( + "/webhooks/github", + json={ + "action": "synchronize", + "pull_request": { + "number": 2, + "title": "Updated PR", + "base": {"sha": "abc123"}, + "head": {"sha": "newsha789"}, + "user": {"login": "testuser"}, + "draft": False, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "queued" + + async def test_github_pr_reopened(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_review", return_value="job-789"): + response = await test_client.post( + "/webhooks/github", + json={ + "action": "reopened", + "pull_request": { + "number": 3, + "title": "Reopened PR", + "base": {"sha": "abc123"}, + "head": {"sha": "def456"}, + "user": {"login": "testuser"}, + "draft": False, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "queued" + + async def test_github_pr_ignored_actions(self, test_client: AsyncClient) -> None: + for action in ["closed", "edited", "labeled", "unlabeled", "assigned"]: + response = await test_client.post( + "/webhooks/github", + json={ + "action": action, + "pull_request": { + "number": 1, + "title": "Test PR", + "base": {"sha": "abc123"}, + "head": {"sha": "def456"}, + "user": {"login": "testuser"}, + "draft": False, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + async def test_github_non_pr_events_ignored(self, test_client: AsyncClient) -> None: + for event_type in ["push", "release", "workflow_run", "issues"]: + response = await test_client.post( + "/webhooks/github", + json={"action": "created"}, + headers={"X-GitHub-Event": event_type}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + async def test_github_missing_fields(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/github", + json={ + "action": "opened", + "pull_request": { + "number": 1, + # Missing base and head SHA + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 400 + + async def test_github_duplicate_review(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_review", return_value=None): + response = await test_client.post( + "/webhooks/github", + json={ + "action": "opened", + "pull_request": { + "number": 1, + "title": "Test PR", + "base": {"sha": "abc123"}, + "head": {"sha": "def456"}, + "user": {"login": "testuser"}, + "draft": False, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "duplicate" + + +class TestGitHubCommentWebhook: + async def test_github_pr_comment_created(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_followup", return_value="followup-123"): + response = await test_client.post( + "/webhooks/github", + json={ + "action": "created", + "issue": { + "number": 42, + "pull_request": {"url": "..."}, # Indicates this is a PR + }, + "comment": { + "id": 12345, + "body": "Why is this a security issue?", + "user": {"login": "reviewer"}, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "issue_comment"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "queued" + + async def test_github_issue_comment_ignored(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/github", + json={ + "action": "created", + "issue": { + "number": 42, + # No pull_request key = regular issue + }, + "comment": { + "id": 12345, + "body": "Some comment", + "user": {"login": "reviewer"}, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "issue_comment"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + async def test_github_own_comment_ignored(self, test_client: AsyncClient) -> None: + from arbiter.integrations import ARBITER_MARKER + + response = await test_client.post( + "/webhooks/github", + json={ + "action": "created", + "issue": { + "number": 42, + "pull_request": {"url": "..."}, + }, + "comment": { + "id": 12345, + "body": f"Review results {ARBITER_MARKER}", + "user": {"login": "arbiter-bot"}, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "issue_comment"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + +class TestGitLabWebhookExpanded: + async def test_gitlab_mr_open(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_review", return_value="job-gl-123"): + response = await test_client.post( + "/webhooks/gitlab", + json={ + "object_kind": "merge_request", + "object_attributes": { + "action": "open", + "iid": 1, + "title": "Test MR", + "target_branch": "main", + "last_commit": {"id": "def456"}, + "work_in_progress": False, + }, + "project": {"path_with_namespace": "group/project"}, + "user": {"username": "developer"}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "queued" + + async def test_gitlab_mr_update(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_review", return_value="job-gl-456"): + response = await test_client.post( + "/webhooks/gitlab", + json={ + "object_kind": "merge_request", + "object_attributes": { + "action": "update", + "iid": 2, + "title": "Updated MR", + "target_branch": "develop", + "last_commit": {"id": "newsha789"}, + "work_in_progress": False, + }, + "project": {"path_with_namespace": "group/project"}, + "user": {"username": "developer"}, + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "queued" + + async def test_gitlab_mr_ignored_actions(self, test_client: AsyncClient) -> None: + for action in ["close", "merge", "approved", "unapproved"]: + response = await test_client.post( + "/webhooks/gitlab", + json={ + "object_kind": "merge_request", + "object_attributes": { + "action": action, + "iid": 1, + "title": "Test MR", + "target_branch": "main", + "last_commit": {"id": "def456"}, + }, + "project": {"path_with_namespace": "group/project"}, + "user": {"username": "developer"}, + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + async def test_gitlab_non_mr_events_ignored(self, test_client: AsyncClient) -> None: + for event_type in ["push", "issue", "pipeline", "build"]: + response = await test_client.post( + "/webhooks/gitlab", + json={"object_kind": event_type}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + +class TestGitLabNoteWebhook: + async def test_gitlab_mr_note_created(self, test_client: AsyncClient) -> None: + with patch("arbiter.api.routes.webhooks.enqueue_followup", return_value="followup-gl-123"): + response = await test_client.post( + "/webhooks/gitlab", + json={ + "object_kind": "note", + "object_attributes": { + "id": 54321, + "note": "Can you explain this finding?", + "noteable_type": "MergeRequest", + }, + "project": {"path_with_namespace": "group/project"}, + "merge_request": {"iid": 42}, + "user": {"username": "reviewer"}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "queued" + + async def test_gitlab_non_mr_note_ignored(self, test_client: AsyncClient) -> None: + for noteable_type in ["Issue", "Commit", "Snippet"]: + response = await test_client.post( + "/webhooks/gitlab", + json={ + "object_kind": "note", + "object_attributes": { + "id": 54321, + "note": "Some comment", + "noteable_type": noteable_type, + }, + "project": {"path_with_namespace": "group/project"}, + "user": {"username": "reviewer"}, + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + async def test_gitlab_own_note_ignored(self, test_client: AsyncClient) -> None: + from arbiter.integrations import ARBITER_MARKER + + response = await test_client.post( + "/webhooks/gitlab", + json={ + "object_kind": "note", + "object_attributes": { + "id": 54321, + "note": f"Review results {ARBITER_MARKER}", + "noteable_type": "MergeRequest", + }, + "project": {"path_with_namespace": "group/project"}, + "merge_request": {"iid": 42}, + "user": {"username": "arbiter-bot"}, + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + +class TestHealthEndpointsExpanded: + async def test_readiness_healthy(self, test_client: AsyncClient) -> None: + response = await test_client.get("/health/ready") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ready" + assert "database" in data["components"] + assert "redis" in data["components"] + assert data["components"]["database"]["status"] == "healthy" + + async def test_prometheus_metrics_endpoint(self, test_client: AsyncClient) -> None: + response = await test_client.get("/metrics") + assert response.status_code == 200 + # Check content type + assert "text/plain" in response.headers.get("content-type", "") + # Check for expected metric names + content = response.text + assert "arbiter_reviews_total" in content or "# HELP" in content + + async def test_readiness_check_includes_worker_status(self, test_client: AsyncClient) -> None: + response = await test_client.get("/health/ready") + assert response.status_code == 200 + data = response.json() + # Worker component should be present (may show unknown if llen fails) + assert "worker" in data["components"] + + +class TestWebhookInvalidPayload: + async def test_github_invalid_json(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/github", + content=b"not valid json", + headers={ + "X-GitHub-Event": "pull_request", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 400 + + async def test_gitlab_invalid_json(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/gitlab", + content=b"not valid json", + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 400 + + async def test_github_missing_required_fields(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/github", + json={ + "action": "opened", + "pull_request": { + "number": 1, + "title": "Test", + # Missing base, head SHAs + }, + "repository": {}, # Missing full_name + }, + headers={"X-GitHub-Event": "pull_request"}, + ) + assert response.status_code == 400 + + +class TestWebhookCommentActions: + async def test_github_comment_edited_ignored(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/github", + json={ + "action": "edited", + "issue": { + "number": 42, + "pull_request": {"url": "..."}, + }, + "comment": { + "id": 12345, + "body": "Updated comment", + "user": {"login": "reviewer"}, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "issue_comment"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + async def test_github_comment_deleted_ignored(self, test_client: AsyncClient) -> None: + response = await test_client.post( + "/webhooks/github", + json={ + "action": "deleted", + "issue": { + "number": 42, + "pull_request": {"url": "..."}, + }, + "comment": { + "id": 12345, + "body": "Deleted comment", + "user": {"login": "reviewer"}, + }, + "repository": {"full_name": "owner/repo"}, + }, + headers={"X-GitHub-Event": "issue_comment"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ignored" + + +class TestAppExceptionHandlers: + async def test_value_error_returns_400(self, test_client: AsyncClient) -> None: + # The ManualReviewRequest validation can raise ValueError + response = await test_client.post( + "/api/reviews", + json={ + "repository": "owner/repo", + "pr_number": 1, + "base_sha": "abc", # Too short, but this raises pydantic validation error + "head_sha": "def5678", + }, + ) + # Pydantic validation errors return 422, but direct ValueError returns 400 + assert response.status_code in [400, 422] + + +class TestConversationsEndpoints: + async def test_list_conversations_empty(self, test_client: AsyncClient) -> None: + response = await test_client.get("/api/conversations") + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["total"] == 0 + + async def test_list_conversations_with_data( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, + db_session: AsyncSession, + ) -> None: + # Create a conversation + conv = ConversationModel( + id=str(uuid4()), + review_id=completed_review_fixture.id, + platform="github", + repository="owner/repo", + pr_number=42, + total_tokens=100, + total_cost_usd=0.001, + ) + db_session.add(conv) + await db_session.commit() + + response = await test_client.get("/api/conversations") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 1 + assert data["items"][0]["repository"] == "owner/repo" + + async def test_list_conversations_filter_by_repository( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, + db_session: AsyncSession, + ) -> None: + # Create a conversation + conv = ConversationModel( + id=str(uuid4()), + review_id=completed_review_fixture.id, + platform="github", + repository="owner/repo", + pr_number=42, + total_tokens=100, + total_cost_usd=0.001, + ) + db_session.add(conv) + await db_session.commit() + + response = await test_client.get("/api/conversations?repository=owner/repo") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 1 + + response = await test_client.get("/api/conversations?repository=other/repo") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 0 + + async def test_list_conversations_filter_by_review_id( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, + db_session: AsyncSession, + ) -> None: + # Create a conversation + conv = ConversationModel( + id=str(uuid4()), + review_id=completed_review_fixture.id, + platform="github", + repository="owner/repo", + pr_number=42, + total_tokens=100, + total_cost_usd=0.001, + ) + db_session.add(conv) + await db_session.commit() + + response = await test_client.get( + f"/api/conversations?review_id={completed_review_fixture.id}" + ) + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 1 + + # Filter by non-existent review_id + fake_id = str(uuid4()) + response = await test_client.get(f"/api/conversations?review_id={fake_id}") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 0 + + async def test_get_conversation_not_found(self, test_client: AsyncClient) -> None: + fake_id = str(uuid4()) + response = await test_client.get(f"/api/conversations/{fake_id}") + assert response.status_code == 404 + + async def test_get_conversation_success( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, + db_session: AsyncSession, + ) -> None: + # Create a conversation with messages + conv = ConversationModel( + id=str(uuid4()), + review_id=completed_review_fixture.id, + platform="github", + repository="owner/repo", + pr_number=42, + total_tokens=100, + total_cost_usd=0.001, + ) + db_session.add(conv) + await db_session.flush() + + msg = ConversationMessageModel( + id=str(uuid4()), + conversation_id=conv.id, + role="user", + content="Why is this a security issue?", + sequence=0, + ) + db_session.add(msg) + await db_session.commit() + + response = await test_client.get(f"/api/conversations/{conv.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == conv.id + assert len(data["messages"]) == 1 + + async def test_get_conversation_for_review( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, + db_session: AsyncSession, + ) -> None: + # Create a conversation + conv = ConversationModel( + id=str(uuid4()), + review_id=completed_review_fixture.id, + platform="github", + repository="owner/repo", + pr_number=42, + total_tokens=100, + total_cost_usd=0.001, + ) + db_session.add(conv) + await db_session.commit() + + response = await test_client.get(f"/api/conversations/review/{completed_review_fixture.id}") + assert response.status_code == 200 + data = response.json() + assert data["review_id"] == completed_review_fixture.id + + async def test_get_conversation_for_review_not_found( + self, + test_client: AsyncClient, + completed_review_fixture: ReviewModel, # noqa: ARG002 + ) -> None: + # Use a different review ID + fake_id = str(uuid4()) + response = await test_client.get(f"/api/conversations/review/{fake_id}") + assert response.status_code == 200 + assert response.json() is None diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..5be99f1 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,164 @@ +"""Tests for the LLM cache module.""" + +import pytest + +from arbiter.llm.cache import LLMCache, compute_policy_hash +from arbiter.llm.client import LLMResponse + + +class TestComputePolicyHash: + def test_compute_policy_hash_deterministic(self) -> None: + policy = {"agents": {"security": {"enabled": True}}} + hash1 = compute_policy_hash(policy) + hash2 = compute_policy_hash(policy) + assert hash1 == hash2 + + def test_policy_hash_varies(self) -> None: + policy1 = {"agents": {"security": {"enabled": True}}} + policy2 = {"agents": {"security": {"enabled": False}}} + assert compute_policy_hash(policy1) != compute_policy_hash(policy2) + + def test_compute_policy_hash_format(self) -> None: + policy = {"test": "data"} + hash_value = compute_policy_hash(policy) + assert len(hash_value) == 16 + assert all(c in "0123456789abcdef" for c in hash_value) + + +class MockRedisForCache: + """Mock Redis client for cache testing.""" + + def __init__(self) -> None: + self._data: dict[str, str] = {} + + async def get(self, key: str) -> str | None: + return self._data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002 + self._data[key] = value + return True + + async def delete(self, key: str) -> int: + if key in self._data: + del self._data[key] + return 1 + return 0 + + def scan_iter(self, match: str | None = None): # noqa: ARG002 + async def _gen(): + for key in list(self._data.keys()): + yield key + + return _gen() + + +class TestLLMCache: + @pytest.fixture + def cache(self) -> LLMCache: + mock_redis = MockRedisForCache() + return LLMCache(mock_redis) # type: ignore[arg-type] + + def test_compute_key(self, cache: LLMCache) -> None: + key = cache._compute_key("diff content", "security", "v1.0", "policy123") + assert key.startswith("arbiter:llm:cache:") + assert len(key) > 20 # prefix + hash + + def test_compute_key_deterministic(self, cache: LLMCache) -> None: + key1 = cache._compute_key("diff", "security", "v1.0") + key2 = cache._compute_key("diff", "security", "v1.0") + assert key1 == key2 + + def test_compute_key_unique(self, cache: LLMCache) -> None: + key1 = cache._compute_key("diff1", "security", "v1.0") + key2 = cache._compute_key("diff2", "security", "v1.0") + key3 = cache._compute_key("diff1", "style", "v1.0") + key4 = cache._compute_key("diff1", "security", "v2.0") + + assert len({key1, key2, key3, key4}) == 4 + + def test_serialize_deserialize_response(self, cache: LLMCache) -> None: + response = LLMResponse( + content="test content", + model="gpt-4o", + tokens_in=100, + tokens_out=50, + cost_usd=0.01, + ) + + serialized = cache._serialize_response(response) + deserialized = cache._deserialize_response(serialized) + + assert deserialized.content == response.content + assert deserialized.model == response.model + assert deserialized.tokens_in == response.tokens_in + assert deserialized.tokens_out == response.tokens_out + assert deserialized.cost_usd == response.cost_usd + + async def test_cache_get_miss(self, cache: LLMCache) -> None: + result = await cache.get("diff", "security", "v1.0") + assert result is None + assert cache._misses == 1 + assert cache._hits == 0 + + async def test_cache_set_and_get(self, cache: LLMCache) -> None: + response = LLMResponse( + content="cached content", + model="gpt-4o", + tokens_in=100, + tokens_out=50, + cost_usd=0.01, + ) + + await cache.set("diff", "security", "v1.0", response) + result = await cache.get("diff", "security", "v1.0") + + assert result is not None + assert result.content == "cached content" + assert cache._hits == 1 + + async def test_cache_invalidate(self, cache: LLMCache) -> None: + response = LLMResponse( + content="test", + model="gpt-4o", + tokens_in=100, + tokens_out=50, + cost_usd=0.01, + ) + + await cache.set("diff", "security", "v1.0", response) + deleted = await cache.invalidate("diff", "security", "v1.0") + assert deleted is True + + result = await cache.get("diff", "security", "v1.0") + assert result is None + + async def test_cache_invalidate_nonexistent(self, cache: LLMCache) -> None: + deleted = await cache.invalidate("nonexistent", "security", "v1.0") + assert deleted is False + + def test_get_stats(self, cache: LLMCache) -> None: + stats = cache.get_stats() + assert stats["hits"] == 0 + assert stats["misses"] == 0 + assert stats["total"] == 0 + assert stats["hit_rate"] == 0.0 + + async def test_get_stats_after_operations(self, cache: LLMCache) -> None: + await cache.get("key1", "agent", "v1") # miss + await cache.get("key2", "agent", "v1") # miss + + response = LLMResponse( + content="test", + model="gpt-4o", + tokens_in=100, + tokens_out=50, + cost_usd=0.01, + ) + await cache.set("key1", "agent", "v1", response) + await cache.get("key1", "agent", "v1") # hit + + stats = cache.get_stats() + assert stats["hits"] == 1 + assert stats["misses"] == 2 + assert stats["total"] == 3 + assert stats["hit_rate"] == pytest.approx(1 / 3) diff --git a/tests/test_cost.py b/tests/test_cost.py new file mode 100644 index 0000000..fb75cde --- /dev/null +++ b/tests/test_cost.py @@ -0,0 +1,157 @@ +"""Tests for the cost tracking module.""" + +import pytest + +from arbiter.models.cost import AgentCost, CostEstimate, ReviewCost +from arbiter.models.enums import AgentName + + +class TestAgentCost: + def test_agent_cost_creation(self) -> None: + cost = AgentCost( + agent=AgentName.SECURITY, + tokens_in=100, + tokens_out=50, + total_tokens=150, + cost_usd=0.01, + ) + assert cost.agent == AgentName.SECURITY + assert cost.total_tokens == 150 + assert cost.cost_usd == 0.01 + + def test_agent_cost_defaults(self) -> None: + cost = AgentCost(agent=AgentName.STYLE) + assert cost.tokens_in == 0 + assert cost.tokens_out == 0 + assert cost.total_tokens == 0 + assert cost.cost_usd == 0.0 + + +class TestReviewCost: + def test_review_cost_defaults(self) -> None: + cost = ReviewCost() + assert cost.total_tokens == 0 + assert cost.total_cost_usd == 0.0 + assert cost.agent_costs == [] + assert cost.cache_hits == 0 + assert cost.cache_misses == 0 + + def test_add_agent_cost(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01) + cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008) + + assert len(cost.agent_costs) == 2 + assert cost.total_tokens_in == 180 + assert cost.total_tokens_out == 90 + assert cost.total_tokens == 270 + assert cost.total_cost_usd == pytest.approx(0.018) + + def test_add_deliberation_cost(self) -> None: + cost = ReviewCost() + cost.add_deliberation_cost(tokens_in=50, tokens_out=100, cost_usd=0.005) + + assert cost.deliberation_tokens_in == 50 + assert cost.deliberation_tokens_out == 100 + assert cost.deliberation_cost_usd == 0.005 + assert cost.total_tokens == 150 + + def test_combined_costs(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01) + cost.add_deliberation_cost(tokens_in=50, tokens_out=25, cost_usd=0.005) + + assert cost.total_tokens_in == 150 + assert cost.total_tokens_out == 75 + assert cost.total_tokens == 225 + assert cost.total_cost_usd == pytest.approx(0.015) + + def test_to_agent_dict(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01) + cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008) + + agent_dict = cost.to_agent_dict() + assert agent_dict == {"security": 150, "style": 120} + + def test_to_cost_dict(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01) + cost.add_agent_cost(AgentName.STYLE, tokens_in=80, tokens_out=40, cost_usd=0.008) + + cost_dict = cost.to_cost_dict() + assert cost_dict == {"security": 0.01, "style": 0.008} + + def test_is_within_budget_true(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=0.01) + + assert cost.is_within_budget(max_tokens=1000, max_cost_usd=0.50) is True + + def test_is_within_budget_false_tokens(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=1000, tokens_out=500, cost_usd=0.01) + + assert cost.is_within_budget(max_tokens=1000, max_cost_usd=0.50) is False + + def test_is_within_budget_false_cost(self) -> None: + cost = ReviewCost() + cost.add_agent_cost(AgentName.SECURITY, tokens_in=100, tokens_out=50, cost_usd=1.0) + + assert cost.is_within_budget(max_tokens=10000, max_cost_usd=0.50) is False + + +class TestCostEstimate: + def test_estimate_small_diff(self) -> None: + estimate = CostEstimate.estimate( + diff_size=1000, + agents=[AgentName.SECURITY, AgentName.STYLE], + model="gpt-4o-mini", + ) + + assert estimate.estimated_tokens > 0 + assert estimate.estimated_cost_usd > 0 + assert estimate.agents_enabled == [AgentName.SECURITY, AgentName.STYLE] + assert estimate.model == "gpt-4o-mini" + assert estimate.within_budget is True + + def test_estimate_large_diff(self) -> None: + estimate = CostEstimate.estimate( + diff_size=100000, + agents=[AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY], + model="gpt-4o", + max_tokens=10000, + max_cost_usd=0.10, + ) + + # Large diff with expensive model should exceed budget + assert estimate.within_budget is False + + def test_estimate_gpt4o_vs_mini(self) -> None: + estimate_4o = CostEstimate.estimate( + diff_size=10000, + agents=[AgentName.SECURITY], + model="gpt-4o", + ) + estimate_mini = CostEstimate.estimate( + diff_size=10000, + agents=[AgentName.SECURITY], + model="gpt-4o-mini", + ) + + assert estimate_4o.estimated_cost_usd > estimate_mini.estimated_cost_usd + + def test_estimate_more_agents_higher_cost(self) -> None: + estimate_one = CostEstimate.estimate( + diff_size=5000, + agents=[AgentName.SECURITY], + model="gpt-4o", + ) + estimate_three = CostEstimate.estimate( + diff_size=5000, + agents=[AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY], + model="gpt-4o", + ) + + assert estimate_three.estimated_tokens > estimate_one.estimated_tokens + assert estimate_three.estimated_cost_usd > estimate_one.estimated_cost_usd diff --git a/tests/test_db_models.py b/tests/test_db_models.py new file mode 100644 index 0000000..c782120 --- /dev/null +++ b/tests/test_db_models.py @@ -0,0 +1,222 @@ +"""Tests for database models.""" + +from uuid import uuid4 + +from arbiter.db.models import ( + Base, + ConflictModel, + DeliberationStepModel, + FindingModel, + PolicyModel, + ReviewModel, +) +from arbiter.deliberation.conflicts import ConflictNature +from arbiter.deliberation.coordinator import StepType +from arbiter.models.enums import AgentName, Severity, Verdict + + +class TestReviewModel: + def test_review_model_creation(self) -> None: + review = ReviewModel( + id=str(uuid4()), + repository="owner/repo", + pr_number=42, + pr_title="Test PR", + base_sha="abc1234567890123456789012345678901234567", + head_sha="def1234567890123456789012345678901234567", + author="testuser", + is_draft=False, + status="pending", + ) + + assert review.repository == "owner/repo" + assert review.pr_number == 42 + assert review.status == "pending" + assert review.is_draft is False + + def test_review_model_with_verdict(self) -> None: + review = ReviewModel( + id=str(uuid4()), + repository="owner/repo", + pr_number=1, + base_sha="a" * 40, + head_sha="b" * 40, + status="completed", + verdict=Verdict.COMMENT, + verdict_confidence=0.75, + verdict_reasoning="Found some issues", + ) + + assert review.verdict == Verdict.COMMENT + assert review.verdict_confidence == 0.75 + + def test_review_model_cost_tracking(self) -> None: + review = ReviewModel( + id=str(uuid4()), + repository="owner/repo", + pr_number=1, + base_sha="a" * 40, + head_sha="b" * 40, + total_tokens=1500, + total_cost_usd=0.015, + tokens_by_agent={"security": 500, "style": 500, "complexity": 500}, + cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005}, + ) + + assert review.total_tokens == 1500 + assert review.total_cost_usd == 0.015 + assert review.tokens_by_agent["security"] == 500 + + +class TestFindingModel: + def test_finding_model_creation(self) -> None: + finding = FindingModel( + id=str(uuid4()), + review_id=str(uuid4()), + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection", + description="User input concatenated in SQL", + reasoning="String concatenation allows injection", + suggestion="Use parameterized queries", + references=["https://owasp.org"], + prompt_version="security-v1.0", + ) + + assert finding.agent == AgentName.SECURITY + assert finding.severity == Severity.HIGH + assert finding.confidence == 0.9 + assert finding.line_start == 10 + assert finding.line_end == 15 + + +class TestConflictModel: + def test_conflict_model_creation(self) -> None: + conflict = ConflictModel( + id=str(uuid4()), + review_id=str(uuid4()), + finding_ids=["finding-1", "finding-2"], + nature=ConflictNature.TRADE_OFF, + description="Security vs simplicity trade-off", + severity_weight=0.7, + ) + + assert conflict.nature == ConflictNature.TRADE_OFF + assert len(conflict.finding_ids) == 2 + assert conflict.severity_weight == 0.7 + + def test_conflict_model_with_resolution(self) -> None: + conflict = ConflictModel( + id=str(uuid4()), + review_id=str(uuid4()), + finding_ids=["finding-1", "finding-2"], + nature=ConflictNature.CONTRADICTORY, + description="Opposing recommendations", + severity_weight=0.8, + resolution="Security takes precedence", + winning_finding_id="finding-1", + ) + + assert conflict.resolution is not None + assert conflict.winning_finding_id == "finding-1" + + +class TestDeliberationStepModel: + def test_deliberation_step_creation(self) -> None: + step = DeliberationStepModel( + id=str(uuid4()), + review_id=str(uuid4()), + step_type=StepType.MERGE, + description="Merged 5 findings", + details={"groups": 3, "unique": 5}, + sequence=0, + ) + + assert step.step_type == StepType.MERGE + assert step.sequence == 0 + assert step.details["groups"] == 3 + + def test_all_step_types(self) -> None: + review_id = str(uuid4()) + + steps = [ + DeliberationStepModel( + id=str(uuid4()), + review_id=review_id, + step_type=StepType.MERGE, + description="Merge step", + sequence=0, + ), + DeliberationStepModel( + id=str(uuid4()), + review_id=review_id, + step_type=StepType.CONFLICT_DETECTION, + description="Conflict detection step", + sequence=1, + ), + DeliberationStepModel( + id=str(uuid4()), + review_id=review_id, + step_type=StepType.SYNTHESIS, + description="Synthesis step", + sequence=2, + ), + DeliberationStepModel( + id=str(uuid4()), + review_id=review_id, + step_type=StepType.VERDICT, + description="Verdict step", + sequence=3, + ), + ] + + assert len(steps) == 4 + assert steps[0].step_type == StepType.MERGE + assert steps[3].step_type == StepType.VERDICT + + +class TestPolicyModel: + def test_policy_model_creation(self) -> None: + policy = PolicyModel( + id=str(uuid4()), + name="default", + organization="test-org", + description="Default policy", + is_active=True, + ) + + assert policy.name == "default" + assert policy.organization == "test-org" + assert policy.is_active is True + + def test_policy_model_with_config(self) -> None: + policy = PolicyModel( + id=str(uuid4()), + name="strict", + agents_config={ + "security": {"enabled": True, "model": "gpt-4o"}, + "style": {"enabled": True}, + "complexity": {"enabled": False}, + }, + cost_controls={ + "max_tokens": 50000, + "max_cost_usd": 0.50, + }, + verdict_thresholds={ + "critical_threshold": 1, + "high_threshold": 3, + }, + ) + + assert policy.agents_config["security"]["model"] == "gpt-4o" + assert policy.cost_controls["max_tokens"] == 50000 + + +class TestBase: + def test_base_is_declarative_base(self) -> None: + assert hasattr(Base, "metadata") + assert hasattr(Base, "registry") diff --git a/tests/test_worker.py b/tests/test_worker.py new file mode 100644 index 0000000..5056558 --- /dev/null +++ b/tests/test_worker.py @@ -0,0 +1,959 @@ +"""Tests for the worker module.""" + +from datetime import UTC, datetime +from pathlib import Path +from typing import Any +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from arbiter.db.models import ( + ConflictModel, + DeliberationStepModel, + FindingModel, + ReviewModel, +) +from arbiter.integrations import ARBITER_MARKER +from arbiter.integrations.base import Comment, CommitStatus, Platform +from arbiter.models.enums import AgentName, Severity, Verdict +from arbiter.worker.queue import JobPriority, cancel_job, generate_job_id, get_job_status +from arbiter.worker.tasks import ( + _post_or_update_comment, + _verdict_to_status, + detect_platform, + get_platform_client, + process_followup, + process_review, +) +from tests.conftest import MockPlatformClient + + +class TestJobQueue: + def test_generate_job_id_deterministic(self) -> None: + id1 = generate_job_id("owner/repo", 42, "abc123") + id2 = generate_job_id("owner/repo", 42, "abc123") + assert id1 == id2 + + def test_job_id_unique(self) -> None: + id1 = generate_job_id("owner/repo", 42, "abc123") + id2 = generate_job_id("owner/repo", 42, "def456") # Different SHA + id3 = generate_job_id("owner/repo", 43, "abc123") # Different PR + id4 = generate_job_id("other/repo", 42, "abc123") # Different repo + + assert len({id1, id2, id3, id4}) == 4 # All unique + + def test_generate_job_id_format(self) -> None: + job_id = generate_job_id("owner/repo", 42, "abc123") + assert len(job_id) == 16 + assert all(c in "0123456789abcdef" for c in job_id) + + def test_job_priority_ordering(self) -> None: + assert JobPriority.HIGH < JobPriority.NORMAL < JobPriority.LOW + assert int(JobPriority.HIGH) == 1 + assert int(JobPriority.NORMAL) == 2 + assert int(JobPriority.LOW) == 3 + + +class TestWorkerSettings: + def test_worker_settings_has_functions(self) -> None: + from arbiter.worker.settings import WorkerSettings + + assert WorkerSettings.functions is not None + assert len(WorkerSettings.functions) > 0 + + def test_worker_settings_has_cron_jobs(self) -> None: + from arbiter.worker.settings import WorkerSettings + + assert WorkerSettings.cron_jobs is not None + + def test_worker_settings_lifecycle_hooks(self) -> None: + from arbiter.worker.settings import WorkerSettings + + assert WorkerSettings.on_startup is not None + assert WorkerSettings.on_shutdown is not None + + +class TestReviewTask: + @pytest.fixture + def mock_context(self) -> dict[str, Any]: + return { + "settings": None, + "redis": None, + } + + async def test_review_task_requires_diff(self, mock_context: dict[str, Any]) -> None: # noqa: ARG002 + from arbiter.worker.tasks import process_review + + # Note: This would need database fixtures to run fully + # For now, we just verify the function signature + assert callable(process_review) + + +class MockRedisForQueue: + """Mock Redis with enqueue_job support.""" + + def __init__(self) -> None: + self._data: dict[str, Any] = {} + self._jobs: list[dict[str, Any]] = [] + + async def get(self, key: str) -> str | None: + return self._data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002 + self._data[key] = value + return True + + async def delete(self, key: str) -> int: + if key in self._data: + del self._data[key] + return 1 + return 0 + + async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: + job = {"func": func_name, "kwargs": kwargs} + self._jobs.append(job) + return type("Job", (), {"job_id": kwargs.get("_job_id", "test-id")})() + + +class TestEnqueueReview: + @pytest.fixture + def mock_redis_pool(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue: + mock = MockRedisForQueue() + + async def get_pool() -> MockRedisForQueue: + return mock + + monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool) + return mock + + async def test_enqueue_review_creates_job(self, mock_redis_pool: MockRedisForQueue) -> None: + from arbiter.worker.queue import enqueue_review + + job_id = await enqueue_review( + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + pr_title="Test PR", + author="testuser", + is_draft=False, + ) + + assert job_id is not None + assert len(mock_redis_pool._jobs) == 1 + assert mock_redis_pool._jobs[0]["func"] == "process_review" + + async def test_enqueue_review_deduplication( + self, + mock_redis_pool: MockRedisForQueue, # noqa: ARG002 + ) -> None: + from arbiter.worker.queue import enqueue_review + + # First call should succeed + job_id1 = await enqueue_review( + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + ) + assert job_id1 is not None + + # Second call with same params should be deduplicated + job_id2 = await enqueue_review( + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + ) + assert job_id2 is None + + async def test_enqueue_review_draft_lower_priority( + self, mock_redis_pool: MockRedisForQueue + ) -> None: + from arbiter.worker.queue import enqueue_review + + await enqueue_review( + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + is_draft=True, + ) + + assert len(mock_redis_pool._jobs) == 1 + job = mock_redis_pool._jobs[0] + # Draft PRs should be in the low priority queue + assert "arbiter:queue:3" in job["kwargs"]["_queue_name"] + + +class TestJobStatusAndCancel: + @pytest.fixture + def mock_redis_pool_with_jobs(self, monkeypatch: pytest.MonkeyPatch) -> MockRedisForQueue: + mock = MockRedisForQueue() + mock._data["arbiter:job:test-job-id"] = "pending" + + async def get_pool() -> MockRedisForQueue: + return mock + + monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool) + return mock + + async def test_get_job_status_found( + self, + mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 + ) -> None: + status = await get_job_status("test-job-id") + assert status is not None + assert status["job_id"] == "test-job-id" + assert status["status"] == "pending" + + async def test_get_job_status_not_found( + self, + mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 + ) -> None: + status = await get_job_status("nonexistent") + assert status is None + + async def test_cancel_job_success( + self, + mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 + ) -> None: + result = await cancel_job("test-job-id") + assert result is True + + async def test_cancel_job_not_found( + self, + mock_redis_pool_with_jobs: MockRedisForQueue, # noqa: ARG002 + ) -> None: + result = await cancel_job("nonexistent") + assert result is False + + +class TestWorkerStartupShutdown: + async def test_startup_hook(self) -> None: + from unittest.mock import AsyncMock, patch + + from arbiter.worker.settings import startup + + # Mock init_db and get_settings + with ( + patch("arbiter.worker.settings.init_db", new_callable=AsyncMock) as mock_init, + patch("arbiter.worker.settings.get_settings") as mock_settings, + ): + mock_settings.return_value = "mock_settings" + ctx: dict[str, Any] = {} + await startup(ctx) + + mock_init.assert_called_once() + assert ctx["settings"] == "mock_settings" + + async def test_shutdown_hook(self) -> None: + from unittest.mock import AsyncMock, patch + + from arbiter.worker.settings import shutdown + + with patch("arbiter.worker.settings.close_db", new_callable=AsyncMock) as mock_close: + ctx: dict[str, Any] = {} + await shutdown(ctx) + mock_close.assert_called_once() + + async def test_health_check(self) -> None: + from arbiter.worker.settings import health_check + + ctx: dict[str, Any] = {} + result = await health_check(ctx) + assert result == "healthy" + + def test_worker_settings_redis_settings(self) -> None: + from arbiter.worker.settings import WorkerSettings + + redis_settings = WorkerSettings.redis_settings() + assert redis_settings is not None + + def test_worker_settings_get_functions(self) -> None: + from arbiter.worker.settings import WorkerSettings + + functions = WorkerSettings._get_functions() + assert len(functions) == 2 + # Verify the functions are the expected ones + func_names = [f.__name__ for f in functions] + assert "process_review" in func_names + assert "process_followup" in func_names + + +class TestDetectPlatform: + def test_detect_platform_from_webhook_github(self) -> None: + platform = detect_platform("owner/repo", "github") + assert platform == Platform.GITHUB + + def test_detect_platform_from_webhook_gitlab(self) -> None: + platform = detect_platform("owner/repo", "gitlab") + assert platform == Platform.GITLAB + + def test_detect_platform_case_insensitive(self) -> None: + assert detect_platform("owner/repo", "GITHUB") == Platform.GITHUB + assert detect_platform("owner/repo", "GitLab") == Platform.GITLAB + + def test_detect_platform_defaults_to_github(self) -> None: + platform = detect_platform("owner/repo") + assert platform == Platform.GITHUB + + +class TestGetPlatformClient: + def test_get_platform_client_github_no_token(self, mock_settings_no_github: Any) -> None: + client = get_platform_client(Platform.GITHUB, mock_settings_no_github) + assert client is None + + def test_get_platform_client_gitlab_no_token(self, mock_settings_no_github: Any) -> None: + client = get_platform_client(Platform.GITLAB, mock_settings_no_github) + assert client is None + + def test_github_client_with_token(self, mock_settings: Any) -> None: + from arbiter.integrations import GitHubClient + + client = get_platform_client(Platform.GITHUB, mock_settings) + assert client is not None + assert isinstance(client, GitHubClient) + + def test_gitlab_client_with_token(self, mock_settings: Any) -> None: + from arbiter.integrations import GitLabClient + + client = get_platform_client(Platform.GITLAB, mock_settings) + assert client is not None + assert isinstance(client, GitLabClient) + + +class TestVerdictToStatus: + def test_approve_returns_success(self) -> None: + assert _verdict_to_status(Verdict.APPROVE) == CommitStatus.SUCCESS + + def test_request_changes_returns_failure(self) -> None: + assert _verdict_to_status(Verdict.REQUEST_CHANGES) == CommitStatus.FAILURE + + def test_comment_returns_success(self) -> None: + assert _verdict_to_status(Verdict.COMMENT) == CommitStatus.SUCCESS + + +class TestPostOrUpdateComment: + async def test_post_new_comment(self, mock_platform_client: MockPlatformClient) -> None: + body = f"Test comment {ARBITER_MARKER}" + url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) + + assert url is not None + assert "owner/repo" in url + assert len(mock_platform_client._posted_comments) == 1 + assert mock_platform_client._posted_comments[0]["body"] == body + + async def test_update_existing_comment(self, mock_platform_client: MockPlatformClient) -> None: + # Add an existing Arbiter comment + mock_platform_client._comments = [ + Comment( + id="existing-123", + body=f"Old review {ARBITER_MARKER}", + author="arbiter-bot", + url="https://github.com/owner/repo/pull/42#comment-existing-123", + created_at=datetime.now(UTC), + ) + ] + + body = f"Updated review {ARBITER_MARKER}" + url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) + + assert url is not None + assert len(mock_platform_client._posted_comments) == 1 + # Should be an update, not a new post + assert mock_platform_client._posted_comments[0].get("comment_id") == "existing-123" + + async def test_fallback_on_fetch_failure( + self, mock_platform_client: MockPlatformClient + ) -> None: + mock_platform_client._fail_on.add("get_comments") + + body = f"Test comment {ARBITER_MARKER}" + url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) + + # Should still post a new comment + assert url is not None + assert len(mock_platform_client._posted_comments) == 1 + # Should be a new post since fetching failed + assert mock_platform_client._posted_comments[0].get("comment_id") is None + + async def test_returns_none_on_post_failure( + self, mock_platform_client: MockPlatformClient + ) -> None: + mock_platform_client._fail_on.add("post_comment") + + body = f"Test comment {ARBITER_MARKER}" + url = await _post_or_update_comment(mock_platform_client, "owner/repo", 42, body) + + assert url is None + + +class TestProcessReview: + @pytest.fixture + def mock_deliberation_result(self) -> Any: + from arbiter.deliberation import DeliberationResult, DeliberationStep + from arbiter.deliberation.conflicts import Conflict, ConflictNature + from arbiter.deliberation.coordinator import StepType + from arbiter.models import Finding + + finding = Finding( + id=str(uuid4()), + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection", + description="User input concatenated into SQL", + reasoning="Allows SQL injection attacks", + prompt_version="security-v1.0", + ) + + return DeliberationResult( + verdict=Verdict.COMMENT, + verdict_confidence=0.75, + verdict_reasoning="Found security issues", + findings=[finding], + conflicts=[ + Conflict( + id="conflict-1", + finding_ids=["f1", "f2"], + nature=ConflictNature.TRADE_OFF, + description="Trade-off detected", + severity_weight=0.5, + ) + ], + steps=[ + DeliberationStep( + step_type=StepType.MERGE, + timestamp=datetime.now(UTC), + description="Merged findings", + details={"count": 1}, + ) + ], + tokens_used=500, + cost_usd=0.005, + ) + + @pytest.fixture + def mock_review_result(self) -> Any: + from arbiter.models import ReviewResult + + return ReviewResult( + agent_name=AgentName.SECURITY, + findings=[], + duration_ms=1000, + tokens_used=500, + cost_usd=0.005, + ) + + async def test_process_review_creates_review_record( + self, + db_session: AsyncSession, + mock_deliberation_result: Any, + mock_review_result: Any, + ) -> None: + # Mock the review pipeline + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + return_value=([mock_review_result], mock_deliberation_result), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + review_id = await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + pr_title="Test PR", + author="testuser", + diff_content="mock diff", + ) + + # Verify review was created + result = await db_session.execute( + select(ReviewModel).where(ReviewModel.id == review_id) + ) + review = result.scalar_one() + + assert review.repository == "owner/repo" + assert review.pr_number == 42 + assert review.status == "completed" + assert review.verdict == Verdict.COMMENT + + async def test_process_review_stores_findings( + self, + db_session: AsyncSession, + mock_deliberation_result: Any, + mock_review_result: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + return_value=([mock_review_result], mock_deliberation_result), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + review_id = await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + diff_content="mock diff", + ) + + # Verify findings were stored + result = await db_session.execute( + select(FindingModel).where(FindingModel.review_id == review_id) + ) + findings = result.scalars().all() + + assert len(findings) == 1 + assert findings[0].title == "SQL Injection" + assert findings[0].severity == Severity.HIGH + + async def test_process_review_stores_conflicts( + self, + db_session: AsyncSession, + mock_deliberation_result: Any, + mock_review_result: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + return_value=([mock_review_result], mock_deliberation_result), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + review_id = await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + diff_content="mock diff", + ) + + # Verify conflicts were stored + result = await db_session.execute( + select(ConflictModel).where(ConflictModel.review_id == review_id) + ) + conflicts = result.scalars().all() + + assert len(conflicts) == 1 + assert conflicts[0].description == "Trade-off detected" + + async def test_process_review_stores_deliberation_steps( + self, + db_session: AsyncSession, + mock_deliberation_result: Any, + mock_review_result: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + return_value=([mock_review_result], mock_deliberation_result), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + review_id = await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + diff_content="mock diff", + ) + + # Verify deliberation steps were stored + result = await db_session.execute( + select(DeliberationStepModel).where(DeliberationStepModel.review_id == review_id) + ) + steps = result.scalars().all() + + assert len(steps) == 1 + assert steps[0].description == "Merged findings" + + async def test_process_review_handles_errors( + self, + db_session: AsyncSession, + ) -> None: + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + side_effect=ValueError("Test error"), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Test error"): + await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + diff_content="mock diff", + ) + + # Verify review was marked as failed + result = await db_session.execute( + select(ReviewModel).where(ReviewModel.repository == "owner/repo") + ) + review = result.scalar_one() + + assert review.status == "failed" + assert "Test error" in (review.error_message or "") + + async def test_process_review_posts_comment( + self, + db_session: AsyncSession, + mock_platform_client: MockPlatformClient, + mock_deliberation_result: Any, + mock_review_result: Any, + mock_settings: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + return_value=([mock_review_result], mock_deliberation_result), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=mock_platform_client, + ), + patch( + "arbiter.worker.tasks.get_settings", + return_value=mock_settings, + ), + ): + await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + diff_content="mock diff", + platform="github", + ) + + # Verify comment was posted + assert len(mock_platform_client._posted_comments) == 1 + assert ARBITER_MARKER in mock_platform_client._posted_comments[0]["body"] + + async def test_process_review_updates_status( + self, + db_session: AsyncSession, + mock_platform_client: MockPlatformClient, + mock_deliberation_result: Any, + mock_review_result: Any, + mock_settings: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks._run_review_pipeline", + return_value=([mock_review_result], mock_deliberation_result), + ), + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=mock_platform_client, + ), + patch( + "arbiter.worker.tasks.get_settings", + return_value=mock_settings, + ), + ): + await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + diff_content="mock diff", + platform="github", + ) + + # Verify status was updated (pending then final) + assert len(mock_platform_client._status_updates) >= 1 + # Last update should be the final status + final_update = mock_platform_client._status_updates[-1] + assert final_update["status"] == CommitStatus.SUCCESS + + async def test_process_review_requires_diff( + self, + db_session: AsyncSession, + ) -> None: + with ( + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + pytest.raises(ValueError, match="diff_content not provided"), + ): + await process_review( + {}, + repository="owner/repo", + pr_number=42, + base_sha="abc123", + head_sha="def456", + # No diff_content + ) + + +class TestProcessFollowup: + async def test_process_followup_no_review( + self, + db_session: AsyncSession, + mock_settings: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_settings", + return_value=mock_settings, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + result = await process_followup( + {}, + repository="owner/repo", + pr_number=999, # Non-existent PR + comment_id="comment-123", + comment_body="Why is this a security issue?", + author="testuser", + platform="github", + ) + + assert result is None + + async def test_process_followup_disabled( + self, + db_session: AsyncSession, + completed_review_fixture: ReviewModel, # noqa: ARG002 + ) -> None: + class DisabledSettings: + followup_enabled = False + + with ( + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_settings", + return_value=DisabledSettings(), + ), + ): + result = await process_followup( + {}, + repository="owner/repo", + pr_number=42, + comment_id="comment-123", + comment_body="Why is this a security issue?", + author="testuser", + platform="github", + ) + + assert result is None + + async def test_process_followup_not_a_question( + self, + db_session: AsyncSession, + completed_review_fixture: ReviewModel, # noqa: ARG002 + mock_settings: Any, + ) -> None: + with ( + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_settings", + return_value=mock_settings, + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + result = await process_followup( + {}, + repository="owner/repo", + pr_number=42, + comment_id="comment-123", + comment_body="This looks good to me.", # Not a question + author="testuser", + platform="github", + ) + + assert result is None + + async def test_process_followup_low_confidence( + self, + db_session: AsyncSession, + completed_review_fixture: ReviewModel, # noqa: ARG002 + ) -> None: + class HighThresholdSettings: + followup_enabled = True + followup_confidence_threshold = 0.99 # Very high threshold + llm_timeout = 60 + llm_max_retries = 3 + templates_dir = Path("templates") + post_comments = False + + with ( + patch( + "arbiter.worker.tasks.async_session_factory", + return_value=lambda: db_session, + ), + patch( + "arbiter.worker.tasks.get_settings", + return_value=HighThresholdSettings(), + ), + patch( + "arbiter.worker.tasks.get_platform_client", + return_value=None, + ), + ): + result = await process_followup( + {}, + repository="owner/repo", + pr_number=42, + comment_id="comment-123", + comment_body="What does this mean?", # A question but low confidence + author="testuser", + platform="github", + ) + + assert result is None + + +class TestEnqueueFollowup: + @pytest.fixture + def mock_redis_pool_followup(self, monkeypatch: pytest.MonkeyPatch) -> "MockRedisForQueue": + mock = MockRedisForQueue() + + async def get_pool() -> MockRedisForQueue: + return mock + + monkeypatch.setattr("arbiter.worker.queue.get_redis_pool", get_pool) + return mock + + async def test_enqueue_followup_creates_job( + self, mock_redis_pool_followup: "MockRedisForQueue" + ) -> None: + from arbiter.worker.queue import enqueue_followup + + job_id = await enqueue_followup( + repository="owner/repo", + pr_number=42, + comment_id="comment-123", + comment_body="Why is this a security issue?", + author="testuser", + platform="github", + ) + + assert job_id is not None + assert len(mock_redis_pool_followup._jobs) == 1 + assert mock_redis_pool_followup._jobs[0]["func"] == "process_followup" + + async def test_enqueue_followup_deduplication( + self, mock_redis_pool_followup: "MockRedisForQueue" + ) -> None: + from arbiter.worker.queue import enqueue_followup, generate_followup_job_id + + # Pre-set the job as existing + job_id = generate_followup_job_id("owner/repo", 42, "comment-123") + mock_redis_pool_followup._data[f"arbiter:followup:{job_id}"] = "pending" + + result = await enqueue_followup( + repository="owner/repo", + pr_number=42, + comment_id="comment-123", + comment_body="Why is this a security issue?", + author="testuser", + platform="github", + ) + + assert result is None + # No new job should be added + assert len(mock_redis_pool_followup._jobs) == 0 + + +class TestGenerateFollowupJobId: + def test_followup_job_id_stable(self) -> None: + from arbiter.worker.queue import generate_followup_job_id + + id1 = generate_followup_job_id("owner/repo", 42, "comment-123") + id2 = generate_followup_job_id("owner/repo", 42, "comment-123") + assert id1 == id2 + + def test_generate_followup_job_id_unique(self) -> None: + from arbiter.worker.queue import generate_followup_job_id + + id1 = generate_followup_job_id("owner/repo", 42, "comment-123") + id2 = generate_followup_job_id("owner/repo", 42, "comment-456") + id3 = generate_followup_job_id("owner/repo", 43, "comment-123") + id4 = generate_followup_job_id("other/repo", 42, "comment-123") + + assert len({id1, id2, id3, id4}) == 4