conversation detection tests

This commit is contained in:
2025-05-27 20:13:02 +00:00
parent f31272736e
commit e4f814efff
3 changed files with 638 additions and 0 deletions

View File

@@ -0,0 +1,152 @@
"""Tests for question detection in PR comments."""
import pytest
from arbiter.conversation.detection import QuestionDetector
from arbiter.models import AgentName, Finding, Severity
@pytest.fixture
def detector() -> QuestionDetector:
return QuestionDetector(confidence_threshold=0.5)
@pytest.fixture
def sample_findings() -> list[Finding]:
return [
Finding(
id="f1234567-1234-1234-1234-123456789abc",
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection vulnerability",
description="User input directly concatenated into SQL query",
reasoning="String concatenation allows SQL injection",
prompt_version="security-v1.0",
),
Finding(
id="f2345678-2345-2345-2345-234567890bcd",
agent=AgentName.STYLE,
file="src/auth.py",
line_start=20,
line_end=25,
severity=Severity.LOW,
confidence=0.8,
title="Inconsistent naming convention",
description="Variable name does not follow snake_case",
reasoning="PEP 8 recommends snake_case for variables",
prompt_version="style-v1.0",
),
]
class TestQuestionDetection:
def test_detects_simple_question(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Why is this a problem?")
assert analysis.is_question is True
def test_detects_question_with_why(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Why did you flag this line")
assert analysis.is_question is True
def test_detects_question_with_how(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("How can I fix this issue?")
assert analysis.is_question is True
def test_detects_question_with_what(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("What does this mean?")
assert analysis.is_question is True
def test_detects_explain_request(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Please explain this finding")
assert analysis.is_question is True
def test_statement_not_detected(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("I fixed the issue.")
assert analysis.is_question is False
def test_empty_string(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("")
assert analysis.is_question is False
class TestArbiterDirected:
def test_at_arbiter_mention(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("@arbiter Why is this flagged?")
assert analysis.is_directed_at_arbiter is True
def test_arbiter_keyword(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Arbiter, can you explain?")
assert analysis.is_directed_at_arbiter is True
def test_not_directed(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Why is this a problem?")
assert analysis.is_directed_at_arbiter is False
class TestAgentMentions:
def test_security_keywords(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Why is this a security vulnerability?")
assert AgentName.SECURITY in analysis.mentioned_agents
def test_style_keywords(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Is this naming convention wrong?")
assert AgentName.STYLE in analysis.mentioned_agents
def test_complexity_keywords(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("How can I refactor to reduce complexity?")
assert AgentName.COMPLEXITY in analysis.mentioned_agents
def test_multiple_agents(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Is this a security vulnerability or just a style issue?")
assert AgentName.SECURITY in analysis.mentioned_agents
assert AgentName.STYLE in analysis.mentioned_agents
class TestFindingReferences:
def test_extract_finding_id(
self, detector: QuestionDetector, sample_findings: list[Finding]
) -> None:
analysis = detector.analyze(
"Can you explain finding f1234567-1234-1234-1234-123456789abc?",
findings=sample_findings,
)
assert "f1234567-1234-1234-1234-123456789abc" in analysis.mentioned_finding_ids
def test_invalid_finding_id(
self, detector: QuestionDetector, sample_findings: list[Finding]
) -> None:
analysis = detector.analyze(
"What about finding 00000000-0000-0000-0000-000000000000?",
findings=sample_findings,
)
assert len(analysis.mentioned_finding_ids) == 0
class TestConfidence:
def test_high_confidence_directed_question(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("@arbiter Why is this a security issue?")
assert analysis.confidence >= 0.8
def test_lower_confidence_generic_question(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Why is this flagged?")
assert 0.4 <= analysis.confidence < 0.8
def test_zero_confidence_non_question(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("I fixed the issue.")
assert analysis.confidence == 0.0
class TestQuestionTextExtraction:
def test_removes_at_mentions(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("@arbiter @someone Why is this wrong?")
assert "@arbiter" not in analysis.question_text
assert "@someone" not in analysis.question_text
assert "Why is this wrong?" in analysis.question_text
def test_collapses_whitespace(self, detector: QuestionDetector) -> None:
analysis = detector.analyze("Why is this wrong?")
assert " " not in analysis.question_text

View File

@@ -0,0 +1,223 @@
"""Tests for agent routing in conversations."""
from pathlib import Path
import pytest
from arbiter.conversation.detection import QuestionAnalysis
from arbiter.conversation.router import AgentRouter
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import AgentName, Finding, Severity
from tests.conftest import MockLLMClient
@pytest.fixture
def mock_llm() -> MockLLMClient:
return MockLLMClient()
@pytest.fixture
def prompt_registry_with_explain(tmp_path: Path) -> PromptRegistry:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
# Create review prompts
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
(templates_dir / "style-v1.0.md").write_text("Style review: {{diff}}")
(templates_dir / "complexity-v1.0.md").write_text("Complexity review: {{diff}}")
# Create explain prompts
(templates_dir / "security-explain-v1.0.md").write_text(
"Explain security: {{question}} {{finding_title}}"
)
(templates_dir / "style-explain-v1.0.md").write_text(
"Explain style: {{question}} {{finding_title}}"
)
(templates_dir / "complexity-explain-v1.0.md").write_text(
"Explain complexity: {{question}} {{finding_title}}"
)
return PromptRegistry(templates_dir)
@pytest.fixture
def router(mock_llm: MockLLMClient, prompt_registry_with_explain: PromptRegistry) -> AgentRouter:
return AgentRouter(mock_llm, prompt_registry_with_explain)
@pytest.fixture
def sample_findings() -> list[Finding]:
return [
Finding(
id="sec-finding-1",
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection vulnerability",
description="User input directly concatenated into SQL query",
reasoning="String concatenation allows SQL injection",
prompt_version="security-v1.0",
),
Finding(
id="style-finding-1",
agent=AgentName.STYLE,
file="src/auth.py",
line_start=20,
line_end=25,
severity=Severity.LOW,
confidence=0.8,
title="Inconsistent naming",
description="Variable name does not follow snake_case",
reasoning="PEP 8 recommends snake_case",
prompt_version="style-v1.0",
),
Finding(
id="sec-finding-2",
agent=AgentName.SECURITY,
file="src/api.py",
line_start=30,
line_end=35,
severity=Severity.CRITICAL,
confidence=0.95,
title="Hardcoded credentials",
description="API key is hardcoded in source",
reasoning="Credentials should not be in source code",
prompt_version="security-v1.0",
),
]
class TestRoutingPriority:
def test_routes_to_specific_finding(
self, router: AgentRouter, sample_findings: list[Finding]
) -> None:
analysis = QuestionAnalysis(
is_question=True,
is_directed_at_arbiter=True,
question_text="Why is this a problem?",
mentioned_agents=[],
mentioned_finding_ids=["sec-finding-1"],
confidence=0.9,
)
routes = router.route(analysis, sample_findings)
assert len(routes) == 1
assert routes[0].agent_name == AgentName.SECURITY
assert routes[0].finding is not None
assert routes[0].finding.id == "sec-finding-1"
def test_routes_to_mentioned_agents(
self, router: AgentRouter, sample_findings: list[Finding]
) -> None:
analysis = QuestionAnalysis(
is_question=True,
is_directed_at_arbiter=True,
question_text="What about the style issue?",
mentioned_agents=[AgentName.STYLE],
mentioned_finding_ids=[],
confidence=0.8,
)
routes = router.route(analysis, sample_findings)
assert len(routes) == 1
assert routes[0].agent_name == AgentName.STYLE
def test_routes_to_highest_severity_finding(
self, router: AgentRouter, sample_findings: list[Finding]
) -> None:
analysis = QuestionAnalysis(
is_question=True,
is_directed_at_arbiter=True,
question_text="Explain the security issues",
mentioned_agents=[AgentName.SECURITY],
mentioned_finding_ids=[],
confidence=0.8,
)
routes = router.route(analysis, sample_findings)
assert len(routes) == 1
# Should route to critical finding, not high
assert routes[0].finding is not None
assert routes[0].finding.severity == Severity.CRITICAL
def test_routes_to_all_agents_by_default(
self, router: AgentRouter, sample_findings: list[Finding]
) -> None:
analysis = QuestionAnalysis(
is_question=True,
is_directed_at_arbiter=True,
question_text="Can you explain these findings?",
mentioned_agents=[],
mentioned_finding_ids=[],
confidence=0.7,
)
routes = router.route(analysis, sample_findings)
# Should have routes for security and style (agents with findings)
agent_names = [r.agent_name for r in routes]
assert AgentName.SECURITY in agent_names
assert AgentName.STYLE in agent_names
class TestRouteResult:
def test_route_result_with_finding(
self, router: AgentRouter, sample_findings: list[Finding]
) -> None:
analysis = QuestionAnalysis(
is_question=True,
is_directed_at_arbiter=True,
question_text="Why?",
mentioned_agents=[AgentName.SECURITY],
mentioned_finding_ids=[],
confidence=0.8,
)
routes = router.route(analysis, sample_findings)
assert routes[0].finding is not None
assert routes[0].priority == 0
def test_route_without_finding(
self, router: AgentRouter, sample_findings: list[Finding]
) -> None:
# Request complexity agent but no complexity findings exist
analysis = QuestionAnalysis(
is_question=True,
is_directed_at_arbiter=True,
question_text="What about complexity?",
mentioned_agents=[AgentName.COMPLEXITY],
mentioned_finding_ids=[],
confidence=0.8,
)
routes = router.route(analysis, sample_findings)
assert len(routes) == 1
assert routes[0].agent_name == AgentName.COMPLEXITY
assert routes[0].finding is None
class TestAgentInstantiation:
def test_get_security_agent(self, router: AgentRouter) -> None:
agent = router.get_agent(AgentName.SECURITY)
assert agent.name == AgentName.SECURITY
def test_get_style_agent(self, router: AgentRouter) -> None:
agent = router.get_agent(AgentName.STYLE)
assert agent.name == AgentName.STYLE
def test_get_complexity_agent(self, router: AgentRouter) -> None:
agent = router.get_agent(AgentName.COMPLEXITY)
assert agent.name == AgentName.COMPLEXITY
def test_agent_caching(self, router: AgentRouter) -> None:
agent1 = router.get_agent(AgentName.SECURITY)
agent2 = router.get_agent(AgentName.SECURITY)
assert agent1 is agent2

263
tests/test_followup_flow.py Normal file
View File

@@ -0,0 +1,263 @@
"""Integration tests for the follow-up conversation flow."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from arbiter.db.models import (
ConversationMessageModel,
ConversationModel,
FindingModel,
ReviewModel,
)
from arbiter.models.enums import Severity, Verdict
@pytest.fixture
async def review_with_findings(db_session: AsyncSession) -> ReviewModel:
review = ReviewModel(
repository="owner/repo",
pr_number=42,
pr_title="Test PR",
base_sha="abc123",
head_sha="def456",
author="testuser",
status="completed",
verdict=Verdict.COMMENT,
verdict_confidence=0.8,
verdict_reasoning="Minor issues found",
completed_at=datetime.now(UTC),
)
db_session.add(review)
await db_session.flush()
# Add findings
finding1 = FindingModel(
review_id=review.id,
agent="security",
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection vulnerability",
description="User input concatenated into SQL",
reasoning="String concatenation allows injection",
prompt_version="security-v1.0",
)
finding2 = FindingModel(
review_id=review.id,
agent="style",
file="src/utils.py",
line_start=20,
line_end=25,
severity=Severity.LOW,
confidence=0.8,
title="Naming convention",
description="Variable uses camelCase",
reasoning="PEP 8 recommends snake_case",
prompt_version="style-v1.0",
)
db_session.add_all([finding1, finding2])
await db_session.commit()
return review
class TestConversationModels:
async def test_create_conversation(
self, db_session: AsyncSession, review_with_findings: ReviewModel
) -> None:
conversation = ConversationModel(
review_id=review_with_findings.id,
platform="github",
repository="owner/repo",
pr_number=42,
)
db_session.add(conversation)
await db_session.commit()
# Verify it was saved
result = await db_session.execute(
select(ConversationModel).where(ConversationModel.review_id == review_with_findings.id)
)
saved = result.scalar_one()
assert saved.platform == "github"
assert saved.pr_number == 42
async def test_add_messages_to_conversation(
self, db_session: AsyncSession, review_with_findings: ReviewModel
) -> None:
conversation = ConversationModel(
review_id=review_with_findings.id,
platform="github",
repository="owner/repo",
pr_number=42,
)
db_session.add(conversation)
await db_session.flush()
# Add user message
user_msg = ConversationMessageModel(
conversation_id=conversation.id,
role="user",
platform_comment_id="123",
author="testuser",
content="Why is this a security issue?",
sequence=0,
)
db_session.add(user_msg)
# Add assistant message
assistant_msg = ConversationMessageModel(
conversation_id=conversation.id,
role="assistant",
content="This is a SQL injection vulnerability...",
responding_agents=["security"],
tokens_used=150,
cost_usd=0.002,
sequence=1,
)
db_session.add(assistant_msg)
await db_session.commit()
# Verify messages
result = await db_session.execute(
select(ConversationMessageModel)
.where(ConversationMessageModel.conversation_id == conversation.id)
.order_by(ConversationMessageModel.sequence)
)
messages = result.scalars().all()
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[1].role == "assistant"
async def test_conversation_cost_tracking(
self, db_session: AsyncSession, review_with_findings: ReviewModel
) -> None:
conversation = ConversationModel(
review_id=review_with_findings.id,
platform="github",
repository="owner/repo",
pr_number=42,
total_tokens=500,
total_cost_usd=0.01,
)
db_session.add(conversation)
await db_session.commit()
# Verify totals
result = await db_session.execute(
select(ConversationModel).where(ConversationModel.id == conversation.id)
)
saved = result.scalar_one()
assert saved.total_tokens == 500
assert saved.total_cost_usd == 0.01
class TestWebhookCommentHandling:
@patch("arbiter.api.routes.webhooks.enqueue_followup", new_callable=AsyncMock)
async def test_github_comment_webhook(self, mock_enqueue, test_client) -> None:
mock_enqueue.return_value = "test-job-id"
payload = {
"action": "created",
"issue": {
"number": 42,
"pull_request": {"url": "https://api.github.com/..."},
},
"comment": {
"id": 123456,
"body": "@arbiter Why is this flagged?",
"user": {"login": "testuser"},
},
"repository": {"full_name": "owner/repo"},
}
response = await test_client.post(
"/webhooks/github",
json=payload,
headers={"X-GitHub-Event": "issue_comment"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "queued"
assert data["job_id"] == "test-job-id"
mock_enqueue.assert_called_once()
async def test_github_ignores_arbiter_comments(self, test_client) -> None:
# Include the Arbiter marker in the comment
payload = {
"action": "created",
"issue": {
"number": 42,
"pull_request": {"url": "https://api.github.com/..."},
},
"comment": {
"id": 123456,
"body": "Here is my explanation...\n\n<!-- arbiter-review -->",
"user": {"login": "arbiter-bot"},
},
"repository": {"full_name": "owner/repo"},
}
response = await test_client.post(
"/webhooks/github",
json=payload,
headers={"X-GitHub-Event": "issue_comment"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "ignored"
assert "own comment" in data.get("reason", "")
async def test_github_ignores_non_pr_comments(self, test_client) -> None:
payload = {
"action": "created",
"issue": {
"number": 42,
# No pull_request field = regular issue
},
"comment": {
"id": 123456,
"body": "Why is this happening?",
"user": {"login": "testuser"},
},
"repository": {"full_name": "owner/repo"},
}
response = await test_client.post(
"/webhooks/github",
json=payload,
headers={"X-GitHub-Event": "issue_comment"},
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "ignored"
assert "Not a pull request" in data.get("reason", "")
class TestConversationAPI:
async def test_list_conversations_empty(self, test_client) -> None:
response = await test_client.get("/api/conversations")
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["total"] == 0
async def test_get_conversation_not_found(self, test_client) -> None:
response = await test_client.get("/api/conversations/00000000-0000-0000-0000-000000000000")
assert response.status_code == 404
async def test_get_conversation_for_review_none(self, test_client) -> None:
response = await test_client.get(
"/api/conversations/review/00000000-0000-0000-0000-000000000000"
)
assert response.status_code == 200
assert response.json() is None