feat(api): add conversations endpoints

This commit is contained in:
2025-05-24 12:09:38 +00:00
parent 56b528b2e3
commit f31272736e
2 changed files with 242 additions and 1 deletions

View File

@@ -0,0 +1,240 @@
"""Conversation REST API endpoints."""
import logging
from datetime import datetime
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.orm import selectinload
from arbiter.api.deps import DbSession
from arbiter.db.models import ConversationModel
router = APIRouter()
logger = logging.getLogger(__name__)
# Response models
class ConversationMessageResponse(BaseModel):
"""Individual message in a conversation."""
id: str
role: str
platform_comment_id: str | None
author: str | None
content: str
responding_agents: list[str] | None
referenced_finding_ids: list[str] | None
tokens_used: int
cost_usd: float
created_at: datetime
sequence: int
class ConversationSummary(BaseModel):
"""Summary of a conversation for list endpoints."""
id: str
review_id: str
platform: str
repository: str
pr_number: int
message_count: int = 0
total_tokens: int
total_cost_usd: float
started_at: datetime
last_activity: datetime
class ConversationDetail(BaseModel):
"""Full conversation detail with messages."""
id: str
review_id: str
platform: str
repository: str
pr_number: int
total_tokens: int
total_cost_usd: float
started_at: datetime
last_activity: datetime
messages: list[ConversationMessageResponse]
class ConversationListResponse(BaseModel):
"""Paginated conversation list response."""
items: list[ConversationSummary]
total: int
page: int
page_size: int
pages: int
@router.get("", response_model=ConversationListResponse)
async def list_conversations(
db: DbSession,
page: Annotated[int, Query(ge=1)] = 1,
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
repository: str | None = None,
review_id: str | None = None,
) -> ConversationListResponse:
"""List conversations with pagination and filtering."""
# Build query
query = select(ConversationModel)
if repository:
query = query.where(ConversationModel.repository == repository)
if review_id:
query = query.where(ConversationModel.review_id == review_id)
# Get total count
count_query = select(func.count()).select_from(query.subquery())
total = await db.scalar(count_query) or 0
# Add pagination and ordering
query = query.order_by(ConversationModel.last_activity.desc())
query = query.offset((page - 1) * page_size).limit(page_size)
# Execute with eager loading
query = query.options(selectinload(ConversationModel.messages))
result = await db.execute(query)
conversations = result.scalars().all()
# Build response
items = []
for conv in conversations:
items.append(
ConversationSummary(
id=conv.id,
review_id=conv.review_id,
platform=conv.platform,
repository=conv.repository,
pr_number=conv.pr_number,
message_count=len(conv.messages) if conv.messages else 0,
total_tokens=conv.total_tokens,
total_cost_usd=conv.total_cost_usd,
started_at=conv.started_at,
last_activity=conv.last_activity,
)
)
pages = (total + page_size - 1) // page_size if page_size else 1
return ConversationListResponse(
items=items,
total=total,
page=page,
page_size=page_size,
pages=pages,
)
@router.get("/{conversation_id}", response_model=ConversationDetail)
async def get_conversation(
db: DbSession,
conversation_id: str,
) -> ConversationDetail:
"""Get conversation detail with all messages."""
query = (
select(ConversationModel)
.where(ConversationModel.id == conversation_id)
.options(selectinload(ConversationModel.messages))
)
result = await db.execute(query)
conversation = result.scalar_one_or_none()
if not conversation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Conversation {conversation_id} not found",
)
# Sort messages by sequence
sorted_messages = (
sorted(conversation.messages, key=lambda m: m.sequence) if conversation.messages else []
)
return ConversationDetail(
id=conversation.id,
review_id=conversation.review_id,
platform=conversation.platform,
repository=conversation.repository,
pr_number=conversation.pr_number,
total_tokens=conversation.total_tokens,
total_cost_usd=conversation.total_cost_usd,
started_at=conversation.started_at,
last_activity=conversation.last_activity,
messages=[
ConversationMessageResponse(
id=msg.id,
role=msg.role,
platform_comment_id=msg.platform_comment_id,
author=msg.author,
content=msg.content,
responding_agents=msg.responding_agents,
referenced_finding_ids=msg.referenced_finding_ids,
tokens_used=msg.tokens_used,
cost_usd=msg.cost_usd,
created_at=msg.created_at,
sequence=msg.sequence,
)
for msg in sorted_messages
],
)
@router.get("/review/{review_id}", response_model=ConversationDetail | None)
async def get_conversation_for_review(
db: DbSession,
review_id: str,
) -> ConversationDetail | None:
"""Get conversation associated with a specific review.
Returns None if no conversation exists for this review.
"""
query = (
select(ConversationModel)
.where(ConversationModel.review_id == review_id)
.options(selectinload(ConversationModel.messages))
)
result = await db.execute(query)
conversation = result.scalar_one_or_none()
if not conversation:
return None
# Sort messages by sequence
sorted_messages = (
sorted(conversation.messages, key=lambda m: m.sequence) if conversation.messages else []
)
return ConversationDetail(
id=conversation.id,
review_id=conversation.review_id,
platform=conversation.platform,
repository=conversation.repository,
pr_number=conversation.pr_number,
total_tokens=conversation.total_tokens,
total_cost_usd=conversation.total_cost_usd,
started_at=conversation.started_at,
last_activity=conversation.last_activity,
messages=[
ConversationMessageResponse(
id=msg.id,
role=msg.role,
platform_comment_id=msg.platform_comment_id,
author=msg.author,
content=msg.content,
responding_agents=msg.responding_agents,
referenced_finding_ids=msg.referenced_finding_ids,
tokens_used=msg.tokens_used,
cost_usd=msg.cost_usd,
created_at=msg.created_at,
sequence=msg.sequence,
)
for msg in sorted_messages
],
)

View File

@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from arbiter.api.deps import close_redis, get_settings
from arbiter.api.routes import health, reviews, webhooks
from arbiter.api.routes import conversations, health, reviews, webhooks
from arbiter.db.session import close_db, init_db
logger = logging.getLogger(__name__)
@@ -73,6 +73,7 @@ def create_app() -> FastAPI:
app.include_router(health.router, tags=["health"])
app.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"])
app.include_router(reviews.router, prefix="/api/reviews", tags=["reviews"])
app.include_router(conversations.router, prefix="/api/conversations", tags=["conversations"])
return app