feat(api): add conversations endpoints
This commit is contained in:
240
src/arbiter/api/routes/conversations.py
Normal file
240
src/arbiter/api/routes/conversations.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Conversation REST API endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from arbiter.api.deps import DbSession
|
||||
from arbiter.db.models import ConversationModel
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Response models
|
||||
class ConversationMessageResponse(BaseModel):
|
||||
"""Individual message in a conversation."""
|
||||
|
||||
id: str
|
||||
role: str
|
||||
platform_comment_id: str | None
|
||||
author: str | None
|
||||
content: str
|
||||
responding_agents: list[str] | None
|
||||
referenced_finding_ids: list[str] | None
|
||||
tokens_used: int
|
||||
cost_usd: float
|
||||
created_at: datetime
|
||||
sequence: int
|
||||
|
||||
|
||||
class ConversationSummary(BaseModel):
|
||||
"""Summary of a conversation for list endpoints."""
|
||||
|
||||
id: str
|
||||
review_id: str
|
||||
platform: str
|
||||
repository: str
|
||||
pr_number: int
|
||||
message_count: int = 0
|
||||
total_tokens: int
|
||||
total_cost_usd: float
|
||||
started_at: datetime
|
||||
last_activity: datetime
|
||||
|
||||
|
||||
class ConversationDetail(BaseModel):
|
||||
"""Full conversation detail with messages."""
|
||||
|
||||
id: str
|
||||
review_id: str
|
||||
platform: str
|
||||
repository: str
|
||||
pr_number: int
|
||||
total_tokens: int
|
||||
total_cost_usd: float
|
||||
started_at: datetime
|
||||
last_activity: datetime
|
||||
messages: list[ConversationMessageResponse]
|
||||
|
||||
|
||||
class ConversationListResponse(BaseModel):
|
||||
"""Paginated conversation list response."""
|
||||
|
||||
items: list[ConversationSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
@router.get("", response_model=ConversationListResponse)
|
||||
async def list_conversations(
|
||||
db: DbSession,
|
||||
page: Annotated[int, Query(ge=1)] = 1,
|
||||
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
repository: str | None = None,
|
||||
review_id: str | None = None,
|
||||
) -> ConversationListResponse:
|
||||
"""List conversations with pagination and filtering."""
|
||||
# Build query
|
||||
query = select(ConversationModel)
|
||||
|
||||
if repository:
|
||||
query = query.where(ConversationModel.repository == repository)
|
||||
if review_id:
|
||||
query = query.where(ConversationModel.review_id == review_id)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = await db.scalar(count_query) or 0
|
||||
|
||||
# Add pagination and ordering
|
||||
query = query.order_by(ConversationModel.last_activity.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
# Execute with eager loading
|
||||
query = query.options(selectinload(ConversationModel.messages))
|
||||
result = await db.execute(query)
|
||||
conversations = result.scalars().all()
|
||||
|
||||
# Build response
|
||||
items = []
|
||||
for conv in conversations:
|
||||
items.append(
|
||||
ConversationSummary(
|
||||
id=conv.id,
|
||||
review_id=conv.review_id,
|
||||
platform=conv.platform,
|
||||
repository=conv.repository,
|
||||
pr_number=conv.pr_number,
|
||||
message_count=len(conv.messages) if conv.messages else 0,
|
||||
total_tokens=conv.total_tokens,
|
||||
total_cost_usd=conv.total_cost_usd,
|
||||
started_at=conv.started_at,
|
||||
last_activity=conv.last_activity,
|
||||
)
|
||||
)
|
||||
|
||||
pages = (total + page_size - 1) // page_size if page_size else 1
|
||||
|
||||
return ConversationListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=ConversationDetail)
|
||||
async def get_conversation(
|
||||
db: DbSession,
|
||||
conversation_id: str,
|
||||
) -> ConversationDetail:
|
||||
"""Get conversation detail with all messages."""
|
||||
query = (
|
||||
select(ConversationModel)
|
||||
.where(ConversationModel.id == conversation_id)
|
||||
.options(selectinload(ConversationModel.messages))
|
||||
)
|
||||
result = await db.execute(query)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Conversation {conversation_id} not found",
|
||||
)
|
||||
|
||||
# Sort messages by sequence
|
||||
sorted_messages = (
|
||||
sorted(conversation.messages, key=lambda m: m.sequence) if conversation.messages else []
|
||||
)
|
||||
|
||||
return ConversationDetail(
|
||||
id=conversation.id,
|
||||
review_id=conversation.review_id,
|
||||
platform=conversation.platform,
|
||||
repository=conversation.repository,
|
||||
pr_number=conversation.pr_number,
|
||||
total_tokens=conversation.total_tokens,
|
||||
total_cost_usd=conversation.total_cost_usd,
|
||||
started_at=conversation.started_at,
|
||||
last_activity=conversation.last_activity,
|
||||
messages=[
|
||||
ConversationMessageResponse(
|
||||
id=msg.id,
|
||||
role=msg.role,
|
||||
platform_comment_id=msg.platform_comment_id,
|
||||
author=msg.author,
|
||||
content=msg.content,
|
||||
responding_agents=msg.responding_agents,
|
||||
referenced_finding_ids=msg.referenced_finding_ids,
|
||||
tokens_used=msg.tokens_used,
|
||||
cost_usd=msg.cost_usd,
|
||||
created_at=msg.created_at,
|
||||
sequence=msg.sequence,
|
||||
)
|
||||
for msg in sorted_messages
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/review/{review_id}", response_model=ConversationDetail | None)
|
||||
async def get_conversation_for_review(
|
||||
db: DbSession,
|
||||
review_id: str,
|
||||
) -> ConversationDetail | None:
|
||||
"""Get conversation associated with a specific review.
|
||||
|
||||
Returns None if no conversation exists for this review.
|
||||
"""
|
||||
query = (
|
||||
select(ConversationModel)
|
||||
.where(ConversationModel.review_id == review_id)
|
||||
.options(selectinload(ConversationModel.messages))
|
||||
)
|
||||
result = await db.execute(query)
|
||||
conversation = result.scalar_one_or_none()
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
# Sort messages by sequence
|
||||
sorted_messages = (
|
||||
sorted(conversation.messages, key=lambda m: m.sequence) if conversation.messages else []
|
||||
)
|
||||
|
||||
return ConversationDetail(
|
||||
id=conversation.id,
|
||||
review_id=conversation.review_id,
|
||||
platform=conversation.platform,
|
||||
repository=conversation.repository,
|
||||
pr_number=conversation.pr_number,
|
||||
total_tokens=conversation.total_tokens,
|
||||
total_cost_usd=conversation.total_cost_usd,
|
||||
started_at=conversation.started_at,
|
||||
last_activity=conversation.last_activity,
|
||||
messages=[
|
||||
ConversationMessageResponse(
|
||||
id=msg.id,
|
||||
role=msg.role,
|
||||
platform_comment_id=msg.platform_comment_id,
|
||||
author=msg.author,
|
||||
content=msg.content,
|
||||
responding_agents=msg.responding_agents,
|
||||
referenced_finding_ids=msg.referenced_finding_ids,
|
||||
tokens_used=msg.tokens_used,
|
||||
cost_usd=msg.cost_usd,
|
||||
created_at=msg.created_at,
|
||||
sequence=msg.sequence,
|
||||
)
|
||||
for msg in sorted_messages
|
||||
],
|
||||
)
|
||||
@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from arbiter.api.deps import close_redis, get_settings
|
||||
from arbiter.api.routes import health, reviews, webhooks
|
||||
from arbiter.api.routes import conversations, health, reviews, webhooks
|
||||
from arbiter.db.session import close_db, init_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -73,6 +73,7 @@ def create_app() -> FastAPI:
|
||||
app.include_router(health.router, tags=["health"])
|
||||
app.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"])
|
||||
app.include_router(reviews.router, prefix="/api/reviews", tags=["reviews"])
|
||||
app.include_router(conversations.router, prefix="/api/conversations", tags=["conversations"])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
Reference in New Issue
Block a user