diff --git a/src/arbiter/db/migrations/versions/002_conversations.py b/src/arbiter/db/migrations/versions/002_conversations.py new file mode 100644 index 0000000..267f794 --- /dev/null +++ b/src/arbiter/db/migrations/versions/002_conversations.py @@ -0,0 +1,96 @@ +"""Add conversations tables for follow-up Q&A. + +Revision ID: 002 +Revises: 001 +Create Date: 2024-01-20 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "002" +down_revision: str | None = "001" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # Create conversations table + op.create_table( + "conversations", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "review_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("reviews.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("platform", sa.String(50), nullable=False), + sa.Column("repository", sa.String(255), nullable=False), + sa.Column("pr_number", sa.Integer(), nullable=False), + sa.Column( + "started_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "last_activity", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("total_tokens", sa.Integer(), default=0), + sa.Column("total_cost_usd", sa.Float(), default=0.0), + ) + + # Create indexes for conversations + op.create_index("ix_conversations_review_id", "conversations", ["review_id"]) + op.create_index("ix_conversations_repository_pr", "conversations", ["repository", "pr_number"]) + + # Create conversation_messages table + op.create_table( + "conversation_messages", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "conversation_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("role", sa.String(20), nullable=False), + sa.Column("platform_comment_id", sa.String(100), nullable=True), + sa.Column("author", sa.String(255), nullable=True), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("responding_agents", postgresql.JSON(), nullable=True), + sa.Column("referenced_finding_ids", postgresql.JSON(), nullable=True), + sa.Column("tokens_used", sa.Integer(), default=0), + sa.Column("cost_usd", sa.Float(), default=0.0), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("sequence", sa.Integer(), nullable=False), + ) + + # Create indexes for conversation_messages + op.create_index( + "ix_conversation_messages_conversation_id", "conversation_messages", ["conversation_id"] + ) + op.create_index( + "ix_conversation_messages_sequence", + "conversation_messages", + ["conversation_id", "sequence"], + ) + + +def downgrade() -> None: + # Drop tables in reverse order + op.drop_table("conversation_messages") + op.drop_table("conversations") diff --git a/src/arbiter/db/models.py b/src/arbiter/db/models.py index 3675d26..79af8bd 100644 --- a/src/arbiter/db/models.py +++ b/src/arbiter/db/models.py @@ -97,6 +97,9 @@ class ReviewModel(Base): "DeliberationStepModel", back_populates="review", cascade="all, delete-orphan" ) policy: Mapped["PolicyModel | None"] = relationship("PolicyModel", back_populates="reviews") + conversations: Mapped[list["ConversationModel"]] = relationship( + "ConversationModel", back_populates="review", cascade="all, delete-orphan" + ) __table_args__ = ( Index("ix_reviews_repository_pr", "repository", "pr_number"), @@ -236,3 +239,81 @@ class PolicyModel(Base): # Relationship reviews: Mapped[list["ReviewModel"]] = relationship("ReviewModel", back_populates="policy") + + +class ConversationModel(Base): + """Database model for follow-up conversations linked to reviews.""" + + __tablename__ = "conversations" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + review_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("reviews.id"), nullable=False, index=True + ) + + # PR context + platform: Mapped[str] = mapped_column(String(50), nullable=False) # github/gitlab + repository: Mapped[str] = mapped_column(String(255), nullable=False) + pr_number: Mapped[int] = mapped_column(Integer, nullable=False) + + # Timestamps + started_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + last_activity: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + + # Cost tracking + total_tokens: Mapped[int] = mapped_column(Integer, default=0) + total_cost_usd: Mapped[float] = mapped_column(Float, default=0.0) + + # Relationships + review: Mapped["ReviewModel"] = relationship("ReviewModel", back_populates="conversations") + messages: Mapped[list["ConversationMessageModel"]] = relationship( + "ConversationMessageModel", back_populates="conversation", cascade="all, delete-orphan" + ) + + __table_args__ = (Index("ix_conversations_repository_pr", "repository", "pr_number"),) + + +class ConversationMessageModel(Base): + """Database model for individual messages in a conversation.""" + + __tablename__ = "conversation_messages" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + conversation_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("conversations.id"), nullable=False, index=True + ) + + # Message data + role: Mapped[str] = mapped_column(String(20), nullable=False) # "user" or "assistant" + platform_comment_id: Mapped[str | None] = mapped_column(String(100), nullable=True) + author: Mapped[str | None] = mapped_column(String(255), nullable=True) + content: Mapped[str] = mapped_column(Text, nullable=False) + + # Agent routing info (for assistant messages) + responding_agents: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + referenced_finding_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + + # Cost tracking + tokens_used: Mapped[int] = mapped_column(Integer, default=0) + cost_usd: Mapped[float] = mapped_column(Float, default=0.0) + + # Timestamps and ordering + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + sequence: Mapped[int] = mapped_column(Integer, nullable=False) + + # Relationship + conversation: Mapped["ConversationModel"] = relationship( + "ConversationModel", back_populates="messages" + ) + + __table_args__ = (Index("ix_conversation_messages_sequence", "conversation_id", "sequence"),)