diff --git a/src/arbiter/worker/__init__.py b/src/arbiter/worker/__init__.py new file mode 100644 index 0000000..547bca0 --- /dev/null +++ b/src/arbiter/worker/__init__.py @@ -0,0 +1,12 @@ +"""Worker module for async job processing.""" + +from arbiter.worker.queue import JobPriority, enqueue_review +from arbiter.worker.settings import WorkerSettings +from arbiter.worker.tasks import process_review + +__all__ = [ + "JobPriority", + "WorkerSettings", + "enqueue_review", + "process_review", +] diff --git a/src/arbiter/worker/queue.py b/src/arbiter/worker/queue.py new file mode 100644 index 0000000..055bfad --- /dev/null +++ b/src/arbiter/worker/queue.py @@ -0,0 +1,161 @@ +"""Queue utilities for job management.""" + +import hashlib +import logging +from enum import IntEnum +from typing import Any + +from arq import ArqRedis, create_pool +from arq.connections import RedisSettings + +from arbiter.config import get_settings + +logger = logging.getLogger(__name__) + + +class JobPriority(IntEnum): + """Job priority levels (lower = higher priority).""" + + HIGH = 1 + NORMAL = 2 + LOW = 3 + + +async def get_redis_pool() -> ArqRedis: + settings = get_settings() + return await create_pool(RedisSettings.from_dsn(settings.redis_url)) + + +def generate_job_id(repository: str, pr_number: int, head_sha: str) -> str: + key = f"{repository}:{pr_number}:{head_sha}" + return hashlib.sha256(key.encode()).hexdigest()[:16] + + +async def enqueue_review( + repository: str, + pr_number: int, + base_sha: str, + head_sha: str, + pr_title: str | None = None, + author: str | None = None, + is_draft: bool = False, + policy_name: str | None = None, + platform: str | None = None, +) -> str | None: + """Enqueue a review job with deduplication. + + Returns the job ID if enqueued, None if a job for this commit already exists. + """ + redis = await get_redis_pool() + job_id = generate_job_id(repository, pr_number, head_sha) + + # Check if job already exists + existing = await redis.get(f"arbiter:job:{job_id}") + if existing: + logger.info("Job %s already exists for %s PR #%d", job_id, repository, pr_number) + return None + + # Determine priority + priority = JobPriority.LOW if is_draft else JobPriority.NORMAL + + # Enqueue the job + job = await redis.enqueue_job( + "process_review", + repository=repository, + pr_number=pr_number, + base_sha=base_sha, + head_sha=head_sha, + pr_title=pr_title, + author=author, + is_draft=is_draft, + policy_name=policy_name, + platform=platform, + _job_id=job_id, + _queue_name=f"arbiter:queue:{priority}", + ) + + if job: + # Mark job as pending for deduplication + await redis.set(f"arbiter:job:{job_id}", "pending", ex=3600) # 1 hour TTL + logger.info("Enqueued job %s for %s PR #%d", job_id, repository, pr_number) + return job_id + + return None + + +async def get_job_status(job_id: str) -> dict[str, Any] | None: + redis = await get_redis_pool() + job = await redis.get(f"arbiter:job:{job_id}") + if not job: + return None + + return { + "job_id": job_id, + "status": job.decode() if isinstance(job, bytes) else job, + } + + +async def cancel_job(job_id: str) -> bool: + redis = await get_redis_pool() + deleted: int = await redis.delete(f"arbiter:job:{job_id}") + return deleted > 0 + + +def generate_followup_job_id(repository: str, pr_number: int, comment_id: str) -> str: + key = f"followup:{repository}:{pr_number}:{comment_id}" + return hashlib.sha256(key.encode()).hexdigest()[:16] + + +async def enqueue_followup( + repository: str, + pr_number: int, + comment_id: str, + comment_body: str, + author: str, + platform: str, +) -> str | None: + """Enqueue a follow-up processing job with deduplication. + + Returns the job ID if enqueued, None if a job for this comment already exists. + """ + redis = await get_redis_pool() + job_id = generate_followup_job_id(repository, pr_number, comment_id) + + # Check if job already exists + existing = await redis.get(f"arbiter:followup:{job_id}") + if existing: + logger.info( + "Follow-up job %s already exists for %s PR #%d comment %s", + job_id, + repository, + pr_number, + comment_id, + ) + return None + + # Enqueue the job with normal priority + job = await redis.enqueue_job( + "process_followup", + repository=repository, + pr_number=pr_number, + comment_id=comment_id, + comment_body=comment_body, + author=author, + platform=platform, + _job_id=job_id, + _queue_name=f"arbiter:queue:{JobPriority.NORMAL}", + ) + + if job: + # Mark job as pending for deduplication (shorter TTL for follow-ups) + await redis.set(f"arbiter:followup:{job_id}", "pending", ex=1800) # 30 min TTL + logger.info( + "Enqueued follow-up job %s for %s PR #%d comment %s", + job_id, + repository, + pr_number, + comment_id, + ) + return job_id + + return None diff --git a/src/arbiter/worker/settings.py b/src/arbiter/worker/settings.py new file mode 100644 index 0000000..f740f2a --- /dev/null +++ b/src/arbiter/worker/settings.py @@ -0,0 +1,68 @@ +"""arq worker settings for Arbiter.""" + +import logging +from collections.abc import Callable, Coroutine +from typing import Any, ClassVar + +from arq.connections import RedisSettings + +from arbiter.config import get_settings +from arbiter.db.session import close_db, init_db + +logger = logging.getLogger(__name__) + + +async def startup(ctx: dict[str, Any]) -> None: + logger.info("Worker starting up") + await init_db() + ctx["settings"] = get_settings() + logger.info("Worker ready") + + +async def shutdown(_ctx: dict[str, Any]) -> None: + logger.info("Worker shutting down") + await close_db() + logger.info("Worker stopped") + + +async def health_check(_ctx: dict[str, Any]) -> str: + return "healthy" + + +class WorkerSettings: + """arq worker settings.""" + + # Import functions here to avoid circular imports + @staticmethod + def _get_functions() -> list[Callable[..., Coroutine[Any, Any, Any]]]: + from arbiter.worker.tasks import process_followup, process_review + + return [process_review, process_followup] + + functions: ClassVar[list[Callable[..., Coroutine[Any, Any, Any]]]] = [] + cron_jobs: ClassVar[list[Any]] = [] + on_startup = startup + on_shutdown = shutdown + + # Redis connection + @staticmethod + def redis_settings() -> RedisSettings: + settings = get_settings() + return RedisSettings.from_dsn(settings.redis_url) + + # Job settings + max_jobs = get_settings().worker_max_jobs + job_timeout = get_settings().worker_job_timeout + max_tries = get_settings().worker_retry_attempts + retry_jobs = True + + # Queue settings + queue_name = "arbiter:queue" + health_check_interval = 60 + + def __init_subclass__(cls) -> None: + cls.functions = cls._get_functions() + + +# Initialize functions list +WorkerSettings.functions = WorkerSettings._get_functions() diff --git a/src/arbiter/worker/tasks.py b/src/arbiter/worker/tasks.py new file mode 100644 index 0000000..15f9389 --- /dev/null +++ b/src/arbiter/worker/tasks.py @@ -0,0 +1,794 @@ +"""Worker task definitions for Arbiter.""" + +import asyncio +import logging +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from arbiter.agents import ComplexityAgent, ExplainContext, ReviewContext, SecurityAgent, StyleAgent +from arbiter.analysis import DiffParser, StaticAnalysisRunner +from arbiter.config import Settings, get_settings +from arbiter.conversation import AgentRouter, QuestionDetector +from arbiter.db.models import ( + ConflictModel, + ConversationMessageModel, + ConversationModel, + DeliberationStepModel, + FindingModel, + PolicyModel, + ReviewModel, +) +from arbiter.db.session import async_session_factory +from arbiter.deliberation import Coordinator, DeliberationResult +from arbiter.integrations import ( + ARBITER_MARKER, + CommitStatus, + GitHubClient, + GitLabClient, + IntegrationError, + Platform, + PlatformClient, + ReviewCommentFormatter, + has_arbiter_marker, +) +from arbiter.llm import LiteLLMClient, PromptRegistry +from arbiter.models import AgentName, Finding, Policy, ReviewResult, Severity, Verdict + +logger = logging.getLogger(__name__) + + +def detect_platform(repository: str, webhook_source: str | None = None) -> Platform: + """Detect the platform from repository format or webhook source. + + Args: + repository: Repository name (owner/repo format). + webhook_source: Explicit platform from webhook ("github" or "gitlab"). + + Returns: + Detected platform. + + Raises: + ValueError: If platform cannot be determined. + """ + if webhook_source: + if webhook_source.lower() == "github": + return Platform.GITHUB + if webhook_source.lower() == "gitlab": + return Platform.GITLAB + + # Default to GitHub if not specified + # In practice, webhooks should always provide the source + logger.warning("Platform not specified, defaulting to GitHub for %s", repository) + return Platform.GITHUB + + +def get_platform_client(platform: Platform, settings: Settings) -> PlatformClient | None: + """Create a platform client based on the detected platform. + + Args: + platform: Target platform. + settings: Application settings. + + Returns: + Platform client or None if no token is configured. + """ + if platform == Platform.GITHUB: + if not settings.github_token: + logger.warning("GitHub token not configured, skipping platform integration") + return None + return GitHubClient( + token=settings.github_token.get_secret_value(), + base_url=settings.github_base_url, + timeout=settings.integration_timeout, + max_retries=settings.integration_max_retries, + ) + + if platform == Platform.GITLAB: + if not settings.gitlab_token: + logger.warning("GitLab token not configured, skipping platform integration") + return None + return GitLabClient( + token=settings.gitlab_token.get_secret_value(), + base_url=settings.gitlab_base_url, + timeout=settings.integration_timeout, + max_retries=settings.integration_max_retries, + ) + + return None + + +async def _update_status_safe( + client: PlatformClient, + repository: str, + sha: str, + status: CommitStatus, + description: str, + context: str, + target_url: str | None = None, +) -> None: + """Update commit status, logging errors without raising. + + Args: + client: Platform client. + repository: Repository name. + sha: Commit SHA. + status: Status to set. + description: Status description. + context: Status context name. + target_url: Optional URL for details. + """ + try: + await client.update_commit_status( + repository=repository, + sha=sha, + status=status, + description=description, + context=context, + target_url=target_url, + ) + logger.info("Updated commit status for %s@%s: %s", repository, sha[:8], status) + except IntegrationError as e: + logger.warning("Failed to update commit status: %s", e) + + +async def _post_or_update_comment( + client: PlatformClient, + repository: str, + pr_number: int, + body: str, +) -> str | None: + """Post a new comment or update existing Arbiter comment. + + This provides idempotent comment behavior: re-reviews update the existing + Arbiter comment instead of posting new ones (preventing comment spam). + + Args: + client: Platform client. + repository: Repository name. + pr_number: PR/MR number. + body: Comment body (must contain ARBITER_MARKER). + + Returns: + Comment URL if successful, None otherwise. + """ + try: + # Try to find existing Arbiter comment + existing_comment_id: str | None = None + try: + comments = await client.get_comments(repository, pr_number) + for comment in comments: + if has_arbiter_marker(comment.body): + existing_comment_id = comment.id + logger.info( + "Found existing Arbiter comment %s on %s#%d", + existing_comment_id, + repository, + pr_number, + ) + break + except IntegrationError as e: + # If we can't fetch comments, fall back to posting new + logger.warning("Failed to fetch comments, will post new: %s", e) + + if existing_comment_id: + # Update existing comment + url = await client.update_comment(repository, pr_number, existing_comment_id, body) + logger.info( + "Updated comment %s on %s#%d: %s", + existing_comment_id, + repository, + pr_number, + url, + ) + return url + else: + # Post new comment + url = await client.post_comment(repository, pr_number, body) + logger.info("Posted new comment on %s#%d: %s", repository, pr_number, url) + return url + + except IntegrationError as e: + logger.warning("Failed to post/update comment: %s", e) + return None + + +def _verdict_to_status(verdict: Verdict) -> CommitStatus: + if verdict == Verdict.APPROVE: + return CommitStatus.SUCCESS + if verdict == Verdict.REQUEST_CHANGES: + return CommitStatus.FAILURE + return CommitStatus.SUCCESS # COMMENT still passes + + +async def process_review( + _ctx: dict[str, Any], + repository: str, + pr_number: int, + base_sha: str, + head_sha: str, + pr_title: str | None = None, + author: str | None = None, + is_draft: bool = False, + policy_name: str | None = None, + diff_content: str | None = None, + platform: str | None = None, +) -> str: + """Process a code review job. + + This task runs the full review pipeline: + 1. Create/update review record + 2. Set commit status to pending + 3. Fetch diff (from platform if not provided) + 4. Run static analysis + 5. Run agent reviews in parallel + 6. Run deliberation + 7. Store results in database + 8. Post comment with results + 9. Update commit status to success/failure + + Args: + ctx: arq context with settings and connections + repository: Repository name (owner/repo format) + pr_number: Pull request number + base_sha: Base commit SHA + head_sha: Head commit SHA + pr_title: PR title + author: PR author + is_draft: Whether PR is a draft + policy_name: Name of policy to use + diff_content: Optional pre-fetched diff content + platform: Source platform ("github" or "gitlab") + + Returns: + Review ID on success + """ + settings = get_settings() + session_factory = async_session_factory() + + # Detect platform and create client + detected_platform = detect_platform(repository, platform) + client = get_platform_client(detected_platform, settings) + + async with session_factory() as session: + # Create review record + review = ReviewModel( + repository=repository, + pr_number=pr_number, + pr_title=pr_title, + base_sha=base_sha, + head_sha=head_sha, + author=author, + is_draft=is_draft, + status="running", + started_at=datetime.now(UTC), + ) + + # Load policy if specified + policy = Policy() + if policy_name: + result = await session.execute( + select(PolicyModel).where( + PolicyModel.name == policy_name, PolicyModel.is_active.is_(True) + ) + ) + policy_model = result.scalar_one_or_none() + if policy_model: + review.policy_id = policy_model.id + policy = _policy_from_model(policy_model) + + session.add(review) + await session.commit() + await session.refresh(review) + review_id = review.id + + logger.info("Started review %s for %s PR #%d", review_id, repository, pr_number) + + # Set pending status + if client and settings.update_status: + await _update_status_safe( + client, + repository, + head_sha, + CommitStatus.PENDING, + "Arbiter review in progress...", + settings.status_check_context, + ) + + try: + # Get diff content from platform if not provided + if not diff_content: + if client: + logger.info( + "Fetching diff from %s for %s#%d", detected_platform, repository, pr_number + ) + diff_content = await client.get_pr_diff(repository, pr_number) + else: + raise ValueError( + f"diff_content not provided and no {detected_platform} token configured" + ) + + # Run the review pipeline + agent_results, deliberation_result = await _run_review_pipeline( + diff_content=diff_content, + policy=policy, + templates_dir=settings.templates_dir, + ) + + # Update review with results + review.status = "completed" + review.completed_at = datetime.now(UTC) + review.verdict = deliberation_result.verdict + review.verdict_confidence = deliberation_result.verdict_confidence + review.verdict_reasoning = deliberation_result.verdict_reasoning + + # Calculate and store costs + total_tokens = ( + sum(r.tokens_used for r in agent_results) + deliberation_result.tokens_used + ) + total_cost = sum(r.cost_usd for r in agent_results) + deliberation_result.cost_usd + tokens_by_agent = {r.agent_name.value: r.tokens_used for r in agent_results} + cost_by_agent = {r.agent_name.value: r.cost_usd for r in agent_results} + + review.total_tokens = total_tokens + review.total_cost_usd = total_cost + review.tokens_by_agent = tokens_by_agent + review.cost_by_agent = cost_by_agent + + # Store findings + for finding in deliberation_result.findings: + finding_model = FindingModel( + review_id=review_id, + agent=finding.agent, + file=finding.file, + line_start=finding.line_start, + line_end=finding.line_end, + severity=finding.severity, + confidence=finding.confidence, + title=finding.title, + description=finding.description, + reasoning=finding.reasoning, + suggestion=finding.suggestion, + references=finding.references, + prompt_version=finding.prompt_version, + static_analysis_context=finding.static_analysis_context, + ) + session.add(finding_model) + + # Store conflicts + for conflict in deliberation_result.conflicts: + conflict_model = ConflictModel( + review_id=review_id, + finding_ids=conflict.finding_ids, + nature=conflict.nature, + description=conflict.description, + severity_weight=conflict.severity_weight, + resolution=conflict.resolution, + winning_finding_id=conflict.winning_finding_id, + ) + session.add(conflict_model) + + # Store deliberation steps + for i, step in enumerate(deliberation_result.steps): + step_model = DeliberationStepModel( + review_id=review_id, + step_type=step.step_type, + timestamp=step.timestamp, + description=step.description, + details=step.details, + sequence=i, + ) + session.add(step_model) + + await session.commit() + logger.info("Completed review %s: %s", review_id, review.verdict) + + # Post comment and update status + if client: + comment_url: str | None = None + + if settings.post_comments: + formatter = ReviewCommentFormatter(include_cost=True) + comment_body = formatter.format(deliberation_result) + comment_url = await _post_or_update_comment( + client, repository, pr_number, comment_body + ) + + if settings.update_status: + final_status = _verdict_to_status(deliberation_result.verdict) + description = ( + deliberation_result.verdict_reasoning[:140] + if deliberation_result.verdict_reasoning + else f"Review: {deliberation_result.verdict.value}" + ) + await _update_status_safe( + client, + repository, + head_sha, + final_status, + description, + settings.status_check_context, + target_url=comment_url, + ) + + except Exception as e: + logger.exception("Review %s failed", review_id) + review.status = "failed" + review.error_message = str(e) + review.completed_at = datetime.now(UTC) + await session.commit() + + # Update status to error + if client and settings.update_status: + await _update_status_safe( + client, + repository, + head_sha, + CommitStatus.ERROR, + f"Review failed: {str(e)[:100]}", + settings.status_check_context, + ) + + raise + + finally: + # Close the client + if client: + await client.close() + + return review_id + + +async def _run_review_pipeline( + diff_content: str, + policy: Policy, + templates_dir: Path, + work_dir: Path | None = None, +) -> tuple[list[ReviewResult], DeliberationResult]: + """Run the full review pipeline. + + This is extracted from cli.py to be reusable by both CLI and worker. + """ + settings = get_settings() + llm_client = LiteLLMClient( + timeout=settings.llm_timeout, + max_retries=settings.llm_max_retries, + ) + prompt_registry = PromptRegistry(templates_dir) + + # Parse the diff + parser = DiffParser() + parsed_diff = parser.parse(diff_content) + + # Run static analysis if work_dir is provided + static_findings: list[Finding] = [] + if work_dir: + runner = StaticAnalysisRunner() + static_result = await runner.run(parsed_diff, work_dir) + for sf in static_result.findings: + static_findings.append(runner.convert_to_finding(sf)) + + context = ReviewContext( + diff=diff_content, + policy=policy, + static_analysis_context=f"Found {len(static_findings)} static analysis findings." + if static_findings + else "No static analysis data available.", + ) + + # Create agents for enabled agent types + agents: list[SecurityAgent | StyleAgent | ComplexityAgent] = [] + for agent_name in policy.get_enabled_agents(): + if agent_name == AgentName.SECURITY: + agents.append(SecurityAgent(llm_client, prompt_registry)) + elif agent_name == AgentName.STYLE: + agents.append(StyleAgent(llm_client, prompt_registry)) + elif agent_name == AgentName.COMPLEXITY: + agents.append(ComplexityAgent(llm_client, prompt_registry)) + + # Run agents in parallel + results = await asyncio.gather( + *[agent.review(context) for agent in agents], + return_exceptions=True, + ) + + # Filter out exceptions + valid_results: list[ReviewResult] = [] + for result in results: + if isinstance(result, BaseException): + logger.warning("Agent error: %s", result) + else: + valid_results.append(result) + + # Run deliberation + coordinator = Coordinator(llm_client=llm_client) + deliberation_result = await coordinator.deliberate( + valid_results, + static_findings=static_findings if static_findings else None, + ) + + return valid_results, deliberation_result + + +def _policy_from_model(model: PolicyModel) -> Policy: + """Convert database policy model to Pydantic policy.""" + # Reconstruct policy from stored JSON config + data: dict[str, Any] = {} + if model.agents_config: + data["agents"] = model.agents_config + if model.cost_controls: + data["cost_controls"] = model.cost_controls + if model.verdict_thresholds: + data["verdict"] = model.verdict_thresholds + + return Policy.model_validate(data) if data else Policy() + + +async def process_followup( + _ctx: dict[str, Any], + repository: str, + pr_number: int, + comment_id: str, + comment_body: str, + author: str, + platform: str, +) -> str | None: + """Process a follow-up question from a PR comment. + + This task: + 1. Finds the most recent completed review for this PR + 2. Analyzes the comment to detect if it's a question + 3. Routes to appropriate agent(s) + 4. Generates explanations + 5. Posts a response comment + 6. Stores the conversation in the database + + Args: + ctx: arq context + repository: Repository name (owner/repo format) + pr_number: Pull request number + comment_id: Platform-specific comment ID + comment_body: The comment text + author: Comment author username + platform: Source platform ("github" or "gitlab") + + Returns: + Conversation message ID if successful, None if no response needed + """ + settings = get_settings() + session_factory = async_session_factory() + + # Check if follow-up is enabled + if not settings.followup_enabled: + logger.info("Follow-up processing is disabled") + return None + + # Detect platform and create client + detected_platform = detect_platform(repository, platform) + client = get_platform_client(detected_platform, settings) + + async with session_factory() as session: + # Find the most recent completed review for this PR + review_query = ( + select(ReviewModel) + .where( + ReviewModel.repository == repository, + ReviewModel.pr_number == pr_number, + ReviewModel.status == "completed", + ) + .order_by(ReviewModel.completed_at.desc()) + .limit(1) + .options(selectinload(ReviewModel.findings)) + ) + result = await session.execute(review_query) + review = result.scalar_one_or_none() + + if not review: + logger.info( + "No completed review found for %s PR #%d, skipping follow-up", + repository, + pr_number, + ) + return None + + # Convert DB findings to Pydantic models + findings: list[Finding] = [] + for f in review.findings: + # SQLAlchemy returns enum values, need to convert to AgentName/Severity + agent_name = f.agent if isinstance(f.agent, AgentName) else AgentName(f.agent) + severity = f.severity if isinstance(f.severity, Severity) else Severity(f.severity) + findings.append( + Finding( + id=f.id, + agent=agent_name, + file=f.file, + line_start=f.line_start, + line_end=f.line_end, + severity=severity, + confidence=f.confidence, + title=f.title, + description=f.description, + reasoning=f.reasoning, + suggestion=f.suggestion, + references=f.references or [], + prompt_version=f.prompt_version, + static_analysis_context=f.static_analysis_context, + ) + ) + + # Analyze the comment + detector = QuestionDetector(confidence_threshold=settings.followup_confidence_threshold) + analysis = detector.analyze(comment_body, findings) + + if not analysis.is_question: + logger.info("Comment does not appear to be a question, skipping") + return None + + if analysis.confidence < settings.followup_confidence_threshold: + logger.info( + "Question confidence %.2f below threshold %.2f, skipping", + analysis.confidence, + settings.followup_confidence_threshold, + ) + return None + + logger.info( + "Detected question for %s PR #%d with confidence %.2f", + repository, + pr_number, + analysis.confidence, + ) + + # Get or create conversation + conv_query = ( + select(ConversationModel) + .where( + ConversationModel.review_id == review.id, + ConversationModel.repository == repository, + ConversationModel.pr_number == pr_number, + ) + .options(selectinload(ConversationModel.messages)) + ) + conv_result = await session.execute(conv_query) + conversation = conv_result.scalar_one_or_none() + + if not conversation: + conversation = ConversationModel( + review_id=review.id, + platform=platform, + repository=repository, + pr_number=pr_number, + ) + session.add(conversation) + await session.flush() + logger.info("Created new conversation %s", conversation.id) + + # Determine next sequence number + next_sequence = len(conversation.messages) if conversation.messages else 0 + + # Store user message + user_message = ConversationMessageModel( + conversation_id=conversation.id, + role="user", + platform_comment_id=comment_id, + author=author, + content=comment_body, + sequence=next_sequence, + ) + session.add(user_message) + next_sequence += 1 + + # Build conversation history for context + conversation_history: list[dict[str, str]] = [] + if conversation.messages: + for msg in sorted(conversation.messages, key=lambda m: m.sequence): + conversation_history.append({"role": msg.role, "content": msg.content}) + conversation_history.append({"role": "user", "content": comment_body}) + + # Set up LLM client and router + llm_client = LiteLLMClient( + timeout=settings.llm_timeout, + max_retries=settings.llm_max_retries, + ) + prompt_registry = PromptRegistry(settings.templates_dir) + router = AgentRouter(llm_client, prompt_registry) + + # Route to appropriate agents + routes = router.route(analysis, findings) + + if not routes: + logger.info("No agents to route to, skipping response") + return None + + # Generate explanations from each routed agent + responses: list[str] = [] + total_tokens = 0 + total_cost = 0.0 + responding_agents: list[str] = [] + referenced_finding_ids: list[str] = [] + + for route in routes: + agent = router.get_agent(route.agent_name) + responding_agents.append(route.agent_name.value) + + if route.finding: + referenced_finding_ids.append(route.finding.id) + # Create explain context + explain_ctx = ExplainContext( + question=analysis.question_text, + finding=route.finding, + diff="", # Could fetch diff if needed + conversation_history=conversation_history[:-1], # Exclude current question + ) + + try: + explain_result = await agent.explain(explain_ctx) + responses.append( + f"**{route.agent_name.value.title()} Agent** " + f"(regarding: {route.finding.title}):\n\n" + f"{explain_result.response}" + ) + total_tokens += explain_result.tokens_used + total_cost += explain_result.cost_usd + except Exception as e: + logger.warning("Agent %s failed to explain: %s", route.agent_name, e) + responses.append( + f"**{route.agent_name.value.title()} Agent**: " + f"Unable to provide explanation at this time." + ) + else: + # General question without specific finding + responses.append( + f"**{route.agent_name.value.title()} Agent**: " + f"No specific findings to explain for this query." + ) + + if not responses: + logger.info("No responses generated, skipping") + return None + + # Format the response + response_body = "\n\n---\n\n".join(responses) + full_response = f"{response_body}\n\n{ARBITER_MARKER}" + + # Store assistant message + assistant_message = ConversationMessageModel( + conversation_id=conversation.id, + role="assistant", + content=response_body, + responding_agents=responding_agents, + referenced_finding_ids=referenced_finding_ids, + tokens_used=total_tokens, + cost_usd=total_cost, + sequence=next_sequence, + ) + session.add(assistant_message) + + # Update conversation totals + conversation.total_tokens += total_tokens + conversation.total_cost_usd += total_cost + conversation.last_activity = datetime.now(UTC) + + await session.commit() + logger.info( + "Generated response for %s PR #%d, tokens=%d, cost=$%.4f", + repository, + pr_number, + total_tokens, + total_cost, + ) + + # Post the response comment + if client and settings.post_comments: + try: + comment_url = await client.post_comment(repository, pr_number, full_response) + logger.info("Posted follow-up response: %s", comment_url) + except IntegrationError as e: + logger.warning("Failed to post follow-up response: %s", e) + finally: + await client.close() + + return assistant_message.id