224 lines
7.3 KiB
Python
224 lines
7.3 KiB
Python
"""Tests for agent routing in conversations."""
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from arbiter.conversation.detection import QuestionAnalysis
|
|
from arbiter.conversation.router import AgentRouter
|
|
from arbiter.llm.prompts import PromptRegistry
|
|
from arbiter.models import AgentName, Finding, Severity
|
|
from tests.conftest import MockLLMClient
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm() -> MockLLMClient:
|
|
return MockLLMClient()
|
|
|
|
|
|
@pytest.fixture
|
|
def prompt_registry_with_explain(tmp_path: Path) -> PromptRegistry:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
|
|
# Create review prompts
|
|
(templates_dir / "security-v1.0.md").write_text("Security review: {{diff}}")
|
|
(templates_dir / "style-v1.0.md").write_text("Style review: {{diff}}")
|
|
(templates_dir / "complexity-v1.0.md").write_text("Complexity review: {{diff}}")
|
|
|
|
# Create explain prompts
|
|
(templates_dir / "security-explain-v1.0.md").write_text(
|
|
"Explain security: {{question}} {{finding_title}}"
|
|
)
|
|
(templates_dir / "style-explain-v1.0.md").write_text(
|
|
"Explain style: {{question}} {{finding_title}}"
|
|
)
|
|
(templates_dir / "complexity-explain-v1.0.md").write_text(
|
|
"Explain complexity: {{question}} {{finding_title}}"
|
|
)
|
|
|
|
return PromptRegistry(templates_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def router(mock_llm: MockLLMClient, prompt_registry_with_explain: PromptRegistry) -> AgentRouter:
|
|
return AgentRouter(mock_llm, prompt_registry_with_explain)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_findings() -> list[Finding]:
|
|
return [
|
|
Finding(
|
|
id="sec-finding-1",
|
|
agent=AgentName.SECURITY,
|
|
file="src/auth.py",
|
|
line_start=10,
|
|
line_end=15,
|
|
severity=Severity.HIGH,
|
|
confidence=0.9,
|
|
title="SQL Injection vulnerability",
|
|
description="User input directly concatenated into SQL query",
|
|
reasoning="String concatenation allows SQL injection",
|
|
prompt_version="security-v1.0",
|
|
),
|
|
Finding(
|
|
id="style-finding-1",
|
|
agent=AgentName.STYLE,
|
|
file="src/auth.py",
|
|
line_start=20,
|
|
line_end=25,
|
|
severity=Severity.LOW,
|
|
confidence=0.8,
|
|
title="Inconsistent naming",
|
|
description="Variable name does not follow snake_case",
|
|
reasoning="PEP 8 recommends snake_case",
|
|
prompt_version="style-v1.0",
|
|
),
|
|
Finding(
|
|
id="sec-finding-2",
|
|
agent=AgentName.SECURITY,
|
|
file="src/api.py",
|
|
line_start=30,
|
|
line_end=35,
|
|
severity=Severity.CRITICAL,
|
|
confidence=0.95,
|
|
title="Hardcoded credentials",
|
|
description="API key is hardcoded in source",
|
|
reasoning="Credentials should not be in source code",
|
|
prompt_version="security-v1.0",
|
|
),
|
|
]
|
|
|
|
|
|
class TestRoutingPriority:
|
|
def test_routes_to_specific_finding(
|
|
self, router: AgentRouter, sample_findings: list[Finding]
|
|
) -> None:
|
|
analysis = QuestionAnalysis(
|
|
is_question=True,
|
|
is_directed_at_arbiter=True,
|
|
question_text="Why is this a problem?",
|
|
mentioned_agents=[],
|
|
mentioned_finding_ids=["sec-finding-1"],
|
|
confidence=0.9,
|
|
)
|
|
|
|
routes = router.route(analysis, sample_findings)
|
|
|
|
assert len(routes) == 1
|
|
assert routes[0].agent_name == AgentName.SECURITY
|
|
assert routes[0].finding is not None
|
|
assert routes[0].finding.id == "sec-finding-1"
|
|
|
|
def test_routes_to_mentioned_agents(
|
|
self, router: AgentRouter, sample_findings: list[Finding]
|
|
) -> None:
|
|
analysis = QuestionAnalysis(
|
|
is_question=True,
|
|
is_directed_at_arbiter=True,
|
|
question_text="What about the style issue?",
|
|
mentioned_agents=[AgentName.STYLE],
|
|
mentioned_finding_ids=[],
|
|
confidence=0.8,
|
|
)
|
|
|
|
routes = router.route(analysis, sample_findings)
|
|
|
|
assert len(routes) == 1
|
|
assert routes[0].agent_name == AgentName.STYLE
|
|
|
|
def test_routes_to_highest_severity_finding(
|
|
self, router: AgentRouter, sample_findings: list[Finding]
|
|
) -> None:
|
|
analysis = QuestionAnalysis(
|
|
is_question=True,
|
|
is_directed_at_arbiter=True,
|
|
question_text="Explain the security issues",
|
|
mentioned_agents=[AgentName.SECURITY],
|
|
mentioned_finding_ids=[],
|
|
confidence=0.8,
|
|
)
|
|
|
|
routes = router.route(analysis, sample_findings)
|
|
|
|
assert len(routes) == 1
|
|
# Should route to critical finding, not high
|
|
assert routes[0].finding is not None
|
|
assert routes[0].finding.severity == Severity.CRITICAL
|
|
|
|
def test_routes_to_all_agents_by_default(
|
|
self, router: AgentRouter, sample_findings: list[Finding]
|
|
) -> None:
|
|
analysis = QuestionAnalysis(
|
|
is_question=True,
|
|
is_directed_at_arbiter=True,
|
|
question_text="Can you explain these findings?",
|
|
mentioned_agents=[],
|
|
mentioned_finding_ids=[],
|
|
confidence=0.7,
|
|
)
|
|
|
|
routes = router.route(analysis, sample_findings)
|
|
|
|
# Should have routes for security and style (agents with findings)
|
|
agent_names = [r.agent_name for r in routes]
|
|
assert AgentName.SECURITY in agent_names
|
|
assert AgentName.STYLE in agent_names
|
|
|
|
|
|
class TestRouteResult:
|
|
def test_route_result_with_finding(
|
|
self, router: AgentRouter, sample_findings: list[Finding]
|
|
) -> None:
|
|
analysis = QuestionAnalysis(
|
|
is_question=True,
|
|
is_directed_at_arbiter=True,
|
|
question_text="Why?",
|
|
mentioned_agents=[AgentName.SECURITY],
|
|
mentioned_finding_ids=[],
|
|
confidence=0.8,
|
|
)
|
|
|
|
routes = router.route(analysis, sample_findings)
|
|
|
|
assert routes[0].finding is not None
|
|
assert routes[0].priority == 0
|
|
|
|
def test_route_without_finding(
|
|
self, router: AgentRouter, sample_findings: list[Finding]
|
|
) -> None:
|
|
# Request complexity agent but no complexity findings exist
|
|
analysis = QuestionAnalysis(
|
|
is_question=True,
|
|
is_directed_at_arbiter=True,
|
|
question_text="What about complexity?",
|
|
mentioned_agents=[AgentName.COMPLEXITY],
|
|
mentioned_finding_ids=[],
|
|
confidence=0.8,
|
|
)
|
|
|
|
routes = router.route(analysis, sample_findings)
|
|
|
|
assert len(routes) == 1
|
|
assert routes[0].agent_name == AgentName.COMPLEXITY
|
|
assert routes[0].finding is None
|
|
|
|
|
|
class TestAgentInstantiation:
|
|
def test_get_security_agent(self, router: AgentRouter) -> None:
|
|
agent = router.get_agent(AgentName.SECURITY)
|
|
assert agent.name == AgentName.SECURITY
|
|
|
|
def test_get_style_agent(self, router: AgentRouter) -> None:
|
|
agent = router.get_agent(AgentName.STYLE)
|
|
assert agent.name == AgentName.STYLE
|
|
|
|
def test_get_complexity_agent(self, router: AgentRouter) -> None:
|
|
agent = router.get_agent(AgentName.COMPLEXITY)
|
|
assert agent.name == AgentName.COMPLEXITY
|
|
|
|
def test_agent_caching(self, router: AgentRouter) -> None:
|
|
agent1 = router.get_agent(AgentName.SECURITY)
|
|
agent2 = router.get_agent(AgentName.SECURITY)
|
|
assert agent1 is agent2
|