diff --git a/src/arbiter/api/__init__.py b/src/arbiter/api/__init__.py new file mode 100644 index 0000000..8e62c71 --- /dev/null +++ b/src/arbiter/api/__init__.py @@ -0,0 +1,9 @@ +"""API module for Arbiter.""" + +from arbiter.api.deps import get_db, get_redis, get_settings + +__all__ = [ + "get_db", + "get_redis", + "get_settings", +] diff --git a/src/arbiter/api/deps.py b/src/arbiter/api/deps.py new file mode 100644 index 0000000..381042c --- /dev/null +++ b/src/arbiter/api/deps.py @@ -0,0 +1,54 @@ +"""FastAPI dependency injection for Arbiter.""" + +from collections.abc import AsyncGenerator +from functools import lru_cache +from typing import Annotated + +from fastapi import Depends +from redis.asyncio import Redis +from sqlalchemy.ext.asyncio import AsyncSession + +from arbiter.config import Settings +from arbiter.config import get_settings as _get_settings +from arbiter.db.session import get_async_session + +# Redis client cache +_redis_client: Redis | None = None + + +@lru_cache +def get_settings() -> Settings: + return _get_settings() + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + async for session in get_async_session(): + yield session + + +async def get_redis() -> AsyncGenerator[Redis, None]: + """Get Redis client dependency.""" + global _redis_client + settings = get_settings() + + if _redis_client is None: + _redis_client = Redis.from_url( + settings.redis_url, + max_connections=settings.redis_max_connections, + decode_responses=True, + ) + + yield _redis_client + + +async def close_redis() -> None: + global _redis_client + if _redis_client is not None: + await _redis_client.close() + _redis_client = None + + +# Type aliases for cleaner route signatures +DbSession = Annotated[AsyncSession, Depends(get_db)] +RedisClient = Annotated[Redis, Depends(get_redis)] +AppSettings = Annotated[Settings, Depends(get_settings)] diff --git a/src/arbiter/api/routes/__init__.py b/src/arbiter/api/routes/__init__.py new file mode 100644 index 0000000..4b3cdb1 --- /dev/null +++ b/src/arbiter/api/routes/__init__.py @@ -0,0 +1,5 @@ +"""API routes for Arbiter.""" + +from arbiter.api.routes import health, reviews, webhooks + +__all__ = ["health", "reviews", "webhooks"] diff --git a/src/arbiter/api/routes/health.py b/src/arbiter/api/routes/health.py new file mode 100644 index 0000000..300a6c9 --- /dev/null +++ b/src/arbiter/api/routes/health.py @@ -0,0 +1,162 @@ +"""Health and metrics endpoints for Arbiter.""" + +import logging +from typing import Any + +from fastapi import APIRouter, Response, status +from prometheus_client import ( + CONTENT_TYPE_LATEST, + Counter, + Gauge, + Histogram, + generate_latest, +) +from pydantic import BaseModel +from sqlalchemy import text + +from arbiter.api.deps import DbSession, RedisClient + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Prometheus metrics +REVIEWS_TOTAL = Counter( + "arbiter_reviews_total", + "Total number of reviews processed", + ["status", "verdict"], +) + +REVIEW_DURATION = Histogram( + "arbiter_review_duration_seconds", + "Review processing duration in seconds", + buckets=[1, 5, 10, 30, 60, 120, 300, 600], +) + +FINDINGS_TOTAL = Counter( + "arbiter_findings_total", + "Total number of findings by severity", + ["severity", "agent"], +) + +LLM_TOKENS_TOTAL = Counter( + "arbiter_llm_tokens_total", + "Total LLM tokens used", + ["direction"], # input/output +) + +LLM_COST_TOTAL = Counter( + "arbiter_llm_cost_usd_total", + "Total LLM cost in USD", +) + +QUEUE_SIZE = Gauge( + "arbiter_queue_size", + "Number of jobs in the review queue", + ["priority"], +) + +DB_CONNECTIONS = Gauge( + "arbiter_db_connections", + "Number of active database connections", +) + + +class HealthCheck(BaseModel): + """Health check response.""" + + status: str + version: str + + +class ReadinessCheck(BaseModel): + """Readiness check response with component status.""" + + status: str + components: dict[str, dict[str, Any]] + + +@router.get("/health", response_model=HealthCheck) +async def health_check() -> HealthCheck: + from arbiter import __version__ + + return HealthCheck(status="healthy", version=__version__) + + +@router.get("/health/ready", response_model=ReadinessCheck) +async def readiness_check( + db: DbSession, + redis: RedisClient, +) -> ReadinessCheck: + """Readiness check - verifies database and redis connectivity.""" + components: dict[str, dict[str, Any]] = {} + overall_healthy = True + + # Check database + try: + await db.execute(text("SELECT 1")) + components["database"] = {"status": "healthy"} + except Exception as e: + logger.warning("Database health check failed: %s", e) + components["database"] = {"status": "unhealthy", "error": str(e)} + overall_healthy = False + + # Check Redis + try: + await redis.ping() + components["redis"] = {"status": "healthy"} + except Exception as e: + logger.warning("Redis health check failed: %s", e) + components["redis"] = {"status": "unhealthy", "error": str(e)} + overall_healthy = False + + # Check worker queue (via Redis) + try: + llen_result = redis.llen("arbiter:queue:2") + queue_size = await llen_result # type: ignore[misc] + components["worker"] = {"status": "healthy", "queue_size": queue_size} + QUEUE_SIZE.labels(priority="normal").set(queue_size) + except Exception as e: + logger.warning("Worker queue check failed: %s", e) + components["worker"] = {"status": "unknown", "error": str(e)} + + return ReadinessCheck( + status="ready" if overall_healthy else "not_ready", + components=components, + ) + + +@router.get("/metrics") +async def metrics() -> Response: + return Response( + content=generate_latest(), + media_type=CONTENT_TYPE_LATEST, + ) + + +@router.get("/health/live", status_code=status.HTTP_200_OK) +async def liveness() -> dict[str, str]: + return {"status": "alive"} + + +# Helper functions to record metrics (called from worker tasks) +def record_review_completed( + duration_seconds: float, + review_status: str, + verdict: str | None, + findings_by_severity: dict[str, int], + tokens_in: int, + tokens_out: int, + cost_usd: float, +) -> None: + """Record metrics for a completed review.""" + REVIEWS_TOTAL.labels(status=review_status, verdict=verdict or "none").inc() + REVIEW_DURATION.observe(duration_seconds) + + for severity, count in findings_by_severity.items(): + for _ in range(count): + FINDINGS_TOTAL.labels(severity=severity, agent="all").inc() + + LLM_TOKENS_TOTAL.labels(direction="input").inc(tokens_in) + LLM_TOKENS_TOTAL.labels(direction="output").inc(tokens_out) + LLM_COST_TOTAL.inc(cost_usd) diff --git a/src/arbiter/api/routes/reviews.py b/src/arbiter/api/routes/reviews.py new file mode 100644 index 0000000..530a66b --- /dev/null +++ b/src/arbiter/api/routes/reviews.py @@ -0,0 +1,387 @@ +"""Review REST API endpoints.""" + +import logging +from datetime import datetime +from typing import Annotated, Any + +from fastapi import APIRouter, HTTPException, Query, status +from pydantic import BaseModel, Field +from sqlalchemy import func, select +from sqlalchemy.orm import selectinload + +from arbiter.api.deps import DbSession, RedisClient +from arbiter.db.models import ( + DeliberationStepModel, + ReviewModel, +) +from arbiter.models.enums import Severity, Verdict +from arbiter.worker.queue import enqueue_review + +router = APIRouter() +logger = logging.getLogger(__name__) + + +# Response models +class FindingResponse(BaseModel): + """Finding response schema.""" + + id: str + agent: str + file: str + line_start: int + line_end: int + severity: str + confidence: float + title: str + description: str + reasoning: str + suggestion: str | None + references: list[str] | None + prompt_version: str + + +class ConflictResponse(BaseModel): + """Conflict response schema.""" + + id: str + finding_ids: list[str] + nature: str + description: str + severity_weight: float + resolution: str | None + winning_finding_id: str | None + + +class DeliberationStepResponse(BaseModel): + """Deliberation step response schema.""" + + id: str + step_type: str + timestamp: datetime + description: str + details: dict[str, Any] | None + sequence: int + + +class ReviewSummary(BaseModel): + """Review summary for list endpoints.""" + + id: str + repository: str + pr_number: int + pr_title: str | None + author: str | None + status: str + verdict: str | None + verdict_confidence: float | None + finding_count: int = 0 + critical_count: int = 0 + high_count: int = 0 + total_cost_usd: float + created_at: datetime + completed_at: datetime | None + + +class ReviewDetail(BaseModel): + """Full review detail response.""" + + id: str + repository: str + pr_number: int + pr_title: str | None + base_sha: str + head_sha: str + author: str | None + is_draft: bool + status: str + verdict: str | None + verdict_confidence: float | None + verdict_reasoning: str | None + total_tokens: int + total_cost_usd: float + tokens_by_agent: dict[str, int] | None + cost_by_agent: dict[str, float] | None + created_at: datetime + started_at: datetime | None + completed_at: datetime | None + error_message: str | None + findings: list[FindingResponse] + conflicts: list[ConflictResponse] + + +class ReviewListResponse(BaseModel): + """Paginated review list response.""" + + items: list[ReviewSummary] + total: int + page: int + page_size: int + pages: int + + +class DeliberationLogResponse(BaseModel): + """Deliberation log response.""" + + review_id: str + steps: list[DeliberationStepResponse] + + +class ManualReviewRequest(BaseModel): + """Request body for manual review trigger.""" + + repository: str = Field(description="Repository name (owner/repo)") + pr_number: int = Field(ge=1, description="Pull request number") + base_sha: str = Field(min_length=7, max_length=40, description="Base commit SHA") + head_sha: str = Field(min_length=7, max_length=40, description="Head commit SHA") + pr_title: str | None = Field(default=None, description="PR title") + author: str | None = Field(default=None, description="PR author") + is_draft: bool = Field(default=False, description="Whether PR is a draft") + policy_name: str | None = Field(default=None, description="Policy name to use") + diff_content: str | None = Field(default=None, description="Diff content to review") + + +class ManualReviewResponse(BaseModel): + """Response for manual review trigger.""" + + status: str + job_id: str | None + review_id: str | None + message: str + + +@router.get("", response_model=ReviewListResponse) +async def list_reviews( + db: DbSession, + page: Annotated[int, Query(ge=1)] = 1, + page_size: Annotated[int, Query(ge=1, le=100)] = 20, + repository: str | None = None, + status: str | None = None, + verdict: Verdict | None = None, + author: str | None = None, +) -> ReviewListResponse: + """List reviews with pagination and filtering.""" + # Build query + query = select(ReviewModel) + + if repository: + query = query.where(ReviewModel.repository == repository) + if status: + query = query.where(ReviewModel.status == status) + if verdict: + query = query.where(ReviewModel.verdict == verdict) + if author: + query = query.where(ReviewModel.author == author) + + # Get total count + count_query = select(func.count()).select_from(query.subquery()) + total = await db.scalar(count_query) or 0 + + # Add pagination and ordering + query = query.order_by(ReviewModel.created_at.desc()) + query = query.offset((page - 1) * page_size).limit(page_size) + + # Execute with eager loading + query = query.options(selectinload(ReviewModel.findings)) + result = await db.execute(query) + reviews = result.scalars().all() + + # Build response + items = [] + for review in reviews: + # Count findings by severity + critical_count = sum(1 for f in review.findings if f.severity == Severity.CRITICAL) + high_count = sum(1 for f in review.findings if f.severity == Severity.HIGH) + + items.append( + ReviewSummary( + id=review.id, + repository=review.repository, + pr_number=review.pr_number, + pr_title=review.pr_title, + author=review.author, + status=review.status, + verdict=review.verdict if review.verdict else None, + verdict_confidence=review.verdict_confidence, + finding_count=len(review.findings), + critical_count=critical_count, + high_count=high_count, + total_cost_usd=review.total_cost_usd, + created_at=review.created_at, + completed_at=review.completed_at, + ) + ) + + pages = (total + page_size - 1) // page_size + + return ReviewListResponse( + items=items, + total=total, + page=page, + page_size=page_size, + pages=pages, + ) + + +@router.get("/{review_id}", response_model=ReviewDetail) +async def get_review( + db: DbSession, + review_id: str, +) -> ReviewDetail: + """Get review detail with all findings and conflicts.""" + query = ( + select(ReviewModel) + .where(ReviewModel.id == review_id) + .options( + selectinload(ReviewModel.findings), + selectinload(ReviewModel.conflicts), + ) + ) + result = await db.execute(query) + review = result.scalar_one_or_none() + + if not review: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Review {review_id} not found", + ) + + return ReviewDetail( + id=review.id, + repository=review.repository, + pr_number=review.pr_number, + pr_title=review.pr_title, + base_sha=review.base_sha, + head_sha=review.head_sha, + author=review.author, + is_draft=review.is_draft, + status=review.status, + verdict=review.verdict if review.verdict else None, + verdict_confidence=review.verdict_confidence, + verdict_reasoning=review.verdict_reasoning, + total_tokens=review.total_tokens, + total_cost_usd=review.total_cost_usd, + tokens_by_agent=review.tokens_by_agent, + cost_by_agent=review.cost_by_agent, + created_at=review.created_at, + started_at=review.started_at, + completed_at=review.completed_at, + error_message=review.error_message, + findings=[ + FindingResponse( + id=f.id, + agent=f.agent.value if hasattr(f.agent, "value") else str(f.agent), + file=f.file, + line_start=f.line_start, + line_end=f.line_end, + severity=f.severity.value if hasattr(f.severity, "value") else str(f.severity), + confidence=f.confidence, + title=f.title, + description=f.description, + reasoning=f.reasoning, + suggestion=f.suggestion, + references=f.references, + prompt_version=f.prompt_version, + ) + for f in review.findings + ], + conflicts=[ + ConflictResponse( + id=c.id, + finding_ids=c.finding_ids, + nature=c.nature.value if hasattr(c.nature, "value") else str(c.nature), + description=c.description, + severity_weight=c.severity_weight, + resolution=c.resolution, + winning_finding_id=c.winning_finding_id, + ) + for c in review.conflicts + ], + ) + + +@router.get("/{review_id}/deliberation", response_model=DeliberationLogResponse) +async def get_deliberation_log( + db: DbSession, + review_id: str, +) -> DeliberationLogResponse: + """Get the deliberation log for a review.""" + # Verify review exists + review_query = select(ReviewModel.id).where(ReviewModel.id == review_id) + review_result = await db.execute(review_query) + if not review_result.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Review {review_id} not found", + ) + + # Get deliberation steps + query = ( + select(DeliberationStepModel) + .where(DeliberationStepModel.review_id == review_id) + .order_by(DeliberationStepModel.sequence) + ) + result = await db.execute(query) + steps = result.scalars().all() + + return DeliberationLogResponse( + review_id=review_id, + steps=[ + DeliberationStepResponse( + id=s.id, + step_type=s.step_type.value if hasattr(s.step_type, "value") else str(s.step_type), + timestamp=s.timestamp, + description=s.description, + details=s.details, + sequence=s.sequence, + ) + for s in steps + ], + ) + + +@router.post("", response_model=ManualReviewResponse, status_code=status.HTTP_202_ACCEPTED) +async def trigger_manual_review( + _db: DbSession, + _redis: RedisClient, + request: ManualReviewRequest, +) -> ManualReviewResponse: + """Trigger a manual review (for testing or re-review).""" + # Check for required diff_content + if not request.diff_content: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="diff_content is required for manual reviews", + ) + + # Queue the review + job_id = await enqueue_review( + repository=request.repository, + pr_number=request.pr_number, + base_sha=request.base_sha, + head_sha=request.head_sha, + pr_title=request.pr_title, + author=request.author, + is_draft=request.is_draft, + policy_name=request.policy_name, + ) + + if job_id: + logger.info( + "Manual review queued: %s for %s PR #%d", + job_id, + request.repository, + request.pr_number, + ) + return ManualReviewResponse( + status="queued", + job_id=job_id, + review_id=None, + message=f"Review queued with job ID {job_id}", + ) + else: + return ManualReviewResponse( + status="duplicate", + job_id=None, + review_id=None, + message="Review already queued for this commit", + ) diff --git a/src/arbiter/api/routes/webhooks.py b/src/arbiter/api/routes/webhooks.py new file mode 100644 index 0000000..f031cac --- /dev/null +++ b/src/arbiter/api/routes/webhooks.py @@ -0,0 +1,413 @@ +"""Webhook routes for GitHub and GitLab integrations.""" + +import hashlib +import hmac +import logging +from typing import Any + +from fastapi import APIRouter, Header, HTTPException, Request, status + +from arbiter.api.deps import AppSettings, RedisClient +from arbiter.integrations import has_arbiter_marker +from arbiter.worker.queue import enqueue_followup, enqueue_review + +router = APIRouter() +logger = logging.getLogger(__name__) + + +def _verify_github_signature( + payload: bytes, + signature: str | None, + secret: str, +) -> bool: + """Verify GitHub webhook HMAC-SHA256 signature. + + GitHub sends signatures in format: sha256= + """ + if not signature: + return False + + if not signature.startswith("sha256="): + return False + + expected_signature = signature[7:] # Remove 'sha256=' prefix + computed = hmac.new( + secret.encode(), + payload, + hashlib.sha256, + ).hexdigest() + + return hmac.compare_digest(computed, expected_signature) + + +def _verify_gitlab_token( + token: str | None, + expected_token: str, +) -> bool: + if not token: + return False + return hmac.compare_digest(token, expected_token) + + +@router.post("/github") +async def github_webhook( + request: Request, + settings: AppSettings, + _redis: RedisClient, + x_hub_signature_256: str | None = Header(None), + x_github_event: str | None = Header(None), +) -> dict[str, Any]: + """Handle GitHub webhook events. + + Accepts pull_request events and queues reviews for opened/synchronized PRs. + Also handles issue_comment events for follow-up questions. + """ + # Read raw body for signature verification + body = await request.body() + + # Verify signature if secret is configured + if settings.github_webhook_secret: + secret = settings.github_webhook_secret.get_secret_value() + if not _verify_github_signature(body, x_hub_signature_256, secret): + logger.warning("Invalid GitHub webhook signature") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid signature", + ) + + # Parse payload + try: + payload = await request.json() + except Exception as e: + logger.warning("Invalid JSON payload: %s", e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid JSON payload", + ) from e + + # Handle issue_comment events (PR comments) + if x_github_event == "issue_comment": + return await _handle_github_comment(payload, settings) + + # Only process pull_request events + if x_github_event != "pull_request": + logger.debug("Ignoring GitHub event: %s", x_github_event) + return {"status": "ignored", "reason": f"Event type '{x_github_event}' not processed"} + + # Only process relevant actions + action = payload.get("action") + if action not in ("opened", "synchronize", "reopened"): + logger.debug("Ignoring PR action: %s", action) + return {"status": "ignored", "reason": f"Action '{action}' not processed"} + + # Extract PR metadata + pr = payload.get("pull_request", {}) + repo = payload.get("repository", {}) + + repository = repo.get("full_name") + pr_number = pr.get("number") + pr_title = pr.get("title") + base_sha = pr.get("base", {}).get("sha") + head_sha = pr.get("head", {}).get("sha") + author = pr.get("user", {}).get("login") + is_draft = pr.get("draft", False) + + if not all([repository, pr_number, base_sha, head_sha]): + logger.warning("Missing required fields in GitHub payload") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing required fields", + ) + + # Queue the review + job_id = await enqueue_review( + repository=repository, + pr_number=pr_number, + base_sha=base_sha, + head_sha=head_sha, + pr_title=pr_title, + author=author, + is_draft=is_draft, + platform="github", + ) + + if job_id: + logger.info("Queued review %s for %s PR #%d", job_id, repository, pr_number) + return { + "status": "queued", + "job_id": job_id, + "repository": repository, + "pr_number": pr_number, + } + else: + return { + "status": "duplicate", + "reason": "Review already queued for this commit", + "repository": repository, + "pr_number": pr_number, + } + + +async def _handle_github_comment( + payload: dict[str, Any], + settings: AppSettings, +) -> dict[str, Any]: + """Handle GitHub issue_comment events for follow-up questions. + + Args: + payload: The webhook payload. + settings: Application settings. + + Returns: + Response dict indicating how the comment was handled. + """ + # Only process created comments + action = payload.get("action") + if action != "created": + logger.debug("Ignoring comment action: %s", action) + return {"status": "ignored", "reason": f"Comment action '{action}' not processed"} + + # Check if this is a PR comment (not a regular issue comment) + issue = payload.get("issue", {}) + if "pull_request" not in issue: + logger.debug("Ignoring non-PR comment") + return {"status": "ignored", "reason": "Not a pull request comment"} + + # Check if follow-up is enabled + if not settings.followup_enabled: + logger.debug("Follow-up processing is disabled") + return {"status": "ignored", "reason": "Follow-up processing disabled"} + + # Extract comment data + comment = payload.get("comment", {}) + comment_body = comment.get("body", "") + comment_id = str(comment.get("id", "")) + author = comment.get("user", {}).get("login", "") + + # Ignore our own comments (they contain the Arbiter marker) + if has_arbiter_marker(comment_body): + logger.debug("Ignoring Arbiter's own comment") + return {"status": "ignored", "reason": "Arbiter's own comment"} + + # Extract PR metadata + repo = payload.get("repository", {}) + repository = repo.get("full_name", "") + pr_number = issue.get("number", 0) + + if not repository or not pr_number: + logger.warning("Missing repository or PR number in comment payload") + return {"status": "ignored", "reason": "Missing required fields"} + + # Queue the follow-up processing + job_id = await enqueue_followup( + repository=repository, + pr_number=pr_number, + comment_id=comment_id, + comment_body=comment_body, + author=author, + platform="github", + ) + + if job_id: + logger.info( + "Queued follow-up %s for %s PR #%d comment %s", + job_id, + repository, + pr_number, + comment_id, + ) + return { + "status": "queued", + "job_id": job_id, + "repository": repository, + "pr_number": pr_number, + "comment_id": comment_id, + } + else: + return { + "status": "duplicate", + "reason": "Follow-up already queued for this comment", + "repository": repository, + "pr_number": pr_number, + } + + +@router.post("/gitlab") +async def gitlab_webhook( + request: Request, + settings: AppSettings, + _redis: RedisClient, + x_gitlab_token: str | None = Header(None), +) -> dict[str, Any]: + """Handle GitLab webhook events. + + Accepts merge_request events and queues reviews for new/updated MRs. + Also handles note events for follow-up questions. + """ + # Verify token if configured + if settings.gitlab_webhook_token: + token = settings.gitlab_webhook_token.get_secret_value() + if not _verify_gitlab_token(x_gitlab_token, token): + logger.warning("Invalid GitLab webhook token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) + + # Parse payload + try: + payload = await request.json() + except Exception as e: + logger.warning("Invalid JSON payload: %s", e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid JSON payload", + ) from e + + # Handle note events (MR comments) + object_kind = payload.get("object_kind") + if object_kind == "note": + return await _handle_gitlab_note(payload, settings) + + # Only process merge_request events + if object_kind != "merge_request": + logger.debug("Ignoring GitLab event: %s", object_kind) + return {"status": "ignored", "reason": f"Event type '{object_kind}' not processed"} + + # Only process relevant actions + object_attrs = payload.get("object_attributes", {}) + action = object_attrs.get("action") + if action not in ("open", "reopen", "update"): + logger.debug("Ignoring MR action: %s", action) + return {"status": "ignored", "reason": f"Action '{action}' not processed"} + + # Extract MR metadata + project = payload.get("project", {}) + + repository = project.get("path_with_namespace") + pr_number = object_attrs.get("iid") + pr_title = object_attrs.get("title") + base_sha = object_attrs.get("target_branch") + head_sha = object_attrs.get("last_commit", {}).get("id") + author = payload.get("user", {}).get("username") + is_draft = object_attrs.get("work_in_progress", False) + + if not all([repository, pr_number, head_sha]): + logger.warning("Missing required fields in GitLab payload") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing required fields", + ) + + # Use target branch as base_sha if actual SHA not available + if not base_sha: + base_sha = object_attrs.get("target_branch", "main") + + # Queue the review + job_id = await enqueue_review( + repository=repository, + pr_number=pr_number, + base_sha=base_sha, + head_sha=head_sha, + pr_title=pr_title, + author=author, + is_draft=is_draft, + platform="gitlab", + ) + + if job_id: + logger.info("Queued review %s for %s MR !%d", job_id, repository, pr_number) + return { + "status": "queued", + "job_id": job_id, + "repository": repository, + "mr_number": pr_number, + } + else: + return { + "status": "duplicate", + "reason": "Review already queued for this commit", + "repository": repository, + "mr_number": pr_number, + } + + +async def _handle_gitlab_note( + payload: dict[str, Any], + settings: AppSettings, +) -> dict[str, Any]: + """Handle GitLab note events for follow-up questions on MRs. + + Args: + payload: The webhook payload. + settings: Application settings. + + Returns: + Response dict indicating how the note was handled. + """ + # Check if this is a MR note (not issue or other note type) + object_attrs = payload.get("object_attributes", {}) + noteable_type = object_attrs.get("noteable_type") + + if noteable_type != "MergeRequest": + logger.debug("Ignoring non-MR note: %s", noteable_type) + return {"status": "ignored", "reason": f"Note type '{noteable_type}' not processed"} + + # Check if follow-up is enabled + if not settings.followup_enabled: + logger.debug("Follow-up processing is disabled") + return {"status": "ignored", "reason": "Follow-up processing disabled"} + + # Extract note data + note_body = object_attrs.get("note", "") + note_id = str(object_attrs.get("id", "")) + author = payload.get("user", {}).get("username", "") + + # Ignore our own comments (they contain the Arbiter marker) + if has_arbiter_marker(note_body): + logger.debug("Ignoring Arbiter's own note") + return {"status": "ignored", "reason": "Arbiter's own comment"} + + # Extract MR metadata + project = payload.get("project", {}) + merge_request = payload.get("merge_request", {}) + + repository = project.get("path_with_namespace", "") + pr_number = merge_request.get("iid", 0) + + if not repository or not pr_number: + logger.warning("Missing repository or MR number in note payload") + return {"status": "ignored", "reason": "Missing required fields"} + + # Queue the follow-up processing + job_id = await enqueue_followup( + repository=repository, + pr_number=pr_number, + comment_id=note_id, + comment_body=note_body, + author=author, + platform="gitlab", + ) + + if job_id: + logger.info( + "Queued follow-up %s for %s MR !%d note %s", + job_id, + repository, + pr_number, + note_id, + ) + return { + "status": "queued", + "job_id": job_id, + "repository": repository, + "mr_number": pr_number, + "note_id": note_id, + } + else: + return { + "status": "duplicate", + "reason": "Follow-up already queued for this note", + "repository": repository, + "mr_number": pr_number, + } diff --git a/src/arbiter/main.py b/src/arbiter/main.py new file mode 100644 index 0000000..1e9d640 --- /dev/null +++ b/src/arbiter/main.py @@ -0,0 +1,81 @@ +"""FastAPI application entry point for Arbiter.""" + +import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from arbiter.api.deps import close_redis, get_settings +from arbiter.api.routes import health, reviews, webhooks +from arbiter.db.session import close_db, init_db + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: + """Application lifespan handler for startup/shutdown.""" + # Startup + logger.info("Starting Arbiter API") + await init_db() + logger.info("Database initialized") + yield + # Shutdown + logger.info("Shutting down Arbiter API") + await close_db() + await close_redis() + logger.info("Connections closed") + + +def create_app() -> FastAPI: + """Create and configure the FastAPI application.""" + settings = get_settings() + + app = FastAPI( + title=settings.api_title, + version=settings.api_version, + description="Multi-agent code review system that shows its work", + lifespan=lifespan, + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Register exception handlers + @app.exception_handler(ValueError) + async def value_error_handler(_request: Request, exc: ValueError) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(exc)}, + ) + + @app.exception_handler(Exception) + async def general_exception_handler(_request: Request, _exc: Exception) -> JSONResponse: + logger.exception("Unhandled exception") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error"}, + ) + + # Register routers + app.include_router(health.router, tags=["health"]) + app.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"]) + app.include_router(reviews.router, prefix="/api/reviews", tags=["reviews"]) + + return app + + +# Create the app instance +app = create_app()