feat(api): add conversations endpoints
This commit is contained in:
240
src/arbiter/api/routes/conversations.py
Normal file
240
src/arbiter/api/routes/conversations.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Conversation REST API endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from arbiter.api.deps import DbSession
|
||||||
|
from arbiter.db.models import ConversationModel
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Response models
|
||||||
|
class ConversationMessageResponse(BaseModel):
|
||||||
|
"""Individual message in a conversation."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
role: str
|
||||||
|
platform_comment_id: str | None
|
||||||
|
author: str | None
|
||||||
|
content: str
|
||||||
|
responding_agents: list[str] | None
|
||||||
|
referenced_finding_ids: list[str] | None
|
||||||
|
tokens_used: int
|
||||||
|
cost_usd: float
|
||||||
|
created_at: datetime
|
||||||
|
sequence: int
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationSummary(BaseModel):
|
||||||
|
"""Summary of a conversation for list endpoints."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
review_id: str
|
||||||
|
platform: str
|
||||||
|
repository: str
|
||||||
|
pr_number: int
|
||||||
|
message_count: int = 0
|
||||||
|
total_tokens: int
|
||||||
|
total_cost_usd: float
|
||||||
|
started_at: datetime
|
||||||
|
last_activity: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationDetail(BaseModel):
|
||||||
|
"""Full conversation detail with messages."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
review_id: str
|
||||||
|
platform: str
|
||||||
|
repository: str
|
||||||
|
pr_number: int
|
||||||
|
total_tokens: int
|
||||||
|
total_cost_usd: float
|
||||||
|
started_at: datetime
|
||||||
|
last_activity: datetime
|
||||||
|
messages: list[ConversationMessageResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationListResponse(BaseModel):
|
||||||
|
"""Paginated conversation list response."""
|
||||||
|
|
||||||
|
items: list[ConversationSummary]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ConversationListResponse)
|
||||||
|
async def list_conversations(
|
||||||
|
db: DbSession,
|
||||||
|
page: Annotated[int, Query(ge=1)] = 1,
|
||||||
|
page_size: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||||
|
repository: str | None = None,
|
||||||
|
review_id: str | None = None,
|
||||||
|
) -> ConversationListResponse:
|
||||||
|
"""List conversations with pagination and filtering."""
|
||||||
|
# Build query
|
||||||
|
query = select(ConversationModel)
|
||||||
|
|
||||||
|
if repository:
|
||||||
|
query = query.where(ConversationModel.repository == repository)
|
||||||
|
if review_id:
|
||||||
|
query = query.where(ConversationModel.review_id == review_id)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_query = select(func.count()).select_from(query.subquery())
|
||||||
|
total = await db.scalar(count_query) or 0
|
||||||
|
|
||||||
|
# Add pagination and ordering
|
||||||
|
query = query.order_by(ConversationModel.last_activity.desc())
|
||||||
|
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||||
|
|
||||||
|
# Execute with eager loading
|
||||||
|
query = query.options(selectinload(ConversationModel.messages))
|
||||||
|
result = await db.execute(query)
|
||||||
|
conversations = result.scalars().all()
|
||||||
|
|
||||||
|
# Build response
|
||||||
|
items = []
|
||||||
|
for conv in conversations:
|
||||||
|
items.append(
|
||||||
|
ConversationSummary(
|
||||||
|
id=conv.id,
|
||||||
|
review_id=conv.review_id,
|
||||||
|
platform=conv.platform,
|
||||||
|
repository=conv.repository,
|
||||||
|
pr_number=conv.pr_number,
|
||||||
|
message_count=len(conv.messages) if conv.messages else 0,
|
||||||
|
total_tokens=conv.total_tokens,
|
||||||
|
total_cost_usd=conv.total_cost_usd,
|
||||||
|
started_at=conv.started_at,
|
||||||
|
last_activity=conv.last_activity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pages = (total + page_size - 1) // page_size if page_size else 1
|
||||||
|
|
||||||
|
return ConversationListResponse(
|
||||||
|
items=items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
pages=pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{conversation_id}", response_model=ConversationDetail)
|
||||||
|
async def get_conversation(
|
||||||
|
db: DbSession,
|
||||||
|
conversation_id: str,
|
||||||
|
) -> ConversationDetail:
|
||||||
|
"""Get conversation detail with all messages."""
|
||||||
|
query = (
|
||||||
|
select(ConversationModel)
|
||||||
|
.where(ConversationModel.id == conversation_id)
|
||||||
|
.options(selectinload(ConversationModel.messages))
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
conversation = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Conversation {conversation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort messages by sequence
|
||||||
|
sorted_messages = (
|
||||||
|
sorted(conversation.messages, key=lambda m: m.sequence) if conversation.messages else []
|
||||||
|
)
|
||||||
|
|
||||||
|
return ConversationDetail(
|
||||||
|
id=conversation.id,
|
||||||
|
review_id=conversation.review_id,
|
||||||
|
platform=conversation.platform,
|
||||||
|
repository=conversation.repository,
|
||||||
|
pr_number=conversation.pr_number,
|
||||||
|
total_tokens=conversation.total_tokens,
|
||||||
|
total_cost_usd=conversation.total_cost_usd,
|
||||||
|
started_at=conversation.started_at,
|
||||||
|
last_activity=conversation.last_activity,
|
||||||
|
messages=[
|
||||||
|
ConversationMessageResponse(
|
||||||
|
id=msg.id,
|
||||||
|
role=msg.role,
|
||||||
|
platform_comment_id=msg.platform_comment_id,
|
||||||
|
author=msg.author,
|
||||||
|
content=msg.content,
|
||||||
|
responding_agents=msg.responding_agents,
|
||||||
|
referenced_finding_ids=msg.referenced_finding_ids,
|
||||||
|
tokens_used=msg.tokens_used,
|
||||||
|
cost_usd=msg.cost_usd,
|
||||||
|
created_at=msg.created_at,
|
||||||
|
sequence=msg.sequence,
|
||||||
|
)
|
||||||
|
for msg in sorted_messages
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/review/{review_id}", response_model=ConversationDetail | None)
|
||||||
|
async def get_conversation_for_review(
|
||||||
|
db: DbSession,
|
||||||
|
review_id: str,
|
||||||
|
) -> ConversationDetail | None:
|
||||||
|
"""Get conversation associated with a specific review.
|
||||||
|
|
||||||
|
Returns None if no conversation exists for this review.
|
||||||
|
"""
|
||||||
|
query = (
|
||||||
|
select(ConversationModel)
|
||||||
|
.where(ConversationModel.review_id == review_id)
|
||||||
|
.options(selectinload(ConversationModel.messages))
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
conversation = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort messages by sequence
|
||||||
|
sorted_messages = (
|
||||||
|
sorted(conversation.messages, key=lambda m: m.sequence) if conversation.messages else []
|
||||||
|
)
|
||||||
|
|
||||||
|
return ConversationDetail(
|
||||||
|
id=conversation.id,
|
||||||
|
review_id=conversation.review_id,
|
||||||
|
platform=conversation.platform,
|
||||||
|
repository=conversation.repository,
|
||||||
|
pr_number=conversation.pr_number,
|
||||||
|
total_tokens=conversation.total_tokens,
|
||||||
|
total_cost_usd=conversation.total_cost_usd,
|
||||||
|
started_at=conversation.started_at,
|
||||||
|
last_activity=conversation.last_activity,
|
||||||
|
messages=[
|
||||||
|
ConversationMessageResponse(
|
||||||
|
id=msg.id,
|
||||||
|
role=msg.role,
|
||||||
|
platform_comment_id=msg.platform_comment_id,
|
||||||
|
author=msg.author,
|
||||||
|
content=msg.content,
|
||||||
|
responding_agents=msg.responding_agents,
|
||||||
|
referenced_finding_ids=msg.referenced_finding_ids,
|
||||||
|
tokens_used=msg.tokens_used,
|
||||||
|
cost_usd=msg.cost_usd,
|
||||||
|
created_at=msg.created_at,
|
||||||
|
sequence=msg.sequence,
|
||||||
|
)
|
||||||
|
for msg in sorted_messages
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from arbiter.api.deps import close_redis, get_settings
|
from arbiter.api.deps import close_redis, get_settings
|
||||||
from arbiter.api.routes import health, reviews, webhooks
|
from arbiter.api.routes import conversations, health, reviews, webhooks
|
||||||
from arbiter.db.session import close_db, init_db
|
from arbiter.db.session import close_db, init_db
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -73,6 +73,7 @@ def create_app() -> FastAPI:
|
|||||||
app.include_router(health.router, tags=["health"])
|
app.include_router(health.router, tags=["health"])
|
||||||
app.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"])
|
app.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"])
|
||||||
app.include_router(reviews.router, prefix="/api/reviews", tags=["reviews"])
|
app.include_router(reviews.router, prefix="/api/reviews", tags=["reviews"])
|
||||||
|
app.include_router(conversations.router, prefix="/api/conversations", tags=["conversations"])
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user