feat(conversation): add question detection and routing
This commit is contained in:
11
src/arbiter/conversation/__init__.py
Normal file
11
src/arbiter/conversation/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Conversation handling for follow-up questions."""
|
||||
|
||||
from arbiter.conversation.detection import QuestionAnalysis, QuestionDetector
|
||||
from arbiter.conversation.router import AgentRouter, RouteResult
|
||||
|
||||
__all__ = [
|
||||
"AgentRouter",
|
||||
"QuestionAnalysis",
|
||||
"QuestionDetector",
|
||||
"RouteResult",
|
||||
]
|
||||
234
src/arbiter/conversation/detection.py
Normal file
234
src/arbiter/conversation/detection.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Question detection for PR comments."""
|
||||
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.models import AgentName, Finding
|
||||
|
||||
|
||||
class QuestionAnalysis(BaseModel):
|
||||
"""Result of analyzing a PR comment for questions."""
|
||||
|
||||
is_question: bool = Field(description="Whether the comment appears to be a question")
|
||||
is_directed_at_arbiter: bool = Field(
|
||||
description="Whether the question is directed at Arbiter specifically"
|
||||
)
|
||||
question_text: str = Field(description="The extracted question text")
|
||||
mentioned_agents: list[AgentName] = Field(
|
||||
default_factory=list, description="Agents mentioned or implied"
|
||||
)
|
||||
mentioned_finding_ids: list[str] = Field(
|
||||
default_factory=list, description="Finding IDs referenced in the comment"
|
||||
)
|
||||
confidence: float = Field(
|
||||
ge=0.0, le=1.0, description="Confidence that this is a valid follow-up question"
|
||||
)
|
||||
|
||||
|
||||
class QuestionDetector:
|
||||
"""Detects questions in PR comments that should be answered by Arbiter."""
|
||||
|
||||
# Patterns that indicate a question
|
||||
QUESTION_PATTERNS: ClassVar[list[str]] = [
|
||||
r"\?$", # Ends with question mark
|
||||
r"\?\s*$", # Ends with question mark and optional whitespace
|
||||
r"^(why|what|how|can you|could you|please explain|explain|tell me)",
|
||||
r"\b(why|what|how|can|could|would|should)\b.*\?",
|
||||
]
|
||||
|
||||
# Patterns that indicate the comment is directed at Arbiter
|
||||
ARBITER_PATTERNS: ClassVar[list[str]] = [
|
||||
r"@arbiter\b",
|
||||
r"\barbiter\b",
|
||||
r"@security.?agent\b",
|
||||
r"@style.?agent\b",
|
||||
r"@complexity.?agent\b",
|
||||
]
|
||||
|
||||
# Keywords associated with each agent
|
||||
AGENT_KEYWORDS: ClassVar[dict[AgentName, list[str]]] = {
|
||||
AgentName.SECURITY: [
|
||||
"security",
|
||||
"vulnerability",
|
||||
"injection",
|
||||
"xss",
|
||||
"csrf",
|
||||
"auth",
|
||||
"authentication",
|
||||
"authorization",
|
||||
"exploit",
|
||||
"attack",
|
||||
"risk",
|
||||
"threat",
|
||||
"sensitive",
|
||||
"secret",
|
||||
"credential",
|
||||
"permission",
|
||||
"access control",
|
||||
"owasp",
|
||||
"cwe",
|
||||
],
|
||||
AgentName.STYLE: [
|
||||
"style",
|
||||
"naming",
|
||||
"readability",
|
||||
"convention",
|
||||
"format",
|
||||
"formatting",
|
||||
"consistent",
|
||||
"consistency",
|
||||
"pattern",
|
||||
"readable",
|
||||
"clean",
|
||||
"code style",
|
||||
"lint",
|
||||
],
|
||||
AgentName.COMPLEXITY: [
|
||||
"complexity",
|
||||
"architecture",
|
||||
"design",
|
||||
"refactor",
|
||||
"maintainability",
|
||||
"coupling",
|
||||
"cohesion",
|
||||
"abstraction",
|
||||
"module",
|
||||
"dependency",
|
||||
"cyclomatic",
|
||||
"cognitive",
|
||||
"nested",
|
||||
"simplify",
|
||||
],
|
||||
}
|
||||
|
||||
# Patterns to extract finding IDs (UUIDs)
|
||||
UUID_PATTERN: ClassVar[str] = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
|
||||
def __init__(self, confidence_threshold: float = 0.5) -> None:
|
||||
self.confidence_threshold = confidence_threshold
|
||||
self._compiled_question_patterns = [
|
||||
re.compile(p, re.IGNORECASE | re.MULTILINE) for p in self.QUESTION_PATTERNS
|
||||
]
|
||||
self._compiled_arbiter_patterns = [
|
||||
re.compile(p, re.IGNORECASE) for p in self.ARBITER_PATTERNS
|
||||
]
|
||||
self._uuid_pattern = re.compile(self.UUID_PATTERN, re.IGNORECASE)
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
comment_text: str,
|
||||
findings: list[Finding] | None = None,
|
||||
) -> QuestionAnalysis:
|
||||
"""Analyze a PR comment to detect if it's a question for Arbiter.
|
||||
|
||||
Args:
|
||||
comment_text: The comment text to analyze.
|
||||
findings: Optional list of findings from the review to match against.
|
||||
|
||||
Returns:
|
||||
QuestionAnalysis with detection results.
|
||||
"""
|
||||
text = comment_text.strip()
|
||||
text_lower = text.lower()
|
||||
|
||||
# Check if it's a question
|
||||
is_question = self._is_question(text)
|
||||
|
||||
# Check if directed at Arbiter
|
||||
is_directed = self._is_directed_at_arbiter(text)
|
||||
|
||||
# Extract mentioned agents
|
||||
mentioned_agents = self._extract_mentioned_agents(text_lower)
|
||||
|
||||
# Extract mentioned finding IDs
|
||||
mentioned_finding_ids = self._extract_finding_ids(text, findings)
|
||||
|
||||
# Calculate confidence
|
||||
confidence = self._calculate_confidence(
|
||||
is_question=is_question,
|
||||
is_directed=is_directed,
|
||||
has_agent_mentions=len(mentioned_agents) > 0,
|
||||
has_finding_refs=len(mentioned_finding_ids) > 0,
|
||||
)
|
||||
|
||||
# Extract the question text (clean up @mentions and excess whitespace)
|
||||
question_text = self._extract_question_text(text)
|
||||
|
||||
return QuestionAnalysis(
|
||||
is_question=is_question,
|
||||
is_directed_at_arbiter=is_directed,
|
||||
question_text=question_text,
|
||||
mentioned_agents=mentioned_agents,
|
||||
mentioned_finding_ids=mentioned_finding_ids,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
def _is_question(self, text: str) -> bool:
|
||||
return any(pattern.search(text) for pattern in self._compiled_question_patterns)
|
||||
|
||||
def _is_directed_at_arbiter(self, text: str) -> bool:
|
||||
return any(pattern.search(text) for pattern in self._compiled_arbiter_patterns)
|
||||
|
||||
def _extract_mentioned_agents(self, text_lower: str) -> list[AgentName]:
|
||||
"""Extract agents mentioned or implied in the text."""
|
||||
mentioned = []
|
||||
for agent, keywords in self.AGENT_KEYWORDS.items():
|
||||
for keyword in keywords:
|
||||
if keyword in text_lower:
|
||||
if agent not in mentioned:
|
||||
mentioned.append(agent)
|
||||
break
|
||||
return mentioned
|
||||
|
||||
def _extract_finding_ids(
|
||||
self,
|
||||
text: str,
|
||||
findings: list[Finding] | None,
|
||||
) -> list[str]:
|
||||
"""Extract finding IDs referenced in the text."""
|
||||
# Find UUIDs in the text
|
||||
uuid_matches = self._uuid_pattern.findall(text.lower())
|
||||
|
||||
if not findings:
|
||||
return uuid_matches
|
||||
|
||||
# Filter to only valid finding IDs
|
||||
valid_ids = {f.id.lower() for f in findings}
|
||||
return [uid for uid in uuid_matches if uid in valid_ids]
|
||||
|
||||
def _calculate_confidence(
|
||||
self,
|
||||
is_question: bool,
|
||||
is_directed: bool,
|
||||
has_agent_mentions: bool,
|
||||
has_finding_refs: bool,
|
||||
) -> float:
|
||||
"""Calculate confidence score for the analysis."""
|
||||
if not is_question:
|
||||
return 0.0
|
||||
|
||||
# Base confidence for being a question
|
||||
confidence = 0.4
|
||||
|
||||
# Boost for being directed at Arbiter
|
||||
if is_directed:
|
||||
confidence += 0.4
|
||||
|
||||
# Boost for mentioning specific agents
|
||||
if has_agent_mentions:
|
||||
confidence += 0.1
|
||||
|
||||
# Boost for referencing specific findings
|
||||
if has_finding_refs:
|
||||
confidence += 0.1
|
||||
|
||||
return min(confidence, 1.0)
|
||||
|
||||
def _extract_question_text(self, text: str) -> str:
|
||||
# Remove @mentions
|
||||
cleaned = re.sub(r"@\w+", "", text)
|
||||
# Collapse multiple whitespace
|
||||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||
return cleaned.strip()
|
||||
150
src/arbiter/conversation/router.py
Normal file
150
src/arbiter/conversation/router.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Agent routing for follow-up questions."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from arbiter.agents import Agent, ComplexityAgent, SecurityAgent, StyleAgent
|
||||
from arbiter.conversation.detection import QuestionAnalysis
|
||||
from arbiter.llm.client import LLMClient
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import AgentName, Finding
|
||||
|
||||
|
||||
class RouteResult(BaseModel):
|
||||
"""Result of routing a question to agents."""
|
||||
|
||||
agent_name: AgentName = Field(description="The agent to handle this route")
|
||||
finding: Finding | None = Field(
|
||||
default=None, description="Specific finding this route is about"
|
||||
)
|
||||
priority: int = Field(default=0, description="Priority order (lower = higher priority)")
|
||||
|
||||
|
||||
class AgentRouter:
|
||||
"""Routes follow-up questions to appropriate agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: LLMClient,
|
||||
prompt_registry: PromptRegistry,
|
||||
) -> None:
|
||||
self.llm_client = llm_client
|
||||
self.prompt_registry = prompt_registry
|
||||
self._agents: dict[AgentName, Agent] = {}
|
||||
|
||||
def get_agent(self, agent_name: AgentName) -> Agent:
|
||||
if agent_name not in self._agents:
|
||||
if agent_name == AgentName.SECURITY:
|
||||
self._agents[agent_name] = SecurityAgent(self.llm_client, self.prompt_registry)
|
||||
elif agent_name == AgentName.STYLE:
|
||||
self._agents[agent_name] = StyleAgent(self.llm_client, self.prompt_registry)
|
||||
elif agent_name == AgentName.COMPLEXITY:
|
||||
self._agents[agent_name] = ComplexityAgent(self.llm_client, self.prompt_registry)
|
||||
return self._agents[agent_name]
|
||||
|
||||
def route(
|
||||
self,
|
||||
analysis: QuestionAnalysis,
|
||||
findings: list[Finding],
|
||||
) -> list[RouteResult]:
|
||||
"""Determine which agents should respond to a question.
|
||||
|
||||
The routing logic follows this priority:
|
||||
1. If specific finding IDs are referenced, route to those findings' agents
|
||||
2. If specific agents are mentioned by keyword, route to those agents
|
||||
3. If no specific agents/findings, route to all agents that have findings
|
||||
|
||||
Args:
|
||||
analysis: The question analysis result.
|
||||
findings: List of findings from the review.
|
||||
|
||||
Returns:
|
||||
List of RouteResult indicating which agents should respond.
|
||||
"""
|
||||
routes: list[RouteResult] = []
|
||||
|
||||
# Build a map of finding ID to finding
|
||||
findings_by_id = {f.id: f for f in findings}
|
||||
|
||||
# Priority 1: Specific findings referenced
|
||||
if analysis.mentioned_finding_ids:
|
||||
for i, finding_id in enumerate(analysis.mentioned_finding_ids):
|
||||
if finding_id in findings_by_id:
|
||||
finding = findings_by_id[finding_id]
|
||||
routes.append(
|
||||
RouteResult(
|
||||
agent_name=finding.agent,
|
||||
finding=finding,
|
||||
priority=i,
|
||||
)
|
||||
)
|
||||
if routes:
|
||||
return routes
|
||||
|
||||
# Priority 2: Specific agents mentioned by keyword
|
||||
if analysis.mentioned_agents:
|
||||
# Find findings for each mentioned agent
|
||||
findings_by_agent: dict[AgentName, list[Finding]] = {}
|
||||
for finding in findings:
|
||||
if finding.agent in analysis.mentioned_agents:
|
||||
if finding.agent not in findings_by_agent:
|
||||
findings_by_agent[finding.agent] = []
|
||||
findings_by_agent[finding.agent].append(finding)
|
||||
|
||||
for priority, agent_name in enumerate(analysis.mentioned_agents):
|
||||
agent_findings = findings_by_agent.get(agent_name, [])
|
||||
if agent_findings:
|
||||
# Route to most relevant finding (highest severity)
|
||||
severity_order = ["critical", "high", "medium", "low", "info"]
|
||||
sorted_findings = sorted(
|
||||
agent_findings,
|
||||
key=lambda f: severity_order.index(f.severity.value),
|
||||
)
|
||||
routes.append(
|
||||
RouteResult(
|
||||
agent_name=agent_name,
|
||||
finding=sorted_findings[0],
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Agent mentioned but has no findings - still route
|
||||
routes.append(
|
||||
RouteResult(
|
||||
agent_name=agent_name,
|
||||
finding=None,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
|
||||
if routes:
|
||||
return routes
|
||||
|
||||
# Priority 3: Default - route to all agents with findings
|
||||
agents_with_findings: dict[AgentName, Finding] = {}
|
||||
for finding in findings:
|
||||
if finding.agent not in agents_with_findings:
|
||||
agents_with_findings[finding.agent] = finding
|
||||
else:
|
||||
# Keep highest severity finding
|
||||
existing = agents_with_findings[finding.agent]
|
||||
severity_order = ["critical", "high", "medium", "low", "info"]
|
||||
if severity_order.index(finding.severity.value) < severity_order.index(
|
||||
existing.severity.value
|
||||
):
|
||||
agents_with_findings[finding.agent] = finding
|
||||
|
||||
# Sort agents by standard order
|
||||
agent_order = [AgentName.SECURITY, AgentName.STYLE, AgentName.COMPLEXITY]
|
||||
priority = 0
|
||||
for agent_name in agent_order:
|
||||
if agent_name in agents_with_findings:
|
||||
routes.append(
|
||||
RouteResult(
|
||||
agent_name=agent_name,
|
||||
finding=agents_with_findings[agent_name],
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
priority += 1
|
||||
|
||||
return routes
|
||||
Reference in New Issue
Block a user