diff --git a/tests/test_conversation_detection.py b/tests/test_conversation_detection.py new file mode 100644 index 0000000..9d92697 --- /dev/null +++ b/tests/test_conversation_detection.py @@ -0,0 +1,152 @@ +"""Tests for question detection in PR comments.""" + +import pytest + +from arbiter.conversation.detection import QuestionDetector +from arbiter.models import AgentName, Finding, Severity + + +@pytest.fixture +def detector() -> QuestionDetector: + return QuestionDetector(confidence_threshold=0.5) + + +@pytest.fixture +def sample_findings() -> list[Finding]: + return [ + Finding( + id="f1234567-1234-1234-1234-123456789abc", + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection vulnerability", + description="User input directly concatenated into SQL query", + reasoning="String concatenation allows SQL injection", + prompt_version="security-v1.0", + ), + Finding( + id="f2345678-2345-2345-2345-234567890bcd", + agent=AgentName.STYLE, + file="src/auth.py", + line_start=20, + line_end=25, + severity=Severity.LOW, + confidence=0.8, + title="Inconsistent naming convention", + description="Variable name does not follow snake_case", + reasoning="PEP 8 recommends snake_case for variables", + prompt_version="style-v1.0", + ), + ] + + +class TestQuestionDetection: + def test_detects_simple_question(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Why is this a problem?") + assert analysis.is_question is True + + def test_detects_question_with_why(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Why did you flag this line") + assert analysis.is_question is True + + def test_detects_question_with_how(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("How can I fix this issue?") + assert analysis.is_question is True + + def test_detects_question_with_what(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("What does this mean?") + assert analysis.is_question is True + + def test_detects_explain_request(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Please explain this finding") + assert analysis.is_question is True + + def test_statement_not_detected(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("I fixed the issue.") + assert analysis.is_question is False + + def test_empty_string(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("") + assert analysis.is_question is False + + +class TestArbiterDirected: + def test_at_arbiter_mention(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("@arbiter Why is this flagged?") + assert analysis.is_directed_at_arbiter is True + + def test_arbiter_keyword(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Arbiter, can you explain?") + assert analysis.is_directed_at_arbiter is True + + def test_not_directed(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Why is this a problem?") + assert analysis.is_directed_at_arbiter is False + + +class TestAgentMentions: + def test_security_keywords(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Why is this a security vulnerability?") + assert AgentName.SECURITY in analysis.mentioned_agents + + def test_style_keywords(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Is this naming convention wrong?") + assert AgentName.STYLE in analysis.mentioned_agents + + def test_complexity_keywords(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("How can I refactor to reduce complexity?") + assert AgentName.COMPLEXITY in analysis.mentioned_agents + + def test_multiple_agents(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Is this a security vulnerability or just a style issue?") + assert AgentName.SECURITY in analysis.mentioned_agents + assert AgentName.STYLE in analysis.mentioned_agents + + +class TestFindingReferences: + def test_extract_finding_id( + self, detector: QuestionDetector, sample_findings: list[Finding] + ) -> None: + analysis = detector.analyze( + "Can you explain finding f1234567-1234-1234-1234-123456789abc?", + findings=sample_findings, + ) + assert "f1234567-1234-1234-1234-123456789abc" in analysis.mentioned_finding_ids + + def test_invalid_finding_id( + self, detector: QuestionDetector, sample_findings: list[Finding] + ) -> None: + analysis = detector.analyze( + "What about finding 00000000-0000-0000-0000-000000000000?", + findings=sample_findings, + ) + assert len(analysis.mentioned_finding_ids) == 0 + + +class TestConfidence: + def test_high_confidence_directed_question(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("@arbiter Why is this a security issue?") + assert analysis.confidence >= 0.8 + + def test_lower_confidence_generic_question(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Why is this flagged?") + assert 0.4 <= analysis.confidence < 0.8 + + def test_zero_confidence_non_question(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("I fixed the issue.") + assert analysis.confidence == 0.0 + + +class TestQuestionTextExtraction: + def test_removes_at_mentions(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("@arbiter @someone Why is this wrong?") + assert "@arbiter" not in analysis.question_text + assert "@someone" not in analysis.question_text + assert "Why is this wrong?" in analysis.question_text + + def test_collapses_whitespace(self, detector: QuestionDetector) -> None: + analysis = detector.analyze("Why is this wrong?") + assert " " not in analysis.question_text diff --git a/tests/test_conversation_router.py b/tests/test_conversation_router.py new file mode 100644 index 0000000..0018b05 --- /dev/null +++ b/tests/test_conversation_router.py @@ -0,0 +1,223 @@ +"""Tests for agent routing in conversations.""" + +from pathlib import Path + +import pytest + +from arbiter.conversation.detection import QuestionAnalysis +from arbiter.conversation.router import AgentRouter +from arbiter.llm.prompts import PromptRegistry +from arbiter.models import AgentName, Finding, Severity +from tests.conftest import MockLLMClient + + +@pytest.fixture +def mock_llm() -> MockLLMClient: + return MockLLMClient() + + +@pytest.fixture +def prompt_registry_with_explain(tmp_path: Path) -> PromptRegistry: + templates_dir = tmp_path / "templates" + templates_dir.mkdir() + + # Create review prompts + (templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}") + (templates_dir / "style-v1.0.md").write_text("Style review: {{diff}}") + (templates_dir / "complexity-v1.0.md").write_text("Complexity review: {{diff}}") + + # Create explain prompts + (templates_dir / "security-explain-v1.0.md").write_text( + "Explain security: {{question}} {{finding_title}}" + ) + (templates_dir / "style-explain-v1.0.md").write_text( + "Explain style: {{question}} {{finding_title}}" + ) + (templates_dir / "complexity-explain-v1.0.md").write_text( + "Explain complexity: {{question}} {{finding_title}}" + ) + + return PromptRegistry(templates_dir) + + +@pytest.fixture +def router(mock_llm: MockLLMClient, prompt_registry_with_explain: PromptRegistry) -> AgentRouter: + return AgentRouter(mock_llm, prompt_registry_with_explain) + + +@pytest.fixture +def sample_findings() -> list[Finding]: + return [ + Finding( + id="sec-finding-1", + agent=AgentName.SECURITY, + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection vulnerability", + description="User input directly concatenated into SQL query", + reasoning="String concatenation allows SQL injection", + prompt_version="security-v1.0", + ), + Finding( + id="style-finding-1", + agent=AgentName.STYLE, + file="src/auth.py", + line_start=20, + line_end=25, + severity=Severity.LOW, + confidence=0.8, + title="Inconsistent naming", + description="Variable name does not follow snake_case", + reasoning="PEP 8 recommends snake_case", + prompt_version="style-v1.0", + ), + Finding( + id="sec-finding-2", + agent=AgentName.SECURITY, + file="src/api.py", + line_start=30, + line_end=35, + severity=Severity.CRITICAL, + confidence=0.95, + title="Hardcoded credentials", + description="API key is hardcoded in source", + reasoning="Credentials should not be in source code", + prompt_version="security-v1.0", + ), + ] + + +class TestRoutingPriority: + def test_routes_to_specific_finding( + self, router: AgentRouter, sample_findings: list[Finding] + ) -> None: + analysis = QuestionAnalysis( + is_question=True, + is_directed_at_arbiter=True, + question_text="Why is this a problem?", + mentioned_agents=[], + mentioned_finding_ids=["sec-finding-1"], + confidence=0.9, + ) + + routes = router.route(analysis, sample_findings) + + assert len(routes) == 1 + assert routes[0].agent_name == AgentName.SECURITY + assert routes[0].finding is not None + assert routes[0].finding.id == "sec-finding-1" + + def test_routes_to_mentioned_agents( + self, router: AgentRouter, sample_findings: list[Finding] + ) -> None: + analysis = QuestionAnalysis( + is_question=True, + is_directed_at_arbiter=True, + question_text="What about the style issue?", + mentioned_agents=[AgentName.STYLE], + mentioned_finding_ids=[], + confidence=0.8, + ) + + routes = router.route(analysis, sample_findings) + + assert len(routes) == 1 + assert routes[0].agent_name == AgentName.STYLE + + def test_routes_to_highest_severity_finding( + self, router: AgentRouter, sample_findings: list[Finding] + ) -> None: + analysis = QuestionAnalysis( + is_question=True, + is_directed_at_arbiter=True, + question_text="Explain the security issues", + mentioned_agents=[AgentName.SECURITY], + mentioned_finding_ids=[], + confidence=0.8, + ) + + routes = router.route(analysis, sample_findings) + + assert len(routes) == 1 + # Should route to critical finding, not high + assert routes[0].finding is not None + assert routes[0].finding.severity == Severity.CRITICAL + + def test_routes_to_all_agents_by_default( + self, router: AgentRouter, sample_findings: list[Finding] + ) -> None: + analysis = QuestionAnalysis( + is_question=True, + is_directed_at_arbiter=True, + question_text="Can you explain these findings?", + mentioned_agents=[], + mentioned_finding_ids=[], + confidence=0.7, + ) + + routes = router.route(analysis, sample_findings) + + # Should have routes for security and style (agents with findings) + agent_names = [r.agent_name for r in routes] + assert AgentName.SECURITY in agent_names + assert AgentName.STYLE in agent_names + + +class TestRouteResult: + def test_route_result_with_finding( + self, router: AgentRouter, sample_findings: list[Finding] + ) -> None: + analysis = QuestionAnalysis( + is_question=True, + is_directed_at_arbiter=True, + question_text="Why?", + mentioned_agents=[AgentName.SECURITY], + mentioned_finding_ids=[], + confidence=0.8, + ) + + routes = router.route(analysis, sample_findings) + + assert routes[0].finding is not None + assert routes[0].priority == 0 + + def test_route_without_finding( + self, router: AgentRouter, sample_findings: list[Finding] + ) -> None: + # Request complexity agent but no complexity findings exist + analysis = QuestionAnalysis( + is_question=True, + is_directed_at_arbiter=True, + question_text="What about complexity?", + mentioned_agents=[AgentName.COMPLEXITY], + mentioned_finding_ids=[], + confidence=0.8, + ) + + routes = router.route(analysis, sample_findings) + + assert len(routes) == 1 + assert routes[0].agent_name == AgentName.COMPLEXITY + assert routes[0].finding is None + + +class TestAgentInstantiation: + def test_get_security_agent(self, router: AgentRouter) -> None: + agent = router.get_agent(AgentName.SECURITY) + assert agent.name == AgentName.SECURITY + + def test_get_style_agent(self, router: AgentRouter) -> None: + agent = router.get_agent(AgentName.STYLE) + assert agent.name == AgentName.STYLE + + def test_get_complexity_agent(self, router: AgentRouter) -> None: + agent = router.get_agent(AgentName.COMPLEXITY) + assert agent.name == AgentName.COMPLEXITY + + def test_agent_caching(self, router: AgentRouter) -> None: + agent1 = router.get_agent(AgentName.SECURITY) + agent2 = router.get_agent(AgentName.SECURITY) + assert agent1 is agent2 diff --git a/tests/test_followup_flow.py b/tests/test_followup_flow.py new file mode 100644 index 0000000..1741dea --- /dev/null +++ b/tests/test_followup_flow.py @@ -0,0 +1,263 @@ +"""Integration tests for the follow-up conversation flow.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from arbiter.db.models import ( + ConversationMessageModel, + ConversationModel, + FindingModel, + ReviewModel, +) +from arbiter.models.enums import Severity, Verdict + + +@pytest.fixture +async def review_with_findings(db_session: AsyncSession) -> ReviewModel: + review = ReviewModel( + repository="owner/repo", + pr_number=42, + pr_title="Test PR", + base_sha="abc123", + head_sha="def456", + author="testuser", + status="completed", + verdict=Verdict.COMMENT, + verdict_confidence=0.8, + verdict_reasoning="Minor issues found", + completed_at=datetime.now(UTC), + ) + db_session.add(review) + await db_session.flush() + + # Add findings + finding1 = FindingModel( + review_id=review.id, + agent="security", + file="src/auth.py", + line_start=10, + line_end=15, + severity=Severity.HIGH, + confidence=0.9, + title="SQL Injection vulnerability", + description="User input concatenated into SQL", + reasoning="String concatenation allows injection", + prompt_version="security-v1.0", + ) + finding2 = FindingModel( + review_id=review.id, + agent="style", + file="src/utils.py", + line_start=20, + line_end=25, + severity=Severity.LOW, + confidence=0.8, + title="Naming convention", + description="Variable uses camelCase", + reasoning="PEP 8 recommends snake_case", + prompt_version="style-v1.0", + ) + db_session.add_all([finding1, finding2]) + await db_session.commit() + + return review + + +class TestConversationModels: + async def test_create_conversation( + self, db_session: AsyncSession, review_with_findings: ReviewModel + ) -> None: + conversation = ConversationModel( + review_id=review_with_findings.id, + platform="github", + repository="owner/repo", + pr_number=42, + ) + db_session.add(conversation) + await db_session.commit() + + # Verify it was saved + result = await db_session.execute( + select(ConversationModel).where(ConversationModel.review_id == review_with_findings.id) + ) + saved = result.scalar_one() + assert saved.platform == "github" + assert saved.pr_number == 42 + + async def test_add_messages_to_conversation( + self, db_session: AsyncSession, review_with_findings: ReviewModel + ) -> None: + conversation = ConversationModel( + review_id=review_with_findings.id, + platform="github", + repository="owner/repo", + pr_number=42, + ) + db_session.add(conversation) + await db_session.flush() + + # Add user message + user_msg = ConversationMessageModel( + conversation_id=conversation.id, + role="user", + platform_comment_id="123", + author="testuser", + content="Why is this a security issue?", + sequence=0, + ) + db_session.add(user_msg) + + # Add assistant message + assistant_msg = ConversationMessageModel( + conversation_id=conversation.id, + role="assistant", + content="This is a SQL injection vulnerability...", + responding_agents=["security"], + tokens_used=150, + cost_usd=0.002, + sequence=1, + ) + db_session.add(assistant_msg) + await db_session.commit() + + # Verify messages + result = await db_session.execute( + select(ConversationMessageModel) + .where(ConversationMessageModel.conversation_id == conversation.id) + .order_by(ConversationMessageModel.sequence) + ) + messages = result.scalars().all() + assert len(messages) == 2 + assert messages[0].role == "user" + assert messages[1].role == "assistant" + + async def test_conversation_cost_tracking( + self, db_session: AsyncSession, review_with_findings: ReviewModel + ) -> None: + conversation = ConversationModel( + review_id=review_with_findings.id, + platform="github", + repository="owner/repo", + pr_number=42, + total_tokens=500, + total_cost_usd=0.01, + ) + db_session.add(conversation) + await db_session.commit() + + # Verify totals + result = await db_session.execute( + select(ConversationModel).where(ConversationModel.id == conversation.id) + ) + saved = result.scalar_one() + assert saved.total_tokens == 500 + assert saved.total_cost_usd == 0.01 + + +class TestWebhookCommentHandling: + @patch("arbiter.api.routes.webhooks.enqueue_followup", new_callable=AsyncMock) + async def test_github_comment_webhook(self, mock_enqueue, test_client) -> None: + mock_enqueue.return_value = "test-job-id" + + payload = { + "action": "created", + "issue": { + "number": 42, + "pull_request": {"url": "https://api.github.com/..."}, + }, + "comment": { + "id": 123456, + "body": "@arbiter Why is this flagged?", + "user": {"login": "testuser"}, + }, + "repository": {"full_name": "owner/repo"}, + } + + response = await test_client.post( + "/webhooks/github", + json=payload, + headers={"X-GitHub-Event": "issue_comment"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "queued" + assert data["job_id"] == "test-job-id" + mock_enqueue.assert_called_once() + + async def test_github_ignores_arbiter_comments(self, test_client) -> None: + # Include the Arbiter marker in the comment + payload = { + "action": "created", + "issue": { + "number": 42, + "pull_request": {"url": "https://api.github.com/..."}, + }, + "comment": { + "id": 123456, + "body": "Here is my explanation...\n\n", + "user": {"login": "arbiter-bot"}, + }, + "repository": {"full_name": "owner/repo"}, + } + + response = await test_client.post( + "/webhooks/github", + json=payload, + headers={"X-GitHub-Event": "issue_comment"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ignored" + assert "own comment" in data.get("reason", "") + + async def test_github_ignores_non_pr_comments(self, test_client) -> None: + payload = { + "action": "created", + "issue": { + "number": 42, + # No pull_request field = regular issue + }, + "comment": { + "id": 123456, + "body": "Why is this happening?", + "user": {"login": "testuser"}, + }, + "repository": {"full_name": "owner/repo"}, + } + + response = await test_client.post( + "/webhooks/github", + json=payload, + headers={"X-GitHub-Event": "issue_comment"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ignored" + assert "Not a pull request" in data.get("reason", "") + + +class TestConversationAPI: + async def test_list_conversations_empty(self, test_client) -> None: + response = await test_client.get("/api/conversations") + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["total"] == 0 + + async def test_get_conversation_not_found(self, test_client) -> None: + response = await test_client.get("/api/conversations/00000000-0000-0000-0000-000000000000") + assert response.status_code == 404 + + async def test_get_conversation_for_review_none(self, test_client) -> None: + response = await test_client.get( + "/api/conversations/review/00000000-0000-0000-0000-000000000000" + ) + assert response.status_code == 200 + assert response.json() is None