264 lines
8.4 KiB
Python
264 lines
8.4 KiB
Python
"""Integration tests for the follow-up conversation flow."""
|
|
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from arbiter.db.models import (
|
|
ConversationMessageModel,
|
|
ConversationModel,
|
|
FindingModel,
|
|
ReviewModel,
|
|
)
|
|
from arbiter.models.enums import Severity, Verdict
|
|
|
|
|
|
@pytest.fixture
|
|
async def review_with_findings(db_session: AsyncSession) -> ReviewModel:
|
|
review = ReviewModel(
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
pr_title="Test PR",
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
author="testuser",
|
|
status="completed",
|
|
verdict=Verdict.COMMENT,
|
|
verdict_confidence=0.8,
|
|
verdict_reasoning="Minor issues found",
|
|
completed_at=datetime.now(UTC),
|
|
)
|
|
db_session.add(review)
|
|
await db_session.flush()
|
|
|
|
# Add findings
|
|
finding1 = FindingModel(
|
|
review_id=review.id,
|
|
agent="security",
|
|
file="src/auth.py",
|
|
line_start=10,
|
|
line_end=15,
|
|
severity=Severity.HIGH,
|
|
confidence=0.9,
|
|
title="SQL Injection vulnerability",
|
|
description="User input concatenated into SQL",
|
|
reasoning="String concatenation allows injection",
|
|
prompt_version="security-v1.0",
|
|
)
|
|
finding2 = FindingModel(
|
|
review_id=review.id,
|
|
agent="style",
|
|
file="src/utils.py",
|
|
line_start=20,
|
|
line_end=25,
|
|
severity=Severity.LOW,
|
|
confidence=0.8,
|
|
title="Naming convention",
|
|
description="Variable uses camelCase",
|
|
reasoning="PEP 8 recommends snake_case",
|
|
prompt_version="style-v1.0",
|
|
)
|
|
db_session.add_all([finding1, finding2])
|
|
await db_session.commit()
|
|
|
|
return review
|
|
|
|
|
|
class TestConversationModels:
|
|
async def test_create_conversation(
|
|
self, db_session: AsyncSession, review_with_findings: ReviewModel
|
|
) -> None:
|
|
conversation = ConversationModel(
|
|
review_id=review_with_findings.id,
|
|
platform="github",
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
)
|
|
db_session.add(conversation)
|
|
await db_session.commit()
|
|
|
|
# Verify it was saved
|
|
result = await db_session.execute(
|
|
select(ConversationModel).where(ConversationModel.review_id == review_with_findings.id)
|
|
)
|
|
saved = result.scalar_one()
|
|
assert saved.platform == "github"
|
|
assert saved.pr_number == 42
|
|
|
|
async def test_add_messages_to_conversation(
|
|
self, db_session: AsyncSession, review_with_findings: ReviewModel
|
|
) -> None:
|
|
conversation = ConversationModel(
|
|
review_id=review_with_findings.id,
|
|
platform="github",
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
)
|
|
db_session.add(conversation)
|
|
await db_session.flush()
|
|
|
|
# Add user message
|
|
user_msg = ConversationMessageModel(
|
|
conversation_id=conversation.id,
|
|
role="user",
|
|
platform_comment_id="123",
|
|
author="testuser",
|
|
content="Why is this a security issue?",
|
|
sequence=0,
|
|
)
|
|
db_session.add(user_msg)
|
|
|
|
# Add assistant message
|
|
assistant_msg = ConversationMessageModel(
|
|
conversation_id=conversation.id,
|
|
role="assistant",
|
|
content="This is a SQL injection vulnerability...",
|
|
responding_agents=["security"],
|
|
tokens_used=150,
|
|
cost_usd=0.002,
|
|
sequence=1,
|
|
)
|
|
db_session.add(assistant_msg)
|
|
await db_session.commit()
|
|
|
|
# Verify messages
|
|
result = await db_session.execute(
|
|
select(ConversationMessageModel)
|
|
.where(ConversationMessageModel.conversation_id == conversation.id)
|
|
.order_by(ConversationMessageModel.sequence)
|
|
)
|
|
messages = result.scalars().all()
|
|
assert len(messages) == 2
|
|
assert messages[0].role == "user"
|
|
assert messages[1].role == "assistant"
|
|
|
|
async def test_conversation_cost_tracking(
|
|
self, db_session: AsyncSession, review_with_findings: ReviewModel
|
|
) -> None:
|
|
conversation = ConversationModel(
|
|
review_id=review_with_findings.id,
|
|
platform="github",
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
total_tokens=500,
|
|
total_cost_usd=0.01,
|
|
)
|
|
db_session.add(conversation)
|
|
await db_session.commit()
|
|
|
|
# Verify totals
|
|
result = await db_session.execute(
|
|
select(ConversationModel).where(ConversationModel.id == conversation.id)
|
|
)
|
|
saved = result.scalar_one()
|
|
assert saved.total_tokens == 500
|
|
assert saved.total_cost_usd == 0.01
|
|
|
|
|
|
class TestWebhookCommentHandling:
|
|
@patch("arbiter.api.routes.webhooks.enqueue_followup", new_callable=AsyncMock)
|
|
async def test_github_comment_webhook(self, mock_enqueue, test_client) -> None:
|
|
mock_enqueue.return_value = "test-job-id"
|
|
|
|
payload = {
|
|
"action": "created",
|
|
"issue": {
|
|
"number": 42,
|
|
"pull_request": {"url": "https://api.github.com/..."},
|
|
},
|
|
"comment": {
|
|
"id": 123456,
|
|
"body": "@arbiter Why is this flagged?",
|
|
"user": {"login": "testuser"},
|
|
},
|
|
"repository": {"full_name": "owner/repo"},
|
|
}
|
|
|
|
response = await test_client.post(
|
|
"/webhooks/github",
|
|
json=payload,
|
|
headers={"X-GitHub-Event": "issue_comment"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "queued"
|
|
assert data["job_id"] == "test-job-id"
|
|
mock_enqueue.assert_called_once()
|
|
|
|
async def test_github_ignores_arbiter_comments(self, test_client) -> None:
|
|
# Include the Arbiter marker in the comment
|
|
payload = {
|
|
"action": "created",
|
|
"issue": {
|
|
"number": 42,
|
|
"pull_request": {"url": "https://api.github.com/..."},
|
|
},
|
|
"comment": {
|
|
"id": 123456,
|
|
"body": "Here is my explanation...\n\n<!-- arbiter-review -->",
|
|
"user": {"login": "arbiter-bot"},
|
|
},
|
|
"repository": {"full_name": "owner/repo"},
|
|
}
|
|
|
|
response = await test_client.post(
|
|
"/webhooks/github",
|
|
json=payload,
|
|
headers={"X-GitHub-Event": "issue_comment"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ignored"
|
|
assert "own comment" in data.get("reason", "")
|
|
|
|
async def test_github_ignores_non_pr_comments(self, test_client) -> None:
|
|
payload = {
|
|
"action": "created",
|
|
"issue": {
|
|
"number": 42,
|
|
# No pull_request field = regular issue
|
|
},
|
|
"comment": {
|
|
"id": 123456,
|
|
"body": "Why is this happening?",
|
|
"user": {"login": "testuser"},
|
|
},
|
|
"repository": {"full_name": "owner/repo"},
|
|
}
|
|
|
|
response = await test_client.post(
|
|
"/webhooks/github",
|
|
json=payload,
|
|
headers={"X-GitHub-Event": "issue_comment"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "ignored"
|
|
assert "Not a pull request" in data.get("reason", "")
|
|
|
|
|
|
class TestConversationAPI:
|
|
async def test_list_conversations_empty(self, test_client) -> None:
|
|
response = await test_client.get("/api/conversations")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["items"] == []
|
|
assert data["total"] == 0
|
|
|
|
async def test_get_conversation_not_found(self, test_client) -> None:
|
|
response = await test_client.get("/api/conversations/00000000-0000-0000-0000-000000000000")
|
|
assert response.status_code == 404
|
|
|
|
async def test_get_conversation_for_review_none(self, test_client) -> None:
|
|
response = await test_client.get(
|
|
"/api/conversations/review/00000000-0000-0000-0000-000000000000"
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.json() is None
|