conversation detection tests
This commit is contained in:
223
tests/test_conversation_router.py
Normal file
223
tests/test_conversation_router.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Tests for agent routing in conversations."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from arbiter.conversation.detection import QuestionAnalysis
|
||||
from arbiter.conversation.router import AgentRouter
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import AgentName, Finding, Severity
|
||||
from tests.conftest import MockLLMClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MockLLMClient:
|
||||
return MockLLMClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_registry_with_explain(tmp_path: Path) -> PromptRegistry:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
# Create review prompts
|
||||
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
||||
(templates_dir / "style-v1.0.md").write_text("Style review: {{diff}}")
|
||||
(templates_dir / "complexity-v1.0.md").write_text("Complexity review: {{diff}}")
|
||||
|
||||
# Create explain prompts
|
||||
(templates_dir / "security-explain-v1.0.md").write_text(
|
||||
"Explain security: {{question}} {{finding_title}}"
|
||||
)
|
||||
(templates_dir / "style-explain-v1.0.md").write_text(
|
||||
"Explain style: {{question}} {{finding_title}}"
|
||||
)
|
||||
(templates_dir / "complexity-explain-v1.0.md").write_text(
|
||||
"Explain complexity: {{question}} {{finding_title}}"
|
||||
)
|
||||
|
||||
return PromptRegistry(templates_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(mock_llm: MockLLMClient, prompt_registry_with_explain: PromptRegistry) -> AgentRouter:
|
||||
return AgentRouter(mock_llm, prompt_registry_with_explain)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_findings() -> list[Finding]:
|
||||
return [
|
||||
Finding(
|
||||
id="sec-finding-1",
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection vulnerability",
|
||||
description="User input directly concatenated into SQL query",
|
||||
reasoning="String concatenation allows SQL injection",
|
||||
prompt_version="security-v1.0",
|
||||
),
|
||||
Finding(
|
||||
id="style-finding-1",
|
||||
agent=AgentName.STYLE,
|
||||
file="src/auth.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.LOW,
|
||||
confidence=0.8,
|
||||
title="Inconsistent naming",
|
||||
description="Variable name does not follow snake_case",
|
||||
reasoning="PEP 8 recommends snake_case",
|
||||
prompt_version="style-v1.0",
|
||||
),
|
||||
Finding(
|
||||
id="sec-finding-2",
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/api.py",
|
||||
line_start=30,
|
||||
line_end=35,
|
||||
severity=Severity.CRITICAL,
|
||||
confidence=0.95,
|
||||
title="Hardcoded credentials",
|
||||
description="API key is hardcoded in source",
|
||||
reasoning="Credentials should not be in source code",
|
||||
prompt_version="security-v1.0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestRoutingPriority:
|
||||
def test_routes_to_specific_finding(
|
||||
self, router: AgentRouter, sample_findings: list[Finding]
|
||||
) -> None:
|
||||
analysis = QuestionAnalysis(
|
||||
is_question=True,
|
||||
is_directed_at_arbiter=True,
|
||||
question_text="Why is this a problem?",
|
||||
mentioned_agents=[],
|
||||
mentioned_finding_ids=["sec-finding-1"],
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
routes = router.route(analysis, sample_findings)
|
||||
|
||||
assert len(routes) == 1
|
||||
assert routes[0].agent_name == AgentName.SECURITY
|
||||
assert routes[0].finding is not None
|
||||
assert routes[0].finding.id == "sec-finding-1"
|
||||
|
||||
def test_routes_to_mentioned_agents(
|
||||
self, router: AgentRouter, sample_findings: list[Finding]
|
||||
) -> None:
|
||||
analysis = QuestionAnalysis(
|
||||
is_question=True,
|
||||
is_directed_at_arbiter=True,
|
||||
question_text="What about the style issue?",
|
||||
mentioned_agents=[AgentName.STYLE],
|
||||
mentioned_finding_ids=[],
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
routes = router.route(analysis, sample_findings)
|
||||
|
||||
assert len(routes) == 1
|
||||
assert routes[0].agent_name == AgentName.STYLE
|
||||
|
||||
def test_routes_to_highest_severity_finding(
|
||||
self, router: AgentRouter, sample_findings: list[Finding]
|
||||
) -> None:
|
||||
analysis = QuestionAnalysis(
|
||||
is_question=True,
|
||||
is_directed_at_arbiter=True,
|
||||
question_text="Explain the security issues",
|
||||
mentioned_agents=[AgentName.SECURITY],
|
||||
mentioned_finding_ids=[],
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
routes = router.route(analysis, sample_findings)
|
||||
|
||||
assert len(routes) == 1
|
||||
# Should route to critical finding, not high
|
||||
assert routes[0].finding is not None
|
||||
assert routes[0].finding.severity == Severity.CRITICAL
|
||||
|
||||
def test_routes_to_all_agents_by_default(
|
||||
self, router: AgentRouter, sample_findings: list[Finding]
|
||||
) -> None:
|
||||
analysis = QuestionAnalysis(
|
||||
is_question=True,
|
||||
is_directed_at_arbiter=True,
|
||||
question_text="Can you explain these findings?",
|
||||
mentioned_agents=[],
|
||||
mentioned_finding_ids=[],
|
||||
confidence=0.7,
|
||||
)
|
||||
|
||||
routes = router.route(analysis, sample_findings)
|
||||
|
||||
# Should have routes for security and style (agents with findings)
|
||||
agent_names = [r.agent_name for r in routes]
|
||||
assert AgentName.SECURITY in agent_names
|
||||
assert AgentName.STYLE in agent_names
|
||||
|
||||
|
||||
class TestRouteResult:
|
||||
def test_route_result_with_finding(
|
||||
self, router: AgentRouter, sample_findings: list[Finding]
|
||||
) -> None:
|
||||
analysis = QuestionAnalysis(
|
||||
is_question=True,
|
||||
is_directed_at_arbiter=True,
|
||||
question_text="Why?",
|
||||
mentioned_agents=[AgentName.SECURITY],
|
||||
mentioned_finding_ids=[],
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
routes = router.route(analysis, sample_findings)
|
||||
|
||||
assert routes[0].finding is not None
|
||||
assert routes[0].priority == 0
|
||||
|
||||
def test_route_without_finding(
|
||||
self, router: AgentRouter, sample_findings: list[Finding]
|
||||
) -> None:
|
||||
# Request complexity agent but no complexity findings exist
|
||||
analysis = QuestionAnalysis(
|
||||
is_question=True,
|
||||
is_directed_at_arbiter=True,
|
||||
question_text="What about complexity?",
|
||||
mentioned_agents=[AgentName.COMPLEXITY],
|
||||
mentioned_finding_ids=[],
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
routes = router.route(analysis, sample_findings)
|
||||
|
||||
assert len(routes) == 1
|
||||
assert routes[0].agent_name == AgentName.COMPLEXITY
|
||||
assert routes[0].finding is None
|
||||
|
||||
|
||||
class TestAgentInstantiation:
|
||||
def test_get_security_agent(self, router: AgentRouter) -> None:
|
||||
agent = router.get_agent(AgentName.SECURITY)
|
||||
assert agent.name == AgentName.SECURITY
|
||||
|
||||
def test_get_style_agent(self, router: AgentRouter) -> None:
|
||||
agent = router.get_agent(AgentName.STYLE)
|
||||
assert agent.name == AgentName.STYLE
|
||||
|
||||
def test_get_complexity_agent(self, router: AgentRouter) -> None:
|
||||
agent = router.get_agent(AgentName.COMPLEXITY)
|
||||
assert agent.name == AgentName.COMPLEXITY
|
||||
|
||||
def test_agent_caching(self, router: AgentRouter) -> None:
|
||||
agent1 = router.get_agent(AgentName.SECURITY)
|
||||
agent2 = router.get_agent(AgentName.SECURITY)
|
||||
assert agent1 is agent2
|
||||
Reference in New Issue
Block a user