From 56b528b2e36d588ff05d477c57fba8f042cdb65a Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Sat, 24 May 2025 11:18:59 +0000 Subject: [PATCH] feat(conversation): add question detection and routing --- src/arbiter/conversation/__init__.py | 11 ++ src/arbiter/conversation/detection.py | 234 ++++++++++++++++++++++++++ src/arbiter/conversation/router.py | 150 +++++++++++++++++ 3 files changed, 395 insertions(+) create mode 100644 src/arbiter/conversation/__init__.py create mode 100644 src/arbiter/conversation/detection.py create mode 100644 src/arbiter/conversation/router.py diff --git a/src/arbiter/conversation/__init__.py b/src/arbiter/conversation/__init__.py new file mode 100644 index 0000000..f84fdfa --- /dev/null +++ b/src/arbiter/conversation/__init__.py @@ -0,0 +1,11 @@ +"""Conversation handling for follow-up questions.""" + +from arbiter.conversation.detection import QuestionAnalysis, QuestionDetector +from arbiter.conversation.router import AgentRouter, RouteResult + +__all__ = [ + "AgentRouter", + "QuestionAnalysis", + "QuestionDetector", + "RouteResult", +] diff --git a/src/arbiter/conversation/detection.py b/src/arbiter/conversation/detection.py new file mode 100644 index 0000000..06c62e2 --- /dev/null +++ b/src/arbiter/conversation/detection.py @@ -0,0 +1,234 @@ +"""Question detection for PR comments.""" + +import re +from typing import ClassVar + +from pydantic import BaseModel, Field + +from arbiter.models import AgentName, Finding + + +class QuestionAnalysis(BaseModel): + """Result of analyzing a PR comment for questions.""" + + is_question: bool = Field(description="Whether the comment appears to be a question") + is_directed_at_arbiter: bool = Field( + description="Whether the question is directed at Arbiter specifically" + ) + question_text: str = Field(description="The extracted question text") + mentioned_agents: list[AgentName] = Field( + default_factory=list, description="Agents mentioned or implied" + ) + mentioned_finding_ids: list[str] = Field( + default_factory=list, description="Finding IDs referenced in the comment" + ) + confidence: float = Field( + ge=0.0, le=1.0, description="Confidence that this is a valid follow-up question" + ) + + +class QuestionDetector: + """Detects questions in PR comments that should be answered by Arbiter.""" + + # Patterns that indicate a question + QUESTION_PATTERNS: ClassVar[list[str]] = [ + r"\?$", # Ends with question mark + r"\?\s*$", # Ends with question mark and optional whitespace + r"^(why|what|how|can you|could you|please explain|explain|tell me)", + r"\b(why|what|how|can|could|would|should)\b.*\?", + ] + + # Patterns that indicate the comment is directed at Arbiter + ARBITER_PATTERNS: ClassVar[list[str]] = [ + r"@arbiter\b", + r"\barbiter\b", + r"@security.?agent\b", + r"@style.?agent\b", + r"@complexity.?agent\b", + ] + + # Keywords associated with each agent + AGENT_KEYWORDS: ClassVar[dict[AgentName, list[str]]] = { + AgentName.SECURITY: [ + "security", + "vulnerability", + "injection", + "xss", + "csrf", + "auth", + "authentication", + "authorization", + "exploit", + "attack", + "risk", + "threat", + "sensitive", + "secret", + "credential", + "permission", + "access control", + "owasp", + "cwe", + ], + AgentName.STYLE: [ + "style", + "naming", + "readability", + "convention", + "format", + "formatting", + "consistent", + "consistency", + "pattern", + "readable", + "clean", + "code style", + "lint", + ], + AgentName.COMPLEXITY: [ + "complexity", + "architecture", + "design", + "refactor", + "maintainability", + "coupling", + "cohesion", + "abstraction", + "module", + "dependency", + "cyclomatic", + "cognitive", + "nested", + "simplify", + ], + } + + # Patterns to extract finding IDs (UUIDs) + UUID_PATTERN: ClassVar[str] = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" + + def __init__(self, confidence_threshold: float = 0.5) -> None: + self.confidence_threshold = confidence_threshold + self._compiled_question_patterns = [ + re.compile(p, re.IGNORECASE | re.MULTILINE) for p in self.QUESTION_PATTERNS + ] + self._compiled_arbiter_patterns = [ + re.compile(p, re.IGNORECASE) for p in self.ARBITER_PATTERNS + ] + self._uuid_pattern = re.compile(self.UUID_PATTERN, re.IGNORECASE) + + def analyze( + self, + comment_text: str, + findings: list[Finding] | None = None, + ) -> QuestionAnalysis: + """Analyze a PR comment to detect if it's a question for Arbiter. + + Args: + comment_text: The comment text to analyze. + findings: Optional list of findings from the review to match against. + + Returns: + QuestionAnalysis with detection results. + """ + text = comment_text.strip() + text_lower = text.lower() + + # Check if it's a question + is_question = self._is_question(text) + + # Check if directed at Arbiter + is_directed = self._is_directed_at_arbiter(text) + + # Extract mentioned agents + mentioned_agents = self._extract_mentioned_agents(text_lower) + + # Extract mentioned finding IDs + mentioned_finding_ids = self._extract_finding_ids(text, findings) + + # Calculate confidence + confidence = self._calculate_confidence( + is_question=is_question, + is_directed=is_directed, + has_agent_mentions=len(mentioned_agents) > 0, + has_finding_refs=len(mentioned_finding_ids) > 0, + ) + + # Extract the question text (clean up @mentions and excess whitespace) + question_text = self._extract_question_text(text) + + return QuestionAnalysis( + is_question=is_question, + is_directed_at_arbiter=is_directed, + question_text=question_text, + mentioned_agents=mentioned_agents, + mentioned_finding_ids=mentioned_finding_ids, + confidence=confidence, + ) + + def _is_question(self, text: str) -> bool: + return any(pattern.search(text) for pattern in self._compiled_question_patterns) + + def _is_directed_at_arbiter(self, text: str) -> bool: + return any(pattern.search(text) for pattern in self._compiled_arbiter_patterns) + + def _extract_mentioned_agents(self, text_lower: str) -> list[AgentName]: + """Extract agents mentioned or implied in the text.""" + mentioned = [] + for agent, keywords in self.AGENT_KEYWORDS.items(): + for keyword in keywords: + if keyword in text_lower: + if agent not in mentioned: + mentioned.append(agent) + break + return mentioned + + def _extract_finding_ids( + self, + text: str, + findings: list[Finding] | None, + ) -> list[str]: + """Extract finding IDs referenced in the text.""" + # Find UUIDs in the text + uuid_matches = self._uuid_pattern.findall(text.lower()) + + if not findings: + return uuid_matches + + # Filter to only valid finding IDs + valid_ids = {f.id.lower() for f in findings} + return [uid for uid in uuid_matches if uid in valid_ids] + + def _calculate_confidence( + self, + is_question: bool, + is_directed: bool, + has_agent_mentions: bool, + has_finding_refs: bool, + ) -> float: + """Calculate confidence score for the analysis.""" + if not is_question: + return 0.0 + + # Base confidence for being a question + confidence = 0.4 + + # Boost for being directed at Arbiter + if is_directed: + confidence += 0.4 + + # Boost for mentioning specific agents + if has_agent_mentions: + confidence += 0.1 + + # Boost for referencing specific findings + if has_finding_refs: + confidence += 0.1 + + return min(confidence, 1.0) + + def _extract_question_text(self, text: str) -> str: + # Remove @mentions + cleaned = re.sub(r"@\w+", "", text) + # Collapse multiple whitespace + cleaned = re.sub(r"\s+", " ", cleaned) + return cleaned.strip() diff --git a/src/arbiter/conversation/router.py b/src/arbiter/conversation/router.py new file mode 100644 index 0000000..53be9d8 --- /dev/null +++ b/src/arbiter/conversation/router.py @@ -0,0 +1,150 @@ +"""Agent routing for follow-up questions.""" + +from pydantic import BaseModel, Field + +from arbiter.agents import Agent, ComplexityAgent, SecurityAgent, StyleAgent +from arbiter.conversation.detection import QuestionAnalysis +from arbiter.llm.client import LLMClient +from arbiter.llm.prompts import PromptRegistry +from arbiter.models import AgentName, Finding + + +class RouteResult(BaseModel): + """Result of routing a question to agents.""" + + agent_name: AgentName = Field(description="The agent to handle this route") + finding: Finding | None = Field( + default=None, description="Specific finding this route is about" + ) + priority: int = Field(default=0, description="Priority order (lower = higher priority)") + + +class AgentRouter: + """Routes follow-up questions to appropriate agents.""" + + def __init__( + self, + llm_client: LLMClient, + prompt_registry: PromptRegistry, + ) -> None: + self.llm_client = llm_client + self.prompt_registry = prompt_registry + self._agents: dict[AgentName, Agent] = {} + + def get_agent(self, agent_name: AgentName) -> Agent: + if agent_name not in self._agents: + if agent_name == AgentName.SECURITY: + self._agents[agent_name] = SecurityAgent(self.llm_client, self.prompt_registry) + elif agent_name == AgentName.STYLE: + self._agents[agent_name] = StyleAgent(self.llm_client, self.prompt_registry) + elif agent_name == AgentName.COMPLEXITY: + self._agents[agent_name] = ComplexityAgent(self.llm_client, self.prompt_registry) + return self._agents[agent_name] + + def route( + self, + analysis: QuestionAnalysis, + findings: list[Finding], + ) -> list[RouteResult]: + """Determine which agents should respond to a question. + + The routing logic follows this priority: + 1. If specific finding IDs are referenced, route to those findings' agents + 2. If specific agents are mentioned by keyword, route to those agents + 3. If no specific agents/findings, route to all agents that have findings + + Args: + analysis: The question analysis result. + findings: List of findings from the review. + + Returns: + List of RouteResult indicating which agents should respond. + """ + routes: list[RouteResult] = [] + + # Build a map of finding ID to finding + findings_by_id = {f.id: f for f in findings} + + # Priority 1: Specific findings referenced + if analysis.mentioned_finding_ids: + for i, finding_id in enumerate(analysis.mentioned_finding_ids): + if finding_id in findings_by_id: + finding = findings_by_id[finding_id] + routes.append( + RouteResult( + agent_name=finding.agent, + finding=finding, + priority=i, + ) + ) + if routes: + return routes + + # Priority 2: Specific agents mentioned by keyword + if analysis.mentioned_agents: + # Find findings for each mentioned agent + findings_by_agent: dict[AgentName, list[Finding]] = {} + for finding in findings: + if finding.agent in analysis.mentioned_agents: + if finding.agent not in findings_by_agent: + findings_by_agent[finding.agent] = [] + findings_by_agent[finding.agent].append(finding) + + for priority, agent_name in enumerate(analysis.mentioned_agents): + agent_findings = findings_by_agent.get(agent_name, []) + if agent_findings: + # Route to most relevant finding (highest severity) + severity_order = ["critical", "high", "medium", "low", "info"] + sorted_findings = sorted( + agent_findings, + key=lambda f: severity_order.index(f.severity.value), + ) + routes.append( + RouteResult( + agent_name=agent_name, + finding=sorted_findings[0], + priority=priority, + ) + ) + else: + # Agent mentioned but has no findings - still route + routes.append( + RouteResult( + agent_name=agent_name, + finding=None, + priority=priority, + ) + ) + + if routes: + return routes + + # Priority 3: Default - route to all agents with findings + agents_with_findings: dict[AgentName, Finding] = {} + for finding in findings: + if finding.agent not in agents_with_findings: + agents_with_findings[finding.agent] = finding + else: + # Keep highest severity finding + existing = agents_with_findings[finding.agent] + severity_order = ["critical", "high", "medium", "low", "info"] + if severity_order.index(finding.severity.value) < severity_order.index( + existing.severity.value + ): + agents_with_findings[finding.agent] = finding + + # Sort agents by standard order + agent_order = [AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY] + priority = 0 + for agent_name in agent_order: + if agent_name in agents_with_findings: + routes.append( + RouteResult( + agent_name=agent_name, + finding=agents_with_findings[agent_name], + priority=priority, + ) + ) + priority += 1 + + return routes