add arq worker with redis queue
This commit is contained in:
12
src/arbiter/worker/__init__.py
Normal file
12
src/arbiter/worker/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Worker module for async job processing."""
|
||||||
|
|
||||||
|
from arbiter.worker.queue import JobPriority, enqueue_review
|
||||||
|
from arbiter.worker.settings import WorkerSettings
|
||||||
|
from arbiter.worker.tasks import process_review
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"JobPriority",
|
||||||
|
"WorkerSettings",
|
||||||
|
"enqueue_review",
|
||||||
|
"process_review",
|
||||||
|
]
|
||||||
161
src/arbiter/worker/queue.py
Normal file
161
src/arbiter/worker/queue.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Queue utilities for job management."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from arq import ArqRedis, create_pool
|
||||||
|
from arq.connections import RedisSettings
|
||||||
|
|
||||||
|
from arbiter.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JobPriority(IntEnum):
|
||||||
|
"""Job priority levels (lower = higher priority)."""
|
||||||
|
|
||||||
|
HIGH = 1
|
||||||
|
NORMAL = 2
|
||||||
|
LOW = 3
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redis_pool() -> ArqRedis:
|
||||||
|
settings = get_settings()
|
||||||
|
return await create_pool(RedisSettings.from_dsn(settings.redis_url))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_job_id(repository: str, pr_number: int, head_sha: str) -> str:
|
||||||
|
key = f"{repository}:{pr_number}:{head_sha}"
|
||||||
|
return hashlib.sha256(key.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
async def enqueue_review(
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
base_sha: str,
|
||||||
|
head_sha: str,
|
||||||
|
pr_title: str | None = None,
|
||||||
|
author: str | None = None,
|
||||||
|
is_draft: bool = False,
|
||||||
|
policy_name: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Enqueue a review job with deduplication.
|
||||||
|
|
||||||
|
Returns the job ID if enqueued, None if a job for this commit already exists.
|
||||||
|
"""
|
||||||
|
redis = await get_redis_pool()
|
||||||
|
job_id = generate_job_id(repository, pr_number, head_sha)
|
||||||
|
|
||||||
|
# Check if job already exists
|
||||||
|
existing = await redis.get(f"arbiter:job:{job_id}")
|
||||||
|
if existing:
|
||||||
|
logger.info("Job %s already exists for %s PR #%d", job_id, repository, pr_number)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Determine priority
|
||||||
|
priority = JobPriority.LOW if is_draft else JobPriority.NORMAL
|
||||||
|
|
||||||
|
# Enqueue the job
|
||||||
|
job = await redis.enqueue_job(
|
||||||
|
"process_review",
|
||||||
|
repository=repository,
|
||||||
|
pr_number=pr_number,
|
||||||
|
base_sha=base_sha,
|
||||||
|
head_sha=head_sha,
|
||||||
|
pr_title=pr_title,
|
||||||
|
author=author,
|
||||||
|
is_draft=is_draft,
|
||||||
|
policy_name=policy_name,
|
||||||
|
platform=platform,
|
||||||
|
_job_id=job_id,
|
||||||
|
_queue_name=f"arbiter:queue:{priority}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if job:
|
||||||
|
# Mark job as pending for deduplication
|
||||||
|
await redis.set(f"arbiter:job:{job_id}", "pending", ex=3600) # 1 hour TTL
|
||||||
|
logger.info("Enqueued job %s for %s PR #%d", job_id, repository, pr_number)
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_job_status(job_id: str) -> dict[str, Any] | None:
|
||||||
|
redis = await get_redis_pool()
|
||||||
|
job = await redis.get(f"arbiter:job:{job_id}")
|
||||||
|
if not job:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"job_id": job_id,
|
||||||
|
"status": job.decode() if isinstance(job, bytes) else job,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_job(job_id: str) -> bool:
|
||||||
|
redis = await get_redis_pool()
|
||||||
|
deleted: int = await redis.delete(f"arbiter:job:{job_id}")
|
||||||
|
return deleted > 0
|
||||||
|
|
||||||
|
|
||||||
|
def generate_followup_job_id(repository: str, pr_number: int, comment_id: str) -> str:
|
||||||
|
key = f"followup:{repository}:{pr_number}:{comment_id}"
|
||||||
|
return hashlib.sha256(key.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
async def enqueue_followup(
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
comment_id: str,
|
||||||
|
comment_body: str,
|
||||||
|
author: str,
|
||||||
|
platform: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Enqueue a follow-up processing job with deduplication.
|
||||||
|
|
||||||
|
Returns the job ID if enqueued, None if a job for this comment already exists.
|
||||||
|
"""
|
||||||
|
redis = await get_redis_pool()
|
||||||
|
job_id = generate_followup_job_id(repository, pr_number, comment_id)
|
||||||
|
|
||||||
|
# Check if job already exists
|
||||||
|
existing = await redis.get(f"arbiter:followup:{job_id}")
|
||||||
|
if existing:
|
||||||
|
logger.info(
|
||||||
|
"Follow-up job %s already exists for %s PR #%d comment %s",
|
||||||
|
job_id,
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
comment_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Enqueue the job with normal priority
|
||||||
|
job = await redis.enqueue_job(
|
||||||
|
"process_followup",
|
||||||
|
repository=repository,
|
||||||
|
pr_number=pr_number,
|
||||||
|
comment_id=comment_id,
|
||||||
|
comment_body=comment_body,
|
||||||
|
author=author,
|
||||||
|
platform=platform,
|
||||||
|
_job_id=job_id,
|
||||||
|
_queue_name=f"arbiter:queue:{JobPriority.NORMAL}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if job:
|
||||||
|
# Mark job as pending for deduplication (shorter TTL for follow-ups)
|
||||||
|
await redis.set(f"arbiter:followup:{job_id}", "pending", ex=1800) # 30 min TTL
|
||||||
|
logger.info(
|
||||||
|
"Enqueued follow-up job %s for %s PR #%d comment %s",
|
||||||
|
job_id,
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
comment_id,
|
||||||
|
)
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
return None
|
||||||
68
src/arbiter/worker/settings.py
Normal file
68
src/arbiter/worker/settings.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""arq worker settings for Arbiter."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Coroutine
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from arq.connections import RedisSettings
|
||||||
|
|
||||||
|
from arbiter.config import get_settings
|
||||||
|
from arbiter.db.session import close_db, init_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def startup(ctx: dict[str, Any]) -> None:
|
||||||
|
logger.info("Worker starting up")
|
||||||
|
await init_db()
|
||||||
|
ctx["settings"] = get_settings()
|
||||||
|
logger.info("Worker ready")
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown(_ctx: dict[str, Any]) -> None:
|
||||||
|
logger.info("Worker shutting down")
|
||||||
|
await close_db()
|
||||||
|
logger.info("Worker stopped")
|
||||||
|
|
||||||
|
|
||||||
|
async def health_check(_ctx: dict[str, Any]) -> str:
|
||||||
|
return "healthy"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerSettings:
|
||||||
|
"""arq worker settings."""
|
||||||
|
|
||||||
|
# Import functions here to avoid circular imports
|
||||||
|
@staticmethod
|
||||||
|
def _get_functions() -> list[Callable[..., Coroutine[Any, Any, Any]]]:
|
||||||
|
from arbiter.worker.tasks import process_followup, process_review
|
||||||
|
|
||||||
|
return [process_review, process_followup]
|
||||||
|
|
||||||
|
functions: ClassVar[list[Callable[..., Coroutine[Any, Any, Any]]]] = []
|
||||||
|
cron_jobs: ClassVar[list[Any]] = []
|
||||||
|
on_startup = startup
|
||||||
|
on_shutdown = shutdown
|
||||||
|
|
||||||
|
# Redis connection
|
||||||
|
@staticmethod
|
||||||
|
def redis_settings() -> RedisSettings:
|
||||||
|
settings = get_settings()
|
||||||
|
return RedisSettings.from_dsn(settings.redis_url)
|
||||||
|
|
||||||
|
# Job settings
|
||||||
|
max_jobs = get_settings().worker_max_jobs
|
||||||
|
job_timeout = get_settings().worker_job_timeout
|
||||||
|
max_tries = get_settings().worker_retry_attempts
|
||||||
|
retry_jobs = True
|
||||||
|
|
||||||
|
# Queue settings
|
||||||
|
queue_name = "arbiter:queue"
|
||||||
|
health_check_interval = 60
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
cls.functions = cls._get_functions()
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize functions list
|
||||||
|
WorkerSettings.functions = WorkerSettings._get_functions()
|
||||||
794
src/arbiter/worker/tasks.py
Normal file
794
src/arbiter/worker/tasks.py
Normal file
@@ -0,0 +1,794 @@
|
|||||||
|
"""Worker task definitions for Arbiter."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from arbiter.agents import ComplexityAgent, ExplainContext, ReviewContext, SecurityAgent, StyleAgent
|
||||||
|
from arbiter.analysis import DiffParser, StaticAnalysisRunner
|
||||||
|
from arbiter.config import Settings, get_settings
|
||||||
|
from arbiter.conversation import AgentRouter, QuestionDetector
|
||||||
|
from arbiter.db.models import (
|
||||||
|
ConflictModel,
|
||||||
|
ConversationMessageModel,
|
||||||
|
ConversationModel,
|
||||||
|
DeliberationStepModel,
|
||||||
|
FindingModel,
|
||||||
|
PolicyModel,
|
||||||
|
ReviewModel,
|
||||||
|
)
|
||||||
|
from arbiter.db.session import async_session_factory
|
||||||
|
from arbiter.deliberation import Coordinator, DeliberationResult
|
||||||
|
from arbiter.integrations import (
|
||||||
|
ARBITER_MARKER,
|
||||||
|
CommitStatus,
|
||||||
|
GitHubClient,
|
||||||
|
GitLabClient,
|
||||||
|
IntegrationError,
|
||||||
|
Platform,
|
||||||
|
PlatformClient,
|
||||||
|
ReviewCommentFormatter,
|
||||||
|
has_arbiter_marker,
|
||||||
|
)
|
||||||
|
from arbiter.llm import LiteLLMClient, PromptRegistry
|
||||||
|
from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity, Verdict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_platform(repository: str, webhook_source: str | None = None) -> Platform:
|
||||||
|
"""Detect the platform from repository format or webhook source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name (owner/repo format).
|
||||||
|
webhook_source: Explicit platform from webhook ("github" or "gitlab").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Detected platform.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If platform cannot be determined.
|
||||||
|
"""
|
||||||
|
if webhook_source:
|
||||||
|
if webhook_source.lower() == "github":
|
||||||
|
return Platform.GITHUB
|
||||||
|
if webhook_source.lower() == "gitlab":
|
||||||
|
return Platform.GITLAB
|
||||||
|
|
||||||
|
# Default to GitHub if not specified
|
||||||
|
# In practice, webhooks should always provide the source
|
||||||
|
logger.warning("Platform not specified, defaulting to GitHub for %s", repository)
|
||||||
|
return Platform.GITHUB
|
||||||
|
|
||||||
|
|
||||||
|
def get_platform_client(platform: Platform, settings: Settings) -> PlatformClient | None:
|
||||||
|
"""Create a platform client based on the detected platform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: Target platform.
|
||||||
|
settings: Application settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Platform client or None if no token is configured.
|
||||||
|
"""
|
||||||
|
if platform == Platform.GITHUB:
|
||||||
|
if not settings.github_token:
|
||||||
|
logger.warning("GitHub token not configured, skipping platform integration")
|
||||||
|
return None
|
||||||
|
return GitHubClient(
|
||||||
|
token=settings.github_token.get_secret_value(),
|
||||||
|
base_url=settings.github_base_url,
|
||||||
|
timeout=settings.integration_timeout,
|
||||||
|
max_retries=settings.integration_max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
if platform == Platform.GITLAB:
|
||||||
|
if not settings.gitlab_token:
|
||||||
|
logger.warning("GitLab token not configured, skipping platform integration")
|
||||||
|
return None
|
||||||
|
return GitLabClient(
|
||||||
|
token=settings.gitlab_token.get_secret_value(),
|
||||||
|
base_url=settings.gitlab_base_url,
|
||||||
|
timeout=settings.integration_timeout,
|
||||||
|
max_retries=settings.integration_max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_status_safe(
|
||||||
|
client: PlatformClient,
|
||||||
|
repository: str,
|
||||||
|
sha: str,
|
||||||
|
status: CommitStatus,
|
||||||
|
description: str,
|
||||||
|
context: str,
|
||||||
|
target_url: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update commit status, logging errors without raising.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Platform client.
|
||||||
|
repository: Repository name.
|
||||||
|
sha: Commit SHA.
|
||||||
|
status: Status to set.
|
||||||
|
description: Status description.
|
||||||
|
context: Status context name.
|
||||||
|
target_url: Optional URL for details.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await client.update_commit_status(
|
||||||
|
repository=repository,
|
||||||
|
sha=sha,
|
||||||
|
status=status,
|
||||||
|
description=description,
|
||||||
|
context=context,
|
||||||
|
target_url=target_url,
|
||||||
|
)
|
||||||
|
logger.info("Updated commit status for %s@%s: %s", repository, sha[:8], status)
|
||||||
|
except IntegrationError as e:
|
||||||
|
logger.warning("Failed to update commit status: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
async def _post_or_update_comment(
|
||||||
|
client: PlatformClient,
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
body: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Post a new comment or update existing Arbiter comment.
|
||||||
|
|
||||||
|
This provides idempotent comment behavior: re-reviews update the existing
|
||||||
|
Arbiter comment instead of posting new ones (preventing comment spam).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: Platform client.
|
||||||
|
repository: Repository name.
|
||||||
|
pr_number: PR/MR number.
|
||||||
|
body: Comment body (must contain ARBITER_MARKER).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Comment URL if successful, None otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to find existing Arbiter comment
|
||||||
|
existing_comment_id: str | None = None
|
||||||
|
try:
|
||||||
|
comments = await client.get_comments(repository, pr_number)
|
||||||
|
for comment in comments:
|
||||||
|
if has_arbiter_marker(comment.body):
|
||||||
|
existing_comment_id = comment.id
|
||||||
|
logger.info(
|
||||||
|
"Found existing Arbiter comment %s on %s#%d",
|
||||||
|
existing_comment_id,
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except IntegrationError as e:
|
||||||
|
# If we can't fetch comments, fall back to posting new
|
||||||
|
logger.warning("Failed to fetch comments, will post new: %s", e)
|
||||||
|
|
||||||
|
if existing_comment_id:
|
||||||
|
# Update existing comment
|
||||||
|
url = await client.update_comment(repository, pr_number, existing_comment_id, body)
|
||||||
|
logger.info(
|
||||||
|
"Updated comment %s on %s#%d: %s",
|
||||||
|
existing_comment_id,
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
url,
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
else:
|
||||||
|
# Post new comment
|
||||||
|
url = await client.post_comment(repository, pr_number, body)
|
||||||
|
logger.info("Posted new comment on %s#%d: %s", repository, pr_number, url)
|
||||||
|
return url
|
||||||
|
|
||||||
|
except IntegrationError as e:
|
||||||
|
logger.warning("Failed to post/update comment: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _verdict_to_status(verdict: Verdict) -> CommitStatus:
|
||||||
|
if verdict == Verdict.APPROVE:
|
||||||
|
return CommitStatus.SUCCESS
|
||||||
|
if verdict == Verdict.REQUEST_CHANGES:
|
||||||
|
return CommitStatus.FAILURE
|
||||||
|
return CommitStatus.SUCCESS # COMMENT still passes
|
||||||
|
|
||||||
|
|
||||||
|
async def process_review(
|
||||||
|
_ctx: dict[str, Any],
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
base_sha: str,
|
||||||
|
head_sha: str,
|
||||||
|
pr_title: str | None = None,
|
||||||
|
author: str | None = None,
|
||||||
|
is_draft: bool = False,
|
||||||
|
policy_name: str | None = None,
|
||||||
|
diff_content: str | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Process a code review job.
|
||||||
|
|
||||||
|
This task runs the full review pipeline:
|
||||||
|
1. Create/update review record
|
||||||
|
2. Set commit status to pending
|
||||||
|
3. Fetch diff (from platform if not provided)
|
||||||
|
4. Run static analysis
|
||||||
|
5. Run agent reviews in parallel
|
||||||
|
6. Run deliberation
|
||||||
|
7. Store results in database
|
||||||
|
8. Post comment with results
|
||||||
|
9. Update commit status to success/failure
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: arq context with settings and connections
|
||||||
|
repository: Repository name (owner/repo format)
|
||||||
|
pr_number: Pull request number
|
||||||
|
base_sha: Base commit SHA
|
||||||
|
head_sha: Head commit SHA
|
||||||
|
pr_title: PR title
|
||||||
|
author: PR author
|
||||||
|
is_draft: Whether PR is a draft
|
||||||
|
policy_name: Name of policy to use
|
||||||
|
diff_content: Optional pre-fetched diff content
|
||||||
|
platform: Source platform ("github" or "gitlab")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Review ID on success
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
session_factory = async_session_factory()
|
||||||
|
|
||||||
|
# Detect platform and create client
|
||||||
|
detected_platform = detect_platform(repository, platform)
|
||||||
|
client = get_platform_client(detected_platform, settings)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
# Create review record
|
||||||
|
review = ReviewModel(
|
||||||
|
repository=repository,
|
||||||
|
pr_number=pr_number,
|
||||||
|
pr_title=pr_title,
|
||||||
|
base_sha=base_sha,
|
||||||
|
head_sha=head_sha,
|
||||||
|
author=author,
|
||||||
|
is_draft=is_draft,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load policy if specified
|
||||||
|
policy = Policy()
|
||||||
|
if policy_name:
|
||||||
|
result = await session.execute(
|
||||||
|
select(PolicyModel).where(
|
||||||
|
PolicyModel.name == policy_name, PolicyModel.is_active.is_(True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
policy_model = result.scalar_one_or_none()
|
||||||
|
if policy_model:
|
||||||
|
review.policy_id = policy_model.id
|
||||||
|
policy = _policy_from_model(policy_model)
|
||||||
|
|
||||||
|
session.add(review)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(review)
|
||||||
|
review_id = review.id
|
||||||
|
|
||||||
|
logger.info("Started review %s for %s PR #%d", review_id, repository, pr_number)
|
||||||
|
|
||||||
|
# Set pending status
|
||||||
|
if client and settings.update_status:
|
||||||
|
await _update_status_safe(
|
||||||
|
client,
|
||||||
|
repository,
|
||||||
|
head_sha,
|
||||||
|
CommitStatus.PENDING,
|
||||||
|
"Arbiter review in progress...",
|
||||||
|
settings.status_check_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get diff content from platform if not provided
|
||||||
|
if not diff_content:
|
||||||
|
if client:
|
||||||
|
logger.info(
|
||||||
|
"Fetching diff from %s for %s#%d", detected_platform, repository, pr_number
|
||||||
|
)
|
||||||
|
diff_content = await client.get_pr_diff(repository, pr_number)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"diff_content not provided and no {detected_platform} token configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the review pipeline
|
||||||
|
agent_results, deliberation_result = await _run_review_pipeline(
|
||||||
|
diff_content=diff_content,
|
||||||
|
policy=policy,
|
||||||
|
templates_dir=settings.templates_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update review with results
|
||||||
|
review.status = "completed"
|
||||||
|
review.completed_at = datetime.now(UTC)
|
||||||
|
review.verdict = deliberation_result.verdict
|
||||||
|
review.verdict_confidence = deliberation_result.verdict_confidence
|
||||||
|
review.verdict_reasoning = deliberation_result.verdict_reasoning
|
||||||
|
|
||||||
|
# Calculate and store costs
|
||||||
|
total_tokens = (
|
||||||
|
sum(r.tokens_used for r in agent_results) + deliberation_result.tokens_used
|
||||||
|
)
|
||||||
|
total_cost = sum(r.cost_usd for r in agent_results) + deliberation_result.cost_usd
|
||||||
|
tokens_by_agent = {r.agent_name.value: r.tokens_used for r in agent_results}
|
||||||
|
cost_by_agent = {r.agent_name.value: r.cost_usd for r in agent_results}
|
||||||
|
|
||||||
|
review.total_tokens = total_tokens
|
||||||
|
review.total_cost_usd = total_cost
|
||||||
|
review.tokens_by_agent = tokens_by_agent
|
||||||
|
review.cost_by_agent = cost_by_agent
|
||||||
|
|
||||||
|
# Store findings
|
||||||
|
for finding in deliberation_result.findings:
|
||||||
|
finding_model = FindingModel(
|
||||||
|
review_id=review_id,
|
||||||
|
agent=finding.agent,
|
||||||
|
file=finding.file,
|
||||||
|
line_start=finding.line_start,
|
||||||
|
line_end=finding.line_end,
|
||||||
|
severity=finding.severity,
|
||||||
|
confidence=finding.confidence,
|
||||||
|
title=finding.title,
|
||||||
|
description=finding.description,
|
||||||
|
reasoning=finding.reasoning,
|
||||||
|
suggestion=finding.suggestion,
|
||||||
|
references=finding.references,
|
||||||
|
prompt_version=finding.prompt_version,
|
||||||
|
static_analysis_context=finding.static_analysis_context,
|
||||||
|
)
|
||||||
|
session.add(finding_model)
|
||||||
|
|
||||||
|
# Store conflicts
|
||||||
|
for conflict in deliberation_result.conflicts:
|
||||||
|
conflict_model = ConflictModel(
|
||||||
|
review_id=review_id,
|
||||||
|
finding_ids=conflict.finding_ids,
|
||||||
|
nature=conflict.nature,
|
||||||
|
description=conflict.description,
|
||||||
|
severity_weight=conflict.severity_weight,
|
||||||
|
resolution=conflict.resolution,
|
||||||
|
winning_finding_id=conflict.winning_finding_id,
|
||||||
|
)
|
||||||
|
session.add(conflict_model)
|
||||||
|
|
||||||
|
# Store deliberation steps
|
||||||
|
for i, step in enumerate(deliberation_result.steps):
|
||||||
|
step_model = DeliberationStepModel(
|
||||||
|
review_id=review_id,
|
||||||
|
step_type=step.step_type,
|
||||||
|
timestamp=step.timestamp,
|
||||||
|
description=step.description,
|
||||||
|
details=step.details,
|
||||||
|
sequence=i,
|
||||||
|
)
|
||||||
|
session.add(step_model)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.info("Completed review %s: %s", review_id, review.verdict)
|
||||||
|
|
||||||
|
# Post comment and update status
|
||||||
|
if client:
|
||||||
|
comment_url: str | None = None
|
||||||
|
|
||||||
|
if settings.post_comments:
|
||||||
|
formatter = ReviewCommentFormatter(include_cost=True)
|
||||||
|
comment_body = formatter.format(deliberation_result)
|
||||||
|
comment_url = await _post_or_update_comment(
|
||||||
|
client, repository, pr_number, comment_body
|
||||||
|
)
|
||||||
|
|
||||||
|
if settings.update_status:
|
||||||
|
final_status = _verdict_to_status(deliberation_result.verdict)
|
||||||
|
description = (
|
||||||
|
deliberation_result.verdict_reasoning[:140]
|
||||||
|
if deliberation_result.verdict_reasoning
|
||||||
|
else f"Review: {deliberation_result.verdict.value}"
|
||||||
|
)
|
||||||
|
await _update_status_safe(
|
||||||
|
client,
|
||||||
|
repository,
|
||||||
|
head_sha,
|
||||||
|
final_status,
|
||||||
|
description,
|
||||||
|
settings.status_check_context,
|
||||||
|
target_url=comment_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Review %s failed", review_id)
|
||||||
|
review.status = "failed"
|
||||||
|
review.error_message = str(e)
|
||||||
|
review.completed_at = datetime.now(UTC)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Update status to error
|
||||||
|
if client and settings.update_status:
|
||||||
|
await _update_status_safe(
|
||||||
|
client,
|
||||||
|
repository,
|
||||||
|
head_sha,
|
||||||
|
CommitStatus.ERROR,
|
||||||
|
f"Review failed: {str(e)[:100]}",
|
||||||
|
settings.status_check_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Close the client
|
||||||
|
if client:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return review_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_review_pipeline(
|
||||||
|
diff_content: str,
|
||||||
|
policy: Policy,
|
||||||
|
templates_dir: Path,
|
||||||
|
work_dir: Path | None = None,
|
||||||
|
) -> tuple[list[ReviewResult], DeliberationResult]:
|
||||||
|
"""Run the full review pipeline.
|
||||||
|
|
||||||
|
This is extracted from cli.py to be reusable by both CLI and worker.
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
llm_client = LiteLLMClient(
|
||||||
|
timeout=settings.llm_timeout,
|
||||||
|
max_retries=settings.llm_max_retries,
|
||||||
|
)
|
||||||
|
prompt_registry = PromptRegistry(templates_dir)
|
||||||
|
|
||||||
|
# Parse the diff
|
||||||
|
parser = DiffParser()
|
||||||
|
parsed_diff = parser.parse(diff_content)
|
||||||
|
|
||||||
|
# Run static analysis if work_dir is provided
|
||||||
|
static_findings: list[Finding] = []
|
||||||
|
if work_dir:
|
||||||
|
runner = StaticAnalysisRunner()
|
||||||
|
static_result = await runner.run(parsed_diff, work_dir)
|
||||||
|
for sf in static_result.findings:
|
||||||
|
static_findings.append(runner.convert_to_finding(sf))
|
||||||
|
|
||||||
|
context = ReviewContext(
|
||||||
|
diff=diff_content,
|
||||||
|
policy=policy,
|
||||||
|
static_analysis_context=f"Found {len(static_findings)} static analysis findings."
|
||||||
|
if static_findings
|
||||||
|
else "No static analysis data available.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agents for enabled agent types
|
||||||
|
agents: list[SecurityAgent | StyleAgent | ComplexityAgent] = []
|
||||||
|
for agent_name in policy.get_enabled_agents():
|
||||||
|
if agent_name == AgentName.SECURITY:
|
||||||
|
agents.append(SecurityAgent(llm_client, prompt_registry))
|
||||||
|
elif agent_name == AgentName.STYLE:
|
||||||
|
agents.append(StyleAgent(llm_client, prompt_registry))
|
||||||
|
elif agent_name == AgentName.COMPLEXITY:
|
||||||
|
agents.append(ComplexityAgent(llm_client, prompt_registry))
|
||||||
|
|
||||||
|
# Run agents in parallel
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[agent.review(context) for agent in agents],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out exceptions
|
||||||
|
valid_results: list[ReviewResult] = []
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
logger.warning("Agent error: %s", result)
|
||||||
|
else:
|
||||||
|
valid_results.append(result)
|
||||||
|
|
||||||
|
# Run deliberation
|
||||||
|
coordinator = Coordinator(llm_client=llm_client)
|
||||||
|
deliberation_result = await coordinator.deliberate(
|
||||||
|
valid_results,
|
||||||
|
static_findings=static_findings if static_findings else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_results, deliberation_result
|
||||||
|
|
||||||
|
|
||||||
|
def _policy_from_model(model: PolicyModel) -> Policy:
|
||||||
|
"""Convert database policy model to Pydantic policy."""
|
||||||
|
# Reconstruct policy from stored JSON config
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
if model.agents_config:
|
||||||
|
data["agents"] = model.agents_config
|
||||||
|
if model.cost_controls:
|
||||||
|
data["cost_controls"] = model.cost_controls
|
||||||
|
if model.verdict_thresholds:
|
||||||
|
data["verdict"] = model.verdict_thresholds
|
||||||
|
|
||||||
|
return Policy.model_validate(data) if data else Policy()
|
||||||
|
|
||||||
|
|
||||||
|
async def process_followup(
|
||||||
|
_ctx: dict[str, Any],
|
||||||
|
repository: str,
|
||||||
|
pr_number: int,
|
||||||
|
comment_id: str,
|
||||||
|
comment_body: str,
|
||||||
|
author: str,
|
||||||
|
platform: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Process a follow-up question from a PR comment.
|
||||||
|
|
||||||
|
This task:
|
||||||
|
1. Finds the most recent completed review for this PR
|
||||||
|
2. Analyzes the comment to detect if it's a question
|
||||||
|
3. Routes to appropriate agent(s)
|
||||||
|
4. Generates explanations
|
||||||
|
5. Posts a response comment
|
||||||
|
6. Stores the conversation in the database
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: arq context
|
||||||
|
repository: Repository name (owner/repo format)
|
||||||
|
pr_number: Pull request number
|
||||||
|
comment_id: Platform-specific comment ID
|
||||||
|
comment_body: The comment text
|
||||||
|
author: Comment author username
|
||||||
|
platform: Source platform ("github" or "gitlab")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Conversation message ID if successful, None if no response needed
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
session_factory = async_session_factory()
|
||||||
|
|
||||||
|
# Check if follow-up is enabled
|
||||||
|
if not settings.followup_enabled:
|
||||||
|
logger.info("Follow-up processing is disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Detect platform and create client
|
||||||
|
detected_platform = detect_platform(repository, platform)
|
||||||
|
client = get_platform_client(detected_platform, settings)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
# Find the most recent completed review for this PR
|
||||||
|
review_query = (
|
||||||
|
select(ReviewModel)
|
||||||
|
.where(
|
||||||
|
ReviewModel.repository == repository,
|
||||||
|
ReviewModel.pr_number == pr_number,
|
||||||
|
ReviewModel.status == "completed",
|
||||||
|
)
|
||||||
|
.order_by(ReviewModel.completed_at.desc())
|
||||||
|
.limit(1)
|
||||||
|
.options(selectinload(ReviewModel.findings))
|
||||||
|
)
|
||||||
|
result = await session.execute(review_query)
|
||||||
|
review = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not review:
|
||||||
|
logger.info(
|
||||||
|
"No completed review found for %s PR #%d, skipping follow-up",
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert DB findings to Pydantic models
|
||||||
|
findings: list[Finding] = []
|
||||||
|
for f in review.findings:
|
||||||
|
# SQLAlchemy returns enum values, need to convert to AgentName/Severity
|
||||||
|
agent_name = f.agent if isinstance(f.agent, AgentName) else AgentName(f.agent)
|
||||||
|
severity = f.severity if isinstance(f.severity, Severity) else Severity(f.severity)
|
||||||
|
findings.append(
|
||||||
|
Finding(
|
||||||
|
id=f.id,
|
||||||
|
agent=agent_name,
|
||||||
|
file=f.file,
|
||||||
|
line_start=f.line_start,
|
||||||
|
line_end=f.line_end,
|
||||||
|
severity=severity,
|
||||||
|
confidence=f.confidence,
|
||||||
|
title=f.title,
|
||||||
|
description=f.description,
|
||||||
|
reasoning=f.reasoning,
|
||||||
|
suggestion=f.suggestion,
|
||||||
|
references=f.references or [],
|
||||||
|
prompt_version=f.prompt_version,
|
||||||
|
static_analysis_context=f.static_analysis_context,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Analyze the comment
|
||||||
|
detector = QuestionDetector(confidence_threshold=settings.followup_confidence_threshold)
|
||||||
|
analysis = detector.analyze(comment_body, findings)
|
||||||
|
|
||||||
|
if not analysis.is_question:
|
||||||
|
logger.info("Comment does not appear to be a question, skipping")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if analysis.confidence < settings.followup_confidence_threshold:
|
||||||
|
logger.info(
|
||||||
|
"Question confidence %.2f below threshold %.2f, skipping",
|
||||||
|
analysis.confidence,
|
||||||
|
settings.followup_confidence_threshold,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Detected question for %s PR #%d with confidence %.2f",
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
analysis.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or create conversation
|
||||||
|
conv_query = (
|
||||||
|
select(ConversationModel)
|
||||||
|
.where(
|
||||||
|
ConversationModel.review_id == review.id,
|
||||||
|
ConversationModel.repository == repository,
|
||||||
|
ConversationModel.pr_number == pr_number,
|
||||||
|
)
|
||||||
|
.options(selectinload(ConversationModel.messages))
|
||||||
|
)
|
||||||
|
conv_result = await session.execute(conv_query)
|
||||||
|
conversation = conv_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
conversation = ConversationModel(
|
||||||
|
review_id=review.id,
|
||||||
|
platform=platform,
|
||||||
|
repository=repository,
|
||||||
|
pr_number=pr_number,
|
||||||
|
)
|
||||||
|
session.add(conversation)
|
||||||
|
await session.flush()
|
||||||
|
logger.info("Created new conversation %s", conversation.id)
|
||||||
|
|
||||||
|
# Determine next sequence number
|
||||||
|
next_sequence = len(conversation.messages) if conversation.messages else 0
|
||||||
|
|
||||||
|
# Store user message
|
||||||
|
user_message = ConversationMessageModel(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
role="user",
|
||||||
|
platform_comment_id=comment_id,
|
||||||
|
author=author,
|
||||||
|
content=comment_body,
|
||||||
|
sequence=next_sequence,
|
||||||
|
)
|
||||||
|
session.add(user_message)
|
||||||
|
next_sequence += 1
|
||||||
|
|
||||||
|
# Build conversation history for context
|
||||||
|
conversation_history: list[dict[str, str]] = []
|
||||||
|
if conversation.messages:
|
||||||
|
for msg in sorted(conversation.messages, key=lambda m: m.sequence):
|
||||||
|
conversation_history.append({"role": msg.role, "content": msg.content})
|
||||||
|
conversation_history.append({"role": "user", "content": comment_body})
|
||||||
|
|
||||||
|
# Set up LLM client and router
|
||||||
|
llm_client = LiteLLMClient(
|
||||||
|
timeout=settings.llm_timeout,
|
||||||
|
max_retries=settings.llm_max_retries,
|
||||||
|
)
|
||||||
|
prompt_registry = PromptRegistry(settings.templates_dir)
|
||||||
|
router = AgentRouter(llm_client, prompt_registry)
|
||||||
|
|
||||||
|
# Route to appropriate agents
|
||||||
|
routes = router.route(analysis, findings)
|
||||||
|
|
||||||
|
if not routes:
|
||||||
|
logger.info("No agents to route to, skipping response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Generate explanations from each routed agent
|
||||||
|
responses: list[str] = []
|
||||||
|
total_tokens = 0
|
||||||
|
total_cost = 0.0
|
||||||
|
responding_agents: list[str] = []
|
||||||
|
referenced_finding_ids: list[str] = []
|
||||||
|
|
||||||
|
for route in routes:
|
||||||
|
agent = router.get_agent(route.agent_name)
|
||||||
|
responding_agents.append(route.agent_name.value)
|
||||||
|
|
||||||
|
if route.finding:
|
||||||
|
referenced_finding_ids.append(route.finding.id)
|
||||||
|
# Create explain context
|
||||||
|
explain_ctx = ExplainContext(
|
||||||
|
question=analysis.question_text,
|
||||||
|
finding=route.finding,
|
||||||
|
diff="", # Could fetch diff if needed
|
||||||
|
conversation_history=conversation_history[:-1], # Exclude current question
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
explain_result = await agent.explain(explain_ctx)
|
||||||
|
responses.append(
|
||||||
|
f"**{route.agent_name.value.title()} Agent** "
|
||||||
|
f"(regarding: {route.finding.title}):\n\n"
|
||||||
|
f"{explain_result.response}"
|
||||||
|
)
|
||||||
|
total_tokens += explain_result.tokens_used
|
||||||
|
total_cost += explain_result.cost_usd
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Agent %s failed to explain: %s", route.agent_name, e)
|
||||||
|
responses.append(
|
||||||
|
f"**{route.agent_name.value.title()} Agent**: "
|
||||||
|
f"Unable to provide explanation at this time."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# General question without specific finding
|
||||||
|
responses.append(
|
||||||
|
f"**{route.agent_name.value.title()} Agent**: "
|
||||||
|
f"No specific findings to explain for this query."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not responses:
|
||||||
|
logger.info("No responses generated, skipping")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Format the response
|
||||||
|
response_body = "\n\n---\n\n".join(responses)
|
||||||
|
full_response = f"{response_body}\n\n{ARBITER_MARKER}"
|
||||||
|
|
||||||
|
# Store assistant message
|
||||||
|
assistant_message = ConversationMessageModel(
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
role="assistant",
|
||||||
|
content=response_body,
|
||||||
|
responding_agents=responding_agents,
|
||||||
|
referenced_finding_ids=referenced_finding_ids,
|
||||||
|
tokens_used=total_tokens,
|
||||||
|
cost_usd=total_cost,
|
||||||
|
sequence=next_sequence,
|
||||||
|
)
|
||||||
|
session.add(assistant_message)
|
||||||
|
|
||||||
|
# Update conversation totals
|
||||||
|
conversation.total_tokens += total_tokens
|
||||||
|
conversation.total_cost_usd += total_cost
|
||||||
|
conversation.last_activity = datetime.now(UTC)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Generated response for %s PR #%d, tokens=%d, cost=$%.4f",
|
||||||
|
repository,
|
||||||
|
pr_number,
|
||||||
|
total_tokens,
|
||||||
|
total_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post the response comment
|
||||||
|
if client and settings.post_comments:
|
||||||
|
try:
|
||||||
|
comment_url = await client.post_comment(repository, pr_number, full_response)
|
||||||
|
logger.info("Posted follow-up response: %s", comment_url)
|
||||||
|
except IntegrationError as e:
|
||||||
|
logger.warning("Failed to post follow-up response: %s", e)
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
return assistant_message.id
|
||||||
Reference in New Issue
Block a user