From f1b7852a7ddf43c610d2a2e9b5e61c4f14311633 Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Sat, 15 Mar 2025 14:07:50 +0000 Subject: [PATCH] feat(db): add postgres schema and migrations --- alembic.ini | 45 ++++ src/arbiter/db/__init__.py | 23 ++ src/arbiter/db/migrations/env.py | 80 ++++++ src/arbiter/db/migrations/script.py.mako | 26 ++ .../migrations/versions/001_initial_schema.py | 212 ++++++++++++++++ src/arbiter/db/models.py | 238 ++++++++++++++++++ src/arbiter/db/session.py | 74 ++++++ 7 files changed, 698 insertions(+) create mode 100644 alembic.ini create mode 100644 src/arbiter/db/__init__.py create mode 100644 src/arbiter/db/migrations/env.py create mode 100644 src/arbiter/db/migrations/script.py.mako create mode 100644 src/arbiter/db/migrations/versions/001_initial_schema.py create mode 100644 src/arbiter/db/models.py create mode 100644 src/arbiter/db/session.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..bdfa344 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,45 @@ +# Alembic configuration for Arbiter + +[alembic] +script_location = src/arbiter/db/migrations +prepend_sys_path = . +version_path_separator = os + +# Database URL is loaded from environment via env.py +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/arbiter/db/__init__.py b/src/arbiter/db/__init__.py new file mode 100644 index 0000000..011b491 --- /dev/null +++ b/src/arbiter/db/__init__.py @@ -0,0 +1,23 @@ +"""Database module for Arbiter.""" + +from arbiter.db.models import ( + Base, + ConflictModel, + DeliberationStepModel, + FindingModel, + PolicyModel, + ReviewModel, +) +from arbiter.db.session import async_session_factory, get_async_session, init_db + +__all__ = [ + "Base", + "ConflictModel", + "DeliberationStepModel", + "FindingModel", + "PolicyModel", + "ReviewModel", + "async_session_factory", + "get_async_session", + "init_db", +] diff --git a/src/arbiter/db/migrations/env.py b/src/arbiter/db/migrations/env.py new file mode 100644 index 0000000..c4427e9 --- /dev/null +++ b/src/arbiter/db/migrations/env.py @@ -0,0 +1,80 @@ +"""Alembic environment configuration for async SQLAlchemy.""" + +import asyncio +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from arbiter.config import get_settings +from arbiter.db.models import Base + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata + + +def get_url() -> str: + return get_settings().database_url + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in async mode.""" + configuration = config.get_section(config.config_ini_section, {}) + configuration["sqlalchemy.url"] = get_url() + + connectable = async_engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/arbiter/db/migrations/script.py.mako b/src/arbiter/db/migrations/script.py.mako new file mode 100644 index 0000000..196d563 --- /dev/null +++ b/src/arbiter/db/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: str | None = ${repr(down_revision)} +branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} +depends_on: str | Sequence[str] | None = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/arbiter/db/migrations/versions/001_initial_schema.py b/src/arbiter/db/migrations/versions/001_initial_schema.py new file mode 100644 index 0000000..b492a00 --- /dev/null +++ b/src/arbiter/db/migrations/versions/001_initial_schema.py @@ -0,0 +1,212 @@ +"""Initial schema for Arbiter. + +Revision ID: 001 +Revises: +Create Date: 2024-01-15 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "001" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # Create enum types + verdict_enum = postgresql.ENUM( + "approve", "request_changes", "comment", name="verdict_enum", create_type=False + ) + verdict_enum.create(op.get_bind(), checkfirst=True) + + agent_enum = postgresql.ENUM( + "security", "style", "complexity", name="agent_enum", create_type=False + ) + agent_enum.create(op.get_bind(), checkfirst=True) + + severity_enum = postgresql.ENUM( + "critical", "high", "medium", "low", "info", name="severity_enum", create_type=False + ) + severity_enum.create(op.get_bind(), checkfirst=True) + + conflict_nature_enum = postgresql.ENUM( + "contradictory", "trade_off", "overlapping", name="conflict_nature_enum", create_type=False + ) + conflict_nature_enum.create(op.get_bind(), checkfirst=True) + + step_type_enum = postgresql.ENUM( + "merge", + "conflict_detection", + "synthesis", + "verdict", + name="step_type_enum", + create_type=False, + ) + step_type_enum.create(op.get_bind(), checkfirst=True) + + # Create policies table + op.create_table( + "policies", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column("name", sa.String(100), nullable=False, unique=True), + sa.Column("organization", sa.String(255), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("agents_config", postgresql.JSON(), nullable=True), + sa.Column("cost_controls", postgresql.JSON(), nullable=True), + sa.Column("verdict_thresholds", postgresql.JSON(), nullable=True), + sa.Column("is_active", sa.Boolean(), default=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + ) + + # Create reviews table + op.create_table( + "reviews", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column("repository", sa.String(255), nullable=False), + sa.Column("pr_number", sa.Integer(), nullable=False), + sa.Column("pr_title", sa.String(500), nullable=True), + sa.Column("base_sha", sa.String(40), nullable=False), + sa.Column("head_sha", sa.String(40), nullable=False), + sa.Column("author", sa.String(255), nullable=True), + sa.Column("is_draft", sa.Boolean(), default=False), + sa.Column("status", sa.String(50), default="pending", nullable=False), + sa.Column("verdict", verdict_enum, nullable=True), + sa.Column("verdict_confidence", sa.Float(), nullable=True), + sa.Column("verdict_reasoning", sa.Text(), nullable=True), + sa.Column("total_tokens", sa.Integer(), default=0), + sa.Column("total_cost_usd", sa.Float(), default=0.0), + sa.Column("tokens_by_agent", postgresql.JSON(), nullable=True), + sa.Column("cost_by_agent", postgresql.JSON(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column( + "policy_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("policies.id"), + nullable=True, + ), + ) + + # Create indexes for reviews + op.create_index("ix_reviews_status", "reviews", ["status"]) + op.create_index("ix_reviews_repository_pr", "reviews", ["repository", "pr_number"]) + op.create_index("ix_reviews_created_at", "reviews", ["created_at"]) + + # Create findings table + op.create_table( + "findings", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "review_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("reviews.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("agent", agent_enum, nullable=False), + sa.Column("file", sa.String(500), nullable=False), + sa.Column("line_start", sa.Integer(), nullable=False), + sa.Column("line_end", sa.Integer(), nullable=False), + sa.Column("severity", severity_enum, nullable=False), + sa.Column("confidence", sa.Float(), nullable=False), + sa.Column("title", sa.String(500), nullable=False), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("reasoning", sa.Text(), nullable=False), + sa.Column("suggestion", sa.Text(), nullable=True), + sa.Column("references", postgresql.JSON(), nullable=True), + sa.Column("prompt_version", sa.String(50), nullable=False), + sa.Column("static_analysis_context", postgresql.JSON(), nullable=True), + ) + + # Create indexes for findings + op.create_index("ix_findings_review_id", "findings", ["review_id"]) + op.create_index("ix_findings_severity", "findings", ["severity"]) + op.create_index("ix_findings_review_severity", "findings", ["review_id", "severity"]) + + # Create conflicts table + op.create_table( + "conflicts", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "review_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("reviews.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("finding_ids", postgresql.JSON(), nullable=False), + sa.Column("nature", conflict_nature_enum, nullable=False), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("severity_weight", sa.Float(), nullable=False), + sa.Column("resolution", sa.Text(), nullable=True), + sa.Column("winning_finding_id", sa.String(36), nullable=True), + ) + + # Create index for conflicts + op.create_index("ix_conflicts_review_id", "conflicts", ["review_id"]) + + # Create deliberation_steps table + op.create_table( + "deliberation_steps", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "review_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("reviews.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("step_type", step_type_enum, nullable=False), + sa.Column( + "timestamp", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("description", sa.Text(), nullable=False), + sa.Column("details", postgresql.JSON(), nullable=True), + sa.Column("sequence", sa.Integer(), nullable=False), + ) + + # Create indexes for deliberation_steps + op.create_index("ix_deliberation_steps_review_id", "deliberation_steps", ["review_id"]) + op.create_index( + "ix_deliberation_steps_review_sequence", "deliberation_steps", ["review_id", "sequence"] + ) + + +def downgrade() -> None: + # Drop tables in reverse order + op.drop_table("deliberation_steps") + op.drop_table("conflicts") + op.drop_table("findings") + op.drop_table("reviews") + op.drop_table("policies") + + # Drop enum types + op.execute("DROP TYPE IF EXISTS step_type_enum") + op.execute("DROP TYPE IF EXISTS conflict_nature_enum") + op.execute("DROP TYPE IF EXISTS severity_enum") + op.execute("DROP TYPE IF EXISTS agent_enum") + op.execute("DROP TYPE IF EXISTS verdict_enum") diff --git a/src/arbiter/db/models.py b/src/arbiter/db/models.py new file mode 100644 index 0000000..3675d26 --- /dev/null +++ b/src/arbiter/db/models.py @@ -0,0 +1,238 @@ +"""SQLAlchemy database models for Arbiter.""" + +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from sqlalchemy import ( + JSON, + DateTime, + Enum, + Float, + ForeignKey, + Index, + Integer, + String, + Text, + func, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +from arbiter.deliberation.conflicts import ConflictNature +from arbiter.deliberation.coordinator import StepType +from arbiter.models.enums import AgentName, Severity, Verdict + + +class Base(DeclarativeBase): + """Base class for all database models.""" + + pass + + +class ReviewModel(Base): + """Database model for code reviews.""" + + __tablename__ = "reviews" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + + # PR metadata + repository: Mapped[str] = mapped_column(String(255), nullable=False) + pr_number: Mapped[int] = mapped_column(Integer, nullable=False) + pr_title: Mapped[str | None] = mapped_column(String(500), nullable=True) + base_sha: Mapped[str] = mapped_column(String(40), nullable=False) + head_sha: Mapped[str] = mapped_column(String(40), nullable=False) + author: Mapped[str | None] = mapped_column(String(255), nullable=True) + is_draft: Mapped[bool] = mapped_column(default=False) + + # Review status + status: Mapped[str] = mapped_column( + String(50), + default="pending", + index=True, + ) # pending, running, completed, failed + + # Verdict + verdict: Mapped[str | None] = mapped_column( + Enum(Verdict, name="verdict_enum", create_constraint=True), + nullable=True, + ) + verdict_confidence: Mapped[float | None] = mapped_column(Float, nullable=True) + verdict_reasoning: Mapped[str | None] = mapped_column(Text, nullable=True) + + # Cost tracking + total_tokens: Mapped[int] = mapped_column(Integer, default=0) + total_cost_usd: Mapped[float] = mapped_column(Float, default=0.0) + tokens_by_agent: Mapped[dict[str, int] | None] = mapped_column(JSON, nullable=True) + cost_by_agent: Mapped[dict[str, float] | None] = mapped_column(JSON, nullable=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + + # Error info + error_message: Mapped[str | None] = mapped_column(Text, nullable=True) + + # Policy reference + policy_id: Mapped[str | None] = mapped_column( + UUID(as_uuid=False), ForeignKey("policies.id"), nullable=True + ) + + # Relationships + findings: Mapped[list["FindingModel"]] = relationship( + "FindingModel", back_populates="review", cascade="all, delete-orphan" + ) + conflicts: Mapped[list["ConflictModel"]] = relationship( + "ConflictModel", back_populates="review", cascade="all, delete-orphan" + ) + deliberation_steps: Mapped[list["DeliberationStepModel"]] = relationship( + "DeliberationStepModel", back_populates="review", cascade="all, delete-orphan" + ) + policy: Mapped["PolicyModel | None"] = relationship("PolicyModel", back_populates="reviews") + + __table_args__ = ( + Index("ix_reviews_repository_pr", "repository", "pr_number"), + Index("ix_reviews_created_at", "created_at"), + ) + + +class FindingModel(Base): + """Database model for individual findings.""" + + __tablename__ = "findings" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + review_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("reviews.id"), nullable=False, index=True + ) + + # Finding data + agent: Mapped[str] = mapped_column( + Enum(AgentName, name="agent_enum", create_constraint=True), + nullable=False, + ) + file: Mapped[str] = mapped_column(String(500), nullable=False) + line_start: Mapped[int] = mapped_column(Integer, nullable=False) + line_end: Mapped[int] = mapped_column(Integer, nullable=False) + severity: Mapped[str] = mapped_column( + Enum(Severity, name="severity_enum", create_constraint=True), + nullable=False, + index=True, + ) + confidence: Mapped[float] = mapped_column(Float, nullable=False) + title: Mapped[str] = mapped_column(String(500), nullable=False) + description: Mapped[str] = mapped_column(Text, nullable=False) + reasoning: Mapped[str] = mapped_column(Text, nullable=False) + suggestion: Mapped[str | None] = mapped_column(Text, nullable=True) + references: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + prompt_version: Mapped[str] = mapped_column(String(50), nullable=False) + static_analysis_context: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + + # Relationship + review: Mapped["ReviewModel"] = relationship("ReviewModel", back_populates="findings") + + __table_args__ = (Index("ix_findings_review_severity", "review_id", "severity"),) + + +class ConflictModel(Base): + """Database model for detected conflicts between findings.""" + + __tablename__ = "conflicts" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + review_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("reviews.id"), nullable=False, index=True + ) + + # Conflict data + finding_ids: Mapped[list[str]] = mapped_column(JSON, nullable=False) + nature: Mapped[str] = mapped_column( + Enum(ConflictNature, name="conflict_nature_enum", create_constraint=True), + nullable=False, + ) + description: Mapped[str] = mapped_column(Text, nullable=False) + severity_weight: Mapped[float] = mapped_column(Float, nullable=False) + resolution: Mapped[str | None] = mapped_column(Text, nullable=True) + winning_finding_id: Mapped[str | None] = mapped_column(String(36), nullable=True) + + # Relationship + review: Mapped["ReviewModel"] = relationship("ReviewModel", back_populates="conflicts") + + +class DeliberationStepModel(Base): + """Database model for deliberation log entries.""" + + __tablename__ = "deliberation_steps" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + review_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("reviews.id"), nullable=False, index=True + ) + + # Step data + step_type: Mapped[str] = mapped_column( + Enum(StepType, name="step_type_enum", create_constraint=True), + nullable=False, + ) + timestamp: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + description: Mapped[str] = mapped_column(Text, nullable=False) + details: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + sequence: Mapped[int] = mapped_column(Integer, nullable=False) + + # Relationship + review: Mapped["ReviewModel"] = relationship("ReviewModel", back_populates="deliberation_steps") + + __table_args__ = (Index("ix_deliberation_steps_review_sequence", "review_id", "sequence"),) + + +class PolicyModel(Base): + """Database model for organization policies.""" + + __tablename__ = "policies" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()) + ) + + # Policy identification + name: Mapped[str] = mapped_column(String(100), nullable=False, unique=True) + organization: Mapped[str | None] = mapped_column(String(255), nullable=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + + # Policy configuration (stored as JSON) + agents_config: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + cost_controls: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + verdict_thresholds: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + + # Active status + is_active: Mapped[bool] = mapped_column(default=True) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + # Relationship + reviews: Mapped[list["ReviewModel"]] = relationship("ReviewModel", back_populates="policy") diff --git a/src/arbiter/db/session.py b/src/arbiter/db/session.py new file mode 100644 index 0000000..965e678 --- /dev/null +++ b/src/arbiter/db/session.py @@ -0,0 +1,74 @@ +"""Database session management for Arbiter.""" + +from collections.abc import AsyncGenerator + +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from arbiter.config import get_settings +from arbiter.db.models import Base + +# Create engine lazily to allow settings override during testing +_engine = None +_session_factory: async_sessionmaker[AsyncSession] | None = None + + +def _get_engine() -> AsyncEngine: + """Get or create the async engine.""" + global _engine + if _engine is None: + settings = get_settings() + _engine = create_async_engine( + settings.database_url, + pool_size=settings.database_pool_size, + max_overflow=settings.database_max_overflow, + echo=False, + ) + return _engine + + +def async_session_factory() -> async_sessionmaker[AsyncSession]: + global _session_factory + if _session_factory is None: + _session_factory = async_sessionmaker( + bind=_get_engine(), + class_=AsyncSession, + expire_on_commit=False, + ) + return _session_factory + + +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + """Dependency for getting an async database session.""" + factory = async_session_factory() + async with factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + +async def init_db() -> None: + engine = _get_engine() + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def close_db() -> None: + global _engine, _session_factory + if _engine is not None: + await _engine.dispose() + _engine = None + _session_factory = None + + +def reset_engine() -> None: + global _engine, _session_factory + _engine = None + _session_factory = None