conversation detection tests
This commit is contained in:
263
tests/test_followup_flow.py
Normal file
263
tests/test_followup_flow.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Integration tests for the follow-up conversation flow."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from arbiter.db.models import (
|
||||
ConversationMessageModel,
|
||||
ConversationModel,
|
||||
FindingModel,
|
||||
ReviewModel,
|
||||
)
|
||||
from arbiter.models.enums import Severity, Verdict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def review_with_findings(db_session: AsyncSession) -> ReviewModel:
|
||||
review = ReviewModel(
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
pr_title="Test PR",
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
author="testuser",
|
||||
status="completed",
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.8,
|
||||
verdict_reasoning="Minor issues found",
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
db_session.add(review)
|
||||
await db_session.flush()
|
||||
|
||||
# Add findings
|
||||
finding1 = FindingModel(
|
||||
review_id=review.id,
|
||||
agent="security",
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection vulnerability",
|
||||
description="User input concatenated into SQL",
|
||||
reasoning="String concatenation allows injection",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
finding2 = FindingModel(
|
||||
review_id=review.id,
|
||||
agent="style",
|
||||
file="src/utils.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.LOW,
|
||||
confidence=0.8,
|
||||
title="Naming convention",
|
||||
description="Variable uses camelCase",
|
||||
reasoning="PEP 8 recommends snake_case",
|
||||
prompt_version="style-v1.0",
|
||||
)
|
||||
db_session.add_all([finding1, finding2])
|
||||
await db_session.commit()
|
||||
|
||||
return review
|
||||
|
||||
|
||||
class TestConversationModels:
|
||||
async def test_create_conversation(
|
||||
self, db_session: AsyncSession, review_with_findings: ReviewModel
|
||||
) -> None:
|
||||
conversation = ConversationModel(
|
||||
review_id=review_with_findings.id,
|
||||
platform="github",
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify it was saved
|
||||
result = await db_session.execute(
|
||||
select(ConversationModel).where(ConversationModel.review_id == review_with_findings.id)
|
||||
)
|
||||
saved = result.scalar_one()
|
||||
assert saved.platform == "github"
|
||||
assert saved.pr_number == 42
|
||||
|
||||
async def test_add_messages_to_conversation(
|
||||
self, db_session: AsyncSession, review_with_findings: ReviewModel
|
||||
) -> None:
|
||||
conversation = ConversationModel(
|
||||
review_id=review_with_findings.id,
|
||||
platform="github",
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
await db_session.flush()
|
||||
|
||||
# Add user message
|
||||
user_msg = ConversationMessageModel(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
platform_comment_id="123",
|
||||
author="testuser",
|
||||
content="Why is this a security issue?",
|
||||
sequence=0,
|
||||
)
|
||||
db_session.add(user_msg)
|
||||
|
||||
# Add assistant message
|
||||
assistant_msg = ConversationMessageModel(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content="This is a SQL injection vulnerability...",
|
||||
responding_agents=["security"],
|
||||
tokens_used=150,
|
||||
cost_usd=0.002,
|
||||
sequence=1,
|
||||
)
|
||||
db_session.add(assistant_msg)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify messages
|
||||
result = await db_session.execute(
|
||||
select(ConversationMessageModel)
|
||||
.where(ConversationMessageModel.conversation_id == conversation.id)
|
||||
.order_by(ConversationMessageModel.sequence)
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].role == "user"
|
||||
assert messages[1].role == "assistant"
|
||||
|
||||
async def test_conversation_cost_tracking(
|
||||
self, db_session: AsyncSession, review_with_findings: ReviewModel
|
||||
) -> None:
|
||||
conversation = ConversationModel(
|
||||
review_id=review_with_findings.id,
|
||||
platform="github",
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
total_tokens=500,
|
||||
total_cost_usd=0.01,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
await db_session.commit()
|
||||
|
||||
# Verify totals
|
||||
result = await db_session.execute(
|
||||
select(ConversationModel).where(ConversationModel.id == conversation.id)
|
||||
)
|
||||
saved = result.scalar_one()
|
||||
assert saved.total_tokens == 500
|
||||
assert saved.total_cost_usd == 0.01
|
||||
|
||||
|
||||
class TestWebhookCommentHandling:
|
||||
@patch("arbiter.api.routes.webhooks.enqueue_followup", new_callable=AsyncMock)
|
||||
async def test_github_comment_webhook(self, mock_enqueue, test_client) -> None:
|
||||
mock_enqueue.return_value = "test-job-id"
|
||||
|
||||
payload = {
|
||||
"action": "created",
|
||||
"issue": {
|
||||
"number": 42,
|
||||
"pull_request": {"url": "https://api.github.com/..."},
|
||||
},
|
||||
"comment": {
|
||||
"id": 123456,
|
||||
"body": "@arbiter Why is this flagged?",
|
||||
"user": {"login": "testuser"},
|
||||
},
|
||||
"repository": {"full_name": "owner/repo"},
|
||||
}
|
||||
|
||||
response = await test_client.post(
|
||||
"/webhooks/github",
|
||||
json=payload,
|
||||
headers={"X-GitHub-Event": "issue_comment"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["job_id"] == "test-job-id"
|
||||
mock_enqueue.assert_called_once()
|
||||
|
||||
async def test_github_ignores_arbiter_comments(self, test_client) -> None:
|
||||
# Include the Arbiter marker in the comment
|
||||
payload = {
|
||||
"action": "created",
|
||||
"issue": {
|
||||
"number": 42,
|
||||
"pull_request": {"url": "https://api.github.com/..."},
|
||||
},
|
||||
"comment": {
|
||||
"id": 123456,
|
||||
"body": "Here is my explanation...\n\n<!-- arbiter-review -->",
|
||||
"user": {"login": "arbiter-bot"},
|
||||
},
|
||||
"repository": {"full_name": "owner/repo"},
|
||||
}
|
||||
|
||||
response = await test_client.post(
|
||||
"/webhooks/github",
|
||||
json=payload,
|
||||
headers={"X-GitHub-Event": "issue_comment"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ignored"
|
||||
assert "own comment" in data.get("reason", "")
|
||||
|
||||
async def test_github_ignores_non_pr_comments(self, test_client) -> None:
|
||||
payload = {
|
||||
"action": "created",
|
||||
"issue": {
|
||||
"number": 42,
|
||||
# No pull_request field = regular issue
|
||||
},
|
||||
"comment": {
|
||||
"id": 123456,
|
||||
"body": "Why is this happening?",
|
||||
"user": {"login": "testuser"},
|
||||
},
|
||||
"repository": {"full_name": "owner/repo"},
|
||||
}
|
||||
|
||||
response = await test_client.post(
|
||||
"/webhooks/github",
|
||||
json=payload,
|
||||
headers={"X-GitHub-Event": "issue_comment"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ignored"
|
||||
assert "Not a pull request" in data.get("reason", "")
|
||||
|
||||
|
||||
class TestConversationAPI:
|
||||
async def test_list_conversations_empty(self, test_client) -> None:
|
||||
response = await test_client.get("/api/conversations")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["items"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
async def test_get_conversation_not_found(self, test_client) -> None:
|
||||
response = await test_client.get("/api/conversations/00000000-0000-0000-0000-000000000000")
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_get_conversation_for_review_none(self, test_client) -> None:
|
||||
response = await test_client.get(
|
||||
"/api/conversations/review/00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
Reference in New Issue
Block a user