feat(agents): implement agent framework and CLI
This commit is contained in:
490
tests/conftest.py
Normal file
490
tests/conftest.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""Pytest configuration and fixtures."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from arbiter.db.models import Base, FindingModel, ReviewModel
|
||||
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
|
||||
from arbiter.llm.client import LLMClient, LLMResponse
|
||||
from arbiter.llm.prompts import PromptRegistry
|
||||
from arbiter.models import Policy
|
||||
from arbiter.models.enums import AgentName, Severity, Verdict
|
||||
|
||||
|
||||
class MockLLMClient(LLMClient):
|
||||
"""Mock LLM client for testing."""
|
||||
|
||||
def __init__(self, responses: list[str] | None = None) -> None:
|
||||
self.responses = responses or []
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._call_index = 0
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
**kwargs: Any,
|
||||
) -> LLMResponse:
|
||||
"""Record the call and return a canned response."""
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
content = ""
|
||||
if self._call_index < len(self.responses):
|
||||
content = self.responses[self._call_index]
|
||||
self._call_index += 1
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
cost_usd=0.001,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.calls = []
|
||||
self._call_index = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MockLLMClient:
|
||||
return MockLLMClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_with_findings() -> MockLLMClient:
|
||||
response = """```json
|
||||
[
|
||||
{
|
||||
"file": "src/auth.py",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
"severity": "high",
|
||||
"confidence": 0.9,
|
||||
"title": "SQL Injection vulnerability",
|
||||
"description": "User input is directly concatenated into SQL query",
|
||||
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||
"suggestion": "Use parameterized queries instead",
|
||||
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
|
||||
}
|
||||
]
|
||||
```"""
|
||||
return MockLLMClient(responses=[response])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_diff() -> str:
|
||||
return """diff --git a/src/auth.py b/src/auth.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/auth.py
|
||||
+++ b/src/auth.py
|
||||
@@ -8,6 +8,12 @@ def authenticate(username, password):
|
||||
if not username or not password:
|
||||
return None
|
||||
|
||||
+ # Check user in database
|
||||
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
||||
+ cursor.execute(query)
|
||||
+ user = cursor.fetchone()
|
||||
+ if user and user.password == password:
|
||||
+ return user
|
||||
return None
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_policy() -> Policy:
|
||||
return Policy()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_registry(tmp_path: Path) -> PromptRegistry:
|
||||
templates_dir = tmp_path / "templates"
|
||||
templates_dir.mkdir()
|
||||
|
||||
(templates_dir / "security-v1.0.md").write_text(
|
||||
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
(templates_dir / "style-v1.0.md").write_text(
|
||||
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
(templates_dir / "complexity-v1.0.md").write_text(
|
||||
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
||||
)
|
||||
|
||||
return PromptRegistry(templates_dir)
|
||||
|
||||
|
||||
# Database fixtures for integration tests
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
session_factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||
from arbiter.api.deps import get_db, get_redis
|
||||
from arbiter.main import app
|
||||
|
||||
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
|
||||
yield MockRedis()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_redis] = override_get_redis
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
class MockRedis:
|
||||
"""Mock Redis client for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, Any] = {}
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return self._data.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
||||
self._data[key] = value
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def ping(self) -> bool:
|
||||
return True
|
||||
|
||||
async def llen(self, key: str) -> int:
|
||||
data = self._data.get(key, [])
|
||||
return len(data) if isinstance(data, list) else 0
|
||||
|
||||
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
|
||||
return type("Job", (), {"job_id": "test-job-id"})()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis() -> MockRedis:
|
||||
return MockRedis()
|
||||
|
||||
|
||||
class MockPlatformClient(PlatformClient):
|
||||
"""Mock platform client for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._comments: list[Comment] = []
|
||||
self._posted_comments: list[dict[str, Any]] = []
|
||||
self._status_updates: list[dict[str, Any]] = []
|
||||
self._diff = "mock diff content"
|
||||
self._closed = False
|
||||
self._fail_on: set[str] = set() # Methods to fail on
|
||||
|
||||
@property
|
||||
def platform(self) -> Platform:
|
||||
return Platform.GITHUB
|
||||
|
||||
async def get_pr_diff(
|
||||
self,
|
||||
repository: str, # noqa: ARG002
|
||||
pr_number: int, # noqa: ARG002
|
||||
) -> str:
|
||||
if "get_pr_diff" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_pr_diff")
|
||||
return self._diff
|
||||
|
||||
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
||||
if "post_comment" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: post_comment")
|
||||
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
|
||||
self._posted_comments.append(
|
||||
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
|
||||
)
|
||||
return comment_url
|
||||
|
||||
async def update_commit_status(
|
||||
self,
|
||||
repository: str,
|
||||
sha: str,
|
||||
status: CommitStatus,
|
||||
description: str,
|
||||
context: str,
|
||||
target_url: str | None = None,
|
||||
) -> None:
|
||||
if "update_commit_status" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: update_commit_status")
|
||||
self._status_updates.append(
|
||||
{
|
||||
"repository": repository,
|
||||
"sha": sha,
|
||||
"status": status,
|
||||
"description": description,
|
||||
"context": context,
|
||||
"target_url": target_url,
|
||||
}
|
||||
)
|
||||
|
||||
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
|
||||
if "get_pr_info" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_pr_info")
|
||||
from arbiter.integrations.base import PullRequestInfo
|
||||
|
||||
return PullRequestInfo(
|
||||
platform=Platform.GITHUB,
|
||||
repository=repository,
|
||||
pr_number=pr_number,
|
||||
head_sha="abc123",
|
||||
base_sha="def456",
|
||||
head_ref="feature",
|
||||
base_ref="main",
|
||||
title="Test PR",
|
||||
author="testuser",
|
||||
url=f"https://github.com/{repository}/pull/{pr_number}",
|
||||
is_draft=False,
|
||||
)
|
||||
|
||||
async def get_comments(
|
||||
self,
|
||||
repository: str, # noqa: ARG002
|
||||
pr_number: int, # noqa: ARG002
|
||||
) -> list[Comment]:
|
||||
if "get_comments" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: get_comments")
|
||||
return self._comments
|
||||
|
||||
async def update_comment(
|
||||
self, repository: str, pr_number: int, comment_id: str, body: str
|
||||
) -> str:
|
||||
if "update_comment" in self._fail_on:
|
||||
from arbiter.integrations import IntegrationError
|
||||
|
||||
raise IntegrationError("Mock failure: update_comment")
|
||||
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
|
||||
self._posted_comments.append(
|
||||
{
|
||||
"repository": repository,
|
||||
"pr_number": pr_number,
|
||||
"comment_id": comment_id,
|
||||
"body": body,
|
||||
"url": comment_url,
|
||||
}
|
||||
)
|
||||
return comment_url
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_platform_client() -> MockPlatformClient:
|
||||
return MockPlatformClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
|
||||
review = ReviewModel(
|
||||
id=str(uuid4()),
|
||||
repository="owner/repo",
|
||||
pr_number=42,
|
||||
pr_title="Test PR",
|
||||
base_sha="abc1234567890",
|
||||
head_sha="def0987654321",
|
||||
author="testuser",
|
||||
is_draft=False,
|
||||
status="completed",
|
||||
verdict=Verdict.COMMENT,
|
||||
verdict_confidence=0.75,
|
||||
verdict_reasoning="Found some issues to discuss",
|
||||
total_tokens=1500,
|
||||
total_cost_usd=0.015,
|
||||
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
|
||||
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
|
||||
created_at=datetime.now(UTC),
|
||||
started_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
db_session.add(review)
|
||||
|
||||
# Add findings
|
||||
finding1 = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/auth.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title="SQL Injection vulnerability",
|
||||
description="User input is directly concatenated into SQL query",
|
||||
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
|
||||
suggestion="Use parameterized queries instead",
|
||||
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
finding2 = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.STYLE,
|
||||
file="src/auth.py",
|
||||
line_start=20,
|
||||
line_end=25,
|
||||
severity=Severity.MEDIUM,
|
||||
confidence=0.85,
|
||||
title="Long function",
|
||||
description="Function exceeds 50 lines",
|
||||
reasoning="Long functions are harder to test and maintain",
|
||||
suggestion="Consider breaking into smaller functions",
|
||||
references=[],
|
||||
prompt_version="style-v1.0",
|
||||
)
|
||||
db_session.add(finding1)
|
||||
db_session.add(finding2)
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(review)
|
||||
return review
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
|
||||
reviews = []
|
||||
for i in range(5):
|
||||
review = ReviewModel(
|
||||
id=str(uuid4()),
|
||||
repository="owner/repo" if i < 3 else "other/repo",
|
||||
pr_number=i + 1,
|
||||
pr_title=f"Test PR #{i + 1}",
|
||||
base_sha=f"base{i:040d}",
|
||||
head_sha=f"head{i:040d}",
|
||||
author="testuser" if i % 2 == 0 else "otheruser",
|
||||
is_draft=False,
|
||||
status="completed" if i < 4 else "failed",
|
||||
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
|
||||
verdict_confidence=0.8 if i < 4 else None,
|
||||
total_tokens=1000 * (i + 1),
|
||||
total_cost_usd=0.01 * (i + 1),
|
||||
created_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC) if i < 4 else None,
|
||||
)
|
||||
db_session.add(review)
|
||||
reviews.append(review)
|
||||
|
||||
# Add a finding to each completed review
|
||||
if i < 4:
|
||||
finding = FindingModel(
|
||||
id=str(uuid4()),
|
||||
review_id=review.id,
|
||||
agent=AgentName.SECURITY,
|
||||
file="src/test.py",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
|
||||
confidence=0.9,
|
||||
title=f"Finding {i + 1}",
|
||||
description="Test finding",
|
||||
reasoning="Test reasoning",
|
||||
prompt_version="security-v1.0",
|
||||
)
|
||||
db_session.add(finding)
|
||||
|
||||
await db_session.commit()
|
||||
for review in reviews:
|
||||
await db_session.refresh(review)
|
||||
return reviews
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> Any:
|
||||
class MockSecretStr:
|
||||
def __init__(self, value: str) -> None:
|
||||
self._value = value
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
return self._value
|
||||
|
||||
class MockSettings:
|
||||
github_token = MockSecretStr("ghp_test_token")
|
||||
gitlab_token = MockSecretStr("glpat_test_token")
|
||||
github_base_url = "https://api.github.com"
|
||||
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||
github_webhook_secret = MockSecretStr("webhook_secret")
|
||||
gitlab_webhook_token = MockSecretStr("gitlab_token")
|
||||
integration_timeout = 30
|
||||
integration_max_retries = 3
|
||||
llm_timeout = 60
|
||||
llm_max_retries = 3
|
||||
post_comments = True
|
||||
update_status = True
|
||||
status_check_context = "arbiter"
|
||||
templates_dir = Path("templates")
|
||||
followup_enabled = True
|
||||
followup_confidence_threshold = 0.5
|
||||
|
||||
return MockSettings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_no_github() -> Any:
|
||||
class MockSettings:
|
||||
github_token = None
|
||||
gitlab_token = None
|
||||
github_base_url = "https://api.github.com"
|
||||
gitlab_base_url = "https://gitlab.com/api/v4"
|
||||
integration_timeout = 30
|
||||
integration_max_retries = 3
|
||||
post_comments = False
|
||||
update_status = False
|
||||
status_check_context = "arbiter"
|
||||
templates_dir = Path("templates")
|
||||
|
||||
return MockSettings()
|
||||
Reference in New Issue
Block a user