feat(conversation): add question detection and routing
This commit is contained in:
11
src/arbiter/conversation/__init__.py
Normal file
11
src/arbiter/conversation/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Conversation handling for follow-up questions."""
|
||||||
|
|
||||||
|
from arbiter.conversation.detection import QuestionAnalysis, QuestionDetector
|
||||||
|
from arbiter.conversation.router import AgentRouter, RouteResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentRouter",
|
||||||
|
"QuestionAnalysis",
|
||||||
|
"QuestionDetector",
|
||||||
|
"RouteResult",
|
||||||
|
]
|
||||||
234
src/arbiter/conversation/detection.py
Normal file
234
src/arbiter/conversation/detection.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""Question detection for PR comments."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from arbiter.models import AgentName, Finding
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnalysis(BaseModel):
|
||||||
|
"""Result of analyzing a PR comment for questions."""
|
||||||
|
|
||||||
|
is_question: bool = Field(description="Whether the comment appears to be a question")
|
||||||
|
is_directed_at_arbiter: bool = Field(
|
||||||
|
description="Whether the question is directed at Arbiter specifically"
|
||||||
|
)
|
||||||
|
question_text: str = Field(description="The extracted question text")
|
||||||
|
mentioned_agents: list[AgentName] = Field(
|
||||||
|
default_factory=list, description="Agents mentioned or implied"
|
||||||
|
)
|
||||||
|
mentioned_finding_ids: list[str] = Field(
|
||||||
|
default_factory=list, description="Finding IDs referenced in the comment"
|
||||||
|
)
|
||||||
|
confidence: float = Field(
|
||||||
|
ge=0.0, le=1.0, description="Confidence that this is a valid follow-up question"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionDetector:
|
||||||
|
"""Detects questions in PR comments that should be answered by Arbiter."""
|
||||||
|
|
||||||
|
# Patterns that indicate a question
|
||||||
|
QUESTION_PATTERNS: ClassVar[list[str]] = [
|
||||||
|
r"\?$", # Ends with question mark
|
||||||
|
r"\?\s*$", # Ends with question mark and optional whitespace
|
||||||
|
r"^(why|what|how|can you|could you|please explain|explain|tell me)",
|
||||||
|
r"\b(why|what|how|can|could|would|should)\b.*\?",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Patterns that indicate the comment is directed at Arbiter
|
||||||
|
ARBITER_PATTERNS: ClassVar[list[str]] = [
|
||||||
|
r"@arbiter\b",
|
||||||
|
r"\barbiter\b",
|
||||||
|
r"@security.?agent\b",
|
||||||
|
r"@style.?agent\b",
|
||||||
|
r"@complexity.?agent\b",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Keywords associated with each agent
|
||||||
|
AGENT_KEYWORDS: ClassVar[dict[AgentName, list[str]]] = {
|
||||||
|
AgentName.SECURITY: [
|
||||||
|
"security",
|
||||||
|
"vulnerability",
|
||||||
|
"injection",
|
||||||
|
"xss",
|
||||||
|
"csrf",
|
||||||
|
"auth",
|
||||||
|
"authentication",
|
||||||
|
"authorization",
|
||||||
|
"exploit",
|
||||||
|
"attack",
|
||||||
|
"risk",
|
||||||
|
"threat",
|
||||||
|
"sensitive",
|
||||||
|
"secret",
|
||||||
|
"credential",
|
||||||
|
"permission",
|
||||||
|
"access control",
|
||||||
|
"owasp",
|
||||||
|
"cwe",
|
||||||
|
],
|
||||||
|
AgentName.STYLE: [
|
||||||
|
"style",
|
||||||
|
"naming",
|
||||||
|
"readability",
|
||||||
|
"convention",
|
||||||
|
"format",
|
||||||
|
"formatting",
|
||||||
|
"consistent",
|
||||||
|
"consistency",
|
||||||
|
"pattern",
|
||||||
|
"readable",
|
||||||
|
"clean",
|
||||||
|
"code style",
|
||||||
|
"lint",
|
||||||
|
],
|
||||||
|
AgentName.COMPLEXITY: [
|
||||||
|
"complexity",
|
||||||
|
"architecture",
|
||||||
|
"design",
|
||||||
|
"refactor",
|
||||||
|
"maintainability",
|
||||||
|
"coupling",
|
||||||
|
"cohesion",
|
||||||
|
"abstraction",
|
||||||
|
"module",
|
||||||
|
"dependency",
|
||||||
|
"cyclomatic",
|
||||||
|
"cognitive",
|
||||||
|
"nested",
|
||||||
|
"simplify",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Patterns to extract finding IDs (UUIDs)
|
||||||
|
UUID_PATTERN: ClassVar[str] = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||||
|
|
||||||
|
def __init__(self, confidence_threshold: float = 0.5) -> None:
|
||||||
|
self.confidence_threshold = confidence_threshold
|
||||||
|
self._compiled_question_patterns = [
|
||||||
|
re.compile(p, re.IGNORECASE | re.MULTILINE) for p in self.QUESTION_PATTERNS
|
||||||
|
]
|
||||||
|
self._compiled_arbiter_patterns = [
|
||||||
|
re.compile(p, re.IGNORECASE) for p in self.ARBITER_PATTERNS
|
||||||
|
]
|
||||||
|
self._uuid_pattern = re.compile(self.UUID_PATTERN, re.IGNORECASE)
|
||||||
|
|
||||||
|
def analyze(
|
||||||
|
self,
|
||||||
|
comment_text: str,
|
||||||
|
findings: list[Finding] | None = None,
|
||||||
|
) -> QuestionAnalysis:
|
||||||
|
"""Analyze a PR comment to detect if it's a question for Arbiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
comment_text: The comment text to analyze.
|
||||||
|
findings: Optional list of findings from the review to match against.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QuestionAnalysis with detection results.
|
||||||
|
"""
|
||||||
|
text = comment_text.strip()
|
||||||
|
text_lower = text.lower()
|
||||||
|
|
||||||
|
# Check if it's a question
|
||||||
|
is_question = self._is_question(text)
|
||||||
|
|
||||||
|
# Check if directed at Arbiter
|
||||||
|
is_directed = self._is_directed_at_arbiter(text)
|
||||||
|
|
||||||
|
# Extract mentioned agents
|
||||||
|
mentioned_agents = self._extract_mentioned_agents(text_lower)
|
||||||
|
|
||||||
|
# Extract mentioned finding IDs
|
||||||
|
mentioned_finding_ids = self._extract_finding_ids(text, findings)
|
||||||
|
|
||||||
|
# Calculate confidence
|
||||||
|
confidence = self._calculate_confidence(
|
||||||
|
is_question=is_question,
|
||||||
|
is_directed=is_directed,
|
||||||
|
has_agent_mentions=len(mentioned_agents) > 0,
|
||||||
|
has_finding_refs=len(mentioned_finding_ids) > 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the question text (clean up @mentions and excess whitespace)
|
||||||
|
question_text = self._extract_question_text(text)
|
||||||
|
|
||||||
|
return QuestionAnalysis(
|
||||||
|
is_question=is_question,
|
||||||
|
is_directed_at_arbiter=is_directed,
|
||||||
|
question_text=question_text,
|
||||||
|
mentioned_agents=mentioned_agents,
|
||||||
|
mentioned_finding_ids=mentioned_finding_ids,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_question(self, text: str) -> bool:
|
||||||
|
return any(pattern.search(text) for pattern in self._compiled_question_patterns)
|
||||||
|
|
||||||
|
def _is_directed_at_arbiter(self, text: str) -> bool:
|
||||||
|
return any(pattern.search(text) for pattern in self._compiled_arbiter_patterns)
|
||||||
|
|
||||||
|
def _extract_mentioned_agents(self, text_lower: str) -> list[AgentName]:
|
||||||
|
"""Extract agents mentioned or implied in the text."""
|
||||||
|
mentioned = []
|
||||||
|
for agent, keywords in self.AGENT_KEYWORDS.items():
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword in text_lower:
|
||||||
|
if agent not in mentioned:
|
||||||
|
mentioned.append(agent)
|
||||||
|
break
|
||||||
|
return mentioned
|
||||||
|
|
||||||
|
def _extract_finding_ids(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
findings: list[Finding] | None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract finding IDs referenced in the text."""
|
||||||
|
# Find UUIDs in the text
|
||||||
|
uuid_matches = self._uuid_pattern.findall(text.lower())
|
||||||
|
|
||||||
|
if not findings:
|
||||||
|
return uuid_matches
|
||||||
|
|
||||||
|
# Filter to only valid finding IDs
|
||||||
|
valid_ids = {f.id.lower() for f in findings}
|
||||||
|
return [uid for uid in uuid_matches if uid in valid_ids]
|
||||||
|
|
||||||
|
def _calculate_confidence(
|
||||||
|
self,
|
||||||
|
is_question: bool,
|
||||||
|
is_directed: bool,
|
||||||
|
has_agent_mentions: bool,
|
||||||
|
has_finding_refs: bool,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate confidence score for the analysis."""
|
||||||
|
if not is_question:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Base confidence for being a question
|
||||||
|
confidence = 0.4
|
||||||
|
|
||||||
|
# Boost for being directed at Arbiter
|
||||||
|
if is_directed:
|
||||||
|
confidence += 0.4
|
||||||
|
|
||||||
|
# Boost for mentioning specific agents
|
||||||
|
if has_agent_mentions:
|
||||||
|
confidence += 0.1
|
||||||
|
|
||||||
|
# Boost for referencing specific findings
|
||||||
|
if has_finding_refs:
|
||||||
|
confidence += 0.1
|
||||||
|
|
||||||
|
return min(confidence, 1.0)
|
||||||
|
|
||||||
|
def _extract_question_text(self, text: str) -> str:
|
||||||
|
# Remove @mentions
|
||||||
|
cleaned = re.sub(r"@\w+", "", text)
|
||||||
|
# Collapse multiple whitespace
|
||||||
|
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||||
|
return cleaned.strip()
|
||||||
150
src/arbiter/conversation/router.py
Normal file
150
src/arbiter/conversation/router.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""Agent routing for follow-up questions."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from arbiter.agents import Agent, ComplexityAgent, SecurityAgent, StyleAgent
|
||||||
|
from arbiter.conversation.detection import QuestionAnalysis
|
||||||
|
from arbiter.llm.client import LLMClient
|
||||||
|
from arbiter.llm.prompts import PromptRegistry
|
||||||
|
from arbiter.models import AgentName, Finding
|
||||||
|
|
||||||
|
|
||||||
|
class RouteResult(BaseModel):
|
||||||
|
"""Result of routing a question to agents."""
|
||||||
|
|
||||||
|
agent_name: AgentName = Field(description="The agent to handle this route")
|
||||||
|
finding: Finding | None = Field(
|
||||||
|
default=None, description="Specific finding this route is about"
|
||||||
|
)
|
||||||
|
priority: int = Field(default=0, description="Priority order (lower = higher priority)")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRouter:
|
||||||
|
"""Routes follow-up questions to appropriate agents."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: LLMClient,
|
||||||
|
prompt_registry: PromptRegistry,
|
||||||
|
) -> None:
|
||||||
|
self.llm_client = llm_client
|
||||||
|
self.prompt_registry = prompt_registry
|
||||||
|
self._agents: dict[AgentName, Agent] = {}
|
||||||
|
|
||||||
|
def get_agent(self, agent_name: AgentName) -> Agent:
|
||||||
|
if agent_name not in self._agents:
|
||||||
|
if agent_name == AgentName.SECURITY:
|
||||||
|
self._agents[agent_name] = SecurityAgent(self.llm_client, self.prompt_registry)
|
||||||
|
elif agent_name == AgentName.STYLE:
|
||||||
|
self._agents[agent_name] = StyleAgent(self.llm_client, self.prompt_registry)
|
||||||
|
elif agent_name == AgentName.COMPLEXITY:
|
||||||
|
self._agents[agent_name] = ComplexityAgent(self.llm_client, self.prompt_registry)
|
||||||
|
return self._agents[agent_name]
|
||||||
|
|
||||||
|
def route(
|
||||||
|
self,
|
||||||
|
analysis: QuestionAnalysis,
|
||||||
|
findings: list[Finding],
|
||||||
|
) -> list[RouteResult]:
|
||||||
|
"""Determine which agents should respond to a question.
|
||||||
|
|
||||||
|
The routing logic follows this priority:
|
||||||
|
1. If specific finding IDs are referenced, route to those findings' agents
|
||||||
|
2. If specific agents are mentioned by keyword, route to those agents
|
||||||
|
3. If no specific agents/findings, route to all agents that have findings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
analysis: The question analysis result.
|
||||||
|
findings: List of findings from the review.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of RouteResult indicating which agents should respond.
|
||||||
|
"""
|
||||||
|
routes: list[RouteResult] = []
|
||||||
|
|
||||||
|
# Build a map of finding ID to finding
|
||||||
|
findings_by_id = {f.id: f for f in findings}
|
||||||
|
|
||||||
|
# Priority 1: Specific findings referenced
|
||||||
|
if analysis.mentioned_finding_ids:
|
||||||
|
for i, finding_id in enumerate(analysis.mentioned_finding_ids):
|
||||||
|
if finding_id in findings_by_id:
|
||||||
|
finding = findings_by_id[finding_id]
|
||||||
|
routes.append(
|
||||||
|
RouteResult(
|
||||||
|
agent_name=finding.agent,
|
||||||
|
finding=finding,
|
||||||
|
priority=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if routes:
|
||||||
|
return routes
|
||||||
|
|
||||||
|
# Priority 2: Specific agents mentioned by keyword
|
||||||
|
if analysis.mentioned_agents:
|
||||||
|
# Find findings for each mentioned agent
|
||||||
|
findings_by_agent: dict[AgentName, list[Finding]] = {}
|
||||||
|
for finding in findings:
|
||||||
|
if finding.agent in analysis.mentioned_agents:
|
||||||
|
if finding.agent not in findings_by_agent:
|
||||||
|
findings_by_agent[finding.agent] = []
|
||||||
|
findings_by_agent[finding.agent].append(finding)
|
||||||
|
|
||||||
|
for priority, agent_name in enumerate(analysis.mentioned_agents):
|
||||||
|
agent_findings = findings_by_agent.get(agent_name, [])
|
||||||
|
if agent_findings:
|
||||||
|
# Route to most relevant finding (highest severity)
|
||||||
|
severity_order = ["critical", "high", "medium", "low", "info"]
|
||||||
|
sorted_findings = sorted(
|
||||||
|
agent_findings,
|
||||||
|
key=lambda f: severity_order.index(f.severity.value),
|
||||||
|
)
|
||||||
|
routes.append(
|
||||||
|
RouteResult(
|
||||||
|
agent_name=agent_name,
|
||||||
|
finding=sorted_findings[0],
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Agent mentioned but has no findings - still route
|
||||||
|
routes.append(
|
||||||
|
RouteResult(
|
||||||
|
agent_name=agent_name,
|
||||||
|
finding=None,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if routes:
|
||||||
|
return routes
|
||||||
|
|
||||||
|
# Priority 3: Default - route to all agents with findings
|
||||||
|
agents_with_findings: dict[AgentName, Finding] = {}
|
||||||
|
for finding in findings:
|
||||||
|
if finding.agent not in agents_with_findings:
|
||||||
|
agents_with_findings[finding.agent] = finding
|
||||||
|
else:
|
||||||
|
# Keep highest severity finding
|
||||||
|
existing = agents_with_findings[finding.agent]
|
||||||
|
severity_order = ["critical", "high", "medium", "low", "info"]
|
||||||
|
if severity_order.index(finding.severity.value) < severity_order.index(
|
||||||
|
existing.severity.value
|
||||||
|
):
|
||||||
|
agents_with_findings[finding.agent] = finding
|
||||||
|
|
||||||
|
# Sort agents by standard order
|
||||||
|
agent_order = [AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY]
|
||||||
|
priority = 0
|
||||||
|
for agent_name in agent_order:
|
||||||
|
if agent_name in agents_with_findings:
|
||||||
|
routes.append(
|
||||||
|
RouteResult(
|
||||||
|
agent_name=agent_name,
|
||||||
|
finding=agents_with_findings[agent_name],
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
priority += 1
|
||||||
|
|
||||||
|
return routes
|
||||||
Reference in New Issue
Block a user