Files
arbiter/tests/conftest.py

491 lines
15 KiB
Python

"""Pytest configuration and fixtures."""
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import uuid4
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from arbiter.db.models import Base, FindingModel, ReviewModel
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
from arbiter.llm.client import LLMClient, LLMResponse
from arbiter.llm.prompts import PromptRegistry
from arbiter.models import Policy
from arbiter.models.enums import AgentName, Severity, Verdict
class MockLLMClient(LLMClient):
"""Mock LLM client for testing."""
def __init__(self, responses: list[str] | None = None) -> None:
self.responses = responses or []
self.calls: list[dict[str, Any]] = []
self._call_index = 0
async def complete(
self,
messages: list[dict[str, str]],
model: str,
**kwargs: Any,
) -> LLMResponse:
"""Record the call and return a canned response."""
self.calls.append(
{
"messages": messages,
"model": model,
"kwargs": kwargs,
}
)
content = ""
if self._call_index < len(self.responses):
content = self.responses[self._call_index]
self._call_index += 1
return LLMResponse(
content=content,
model=model,
tokens_in=100,
tokens_out=50,
cost_usd=0.001,
)
def reset(self) -> None:
self.calls = []
self._call_index = 0
@pytest.fixture
def mock_llm() -> MockLLMClient:
return MockLLMClient()
@pytest.fixture
def mock_llm_with_findings() -> MockLLMClient:
response = """```json
[
{
"file": "src/auth.py",
"line_start": 10,
"line_end": 15,
"severity": "high",
"confidence": 0.9,
"title": "SQL Injection vulnerability",
"description": "User input is directly concatenated into SQL query",
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
"suggestion": "Use parameterized queries instead",
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
}
]
```"""
return MockLLMClient(responses=[response])
@pytest.fixture
def sample_diff() -> str:
return """diff --git a/src/auth.py b/src/auth.py
index 1234567..abcdefg 100644
--- a/src/auth.py
+++ b/src/auth.py
@@ -8,6 +8,12 @@ def authenticate(username, password):
if not username or not password:
return None
+ # Check user in database
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
+ cursor.execute(query)
+ user = cursor.fetchone()
+ if user and user.password == password:
+ return user
return None
"""
@pytest.fixture
def sample_policy() -> Policy:
return Policy()
@pytest.fixture
def prompt_registry(tmp_path: Path) -> PromptRegistry:
templates_dir = tmp_path / "templates"
templates_dir.mkdir()
(templates_dir / "security-v1.0.md").write_text(
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
(templates_dir / "style-v1.0.md").write_text(
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
(templates_dir / "complexity-v1.0.md").write_text(
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
)
return PromptRegistry(templates_dir)
# Database fixtures for integration tests
@pytest.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with session_factory() as session:
yield session
@pytest.fixture
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
from arbiter.api.deps import get_db, get_redis
from arbiter.main import app
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
yield db_session
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
yield MockRedis()
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_redis] = override_get_redis
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class MockRedis:
"""Mock Redis client for testing."""
def __init__(self) -> None:
self._data: dict[str, Any] = {}
async def get(self, key: str) -> str | None:
return self._data.get(key)
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
self._data[key] = value
return True
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
return 1
return 0
async def ping(self) -> bool:
return True
async def llen(self, key: str) -> int:
data = self._data.get(key, [])
return len(data) if isinstance(data, list) else 0
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
return type("Job", (), {"job_id": "test-job-id"})()
@pytest.fixture
def mock_redis() -> MockRedis:
return MockRedis()
class MockPlatformClient(PlatformClient):
"""Mock platform client for testing."""
def __init__(self) -> None:
self._comments: list[Comment] = []
self._posted_comments: list[dict[str, Any]] = []
self._status_updates: list[dict[str, Any]] = []
self._diff = "mock diff content"
self._closed = False
self._fail_on: set[str] = set() # Methods to fail on
@property
def platform(self) -> Platform:
return Platform.GITHUB
async def get_pr_diff(
self,
repository: str, # noqa: ARG002
pr_number: int, # noqa: ARG002
) -> str:
if "get_pr_diff" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_pr_diff")
return self._diff
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
if "post_comment" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: post_comment")
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
self._posted_comments.append(
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
)
return comment_url
async def update_commit_status(
self,
repository: str,
sha: str,
status: CommitStatus,
description: str,
context: str,
target_url: str | None = None,
) -> None:
if "update_commit_status" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: update_commit_status")
self._status_updates.append(
{
"repository": repository,
"sha": sha,
"status": status,
"description": description,
"context": context,
"target_url": target_url,
}
)
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
if "get_pr_info" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_pr_info")
from arbiter.integrations.base import PullRequestInfo
return PullRequestInfo(
platform=Platform.GITHUB,
repository=repository,
pr_number=pr_number,
head_sha="abc123",
base_sha="def456",
head_ref="feature",
base_ref="main",
title="Test PR",
author="testuser",
url=f"https://github.com/{repository}/pull/{pr_number}",
is_draft=False,
)
async def get_comments(
self,
repository: str, # noqa: ARG002
pr_number: int, # noqa: ARG002
) -> list[Comment]:
if "get_comments" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: get_comments")
return self._comments
async def update_comment(
self, repository: str, pr_number: int, comment_id: str, body: str
) -> str:
if "update_comment" in self._fail_on:
from arbiter.integrations import IntegrationError
raise IntegrationError("Mock failure: update_comment")
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
self._posted_comments.append(
{
"repository": repository,
"pr_number": pr_number,
"comment_id": comment_id,
"body": body,
"url": comment_url,
}
)
return comment_url
async def close(self) -> None:
self._closed = True
@pytest.fixture
def mock_platform_client() -> MockPlatformClient:
return MockPlatformClient()
@pytest.fixture
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo",
pr_number=42,
pr_title="Test PR",
base_sha="abc1234567890",
head_sha="def0987654321",
author="testuser",
is_draft=False,
status="completed",
verdict=Verdict.COMMENT,
verdict_confidence=0.75,
verdict_reasoning="Found some issues to discuss",
total_tokens=1500,
total_cost_usd=0.015,
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
created_at=datetime.now(UTC),
started_at=datetime.now(UTC),
completed_at=datetime.now(UTC),
)
db_session.add(review)
# Add findings
finding1 = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.SECURITY,
file="src/auth.py",
line_start=10,
line_end=15,
severity=Severity.HIGH,
confidence=0.9,
title="SQL Injection vulnerability",
description="User input is directly concatenated into SQL query",
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
suggestion="Use parameterized queries instead",
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
prompt_version="security-v1.0",
)
finding2 = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.STYLE,
file="src/auth.py",
line_start=20,
line_end=25,
severity=Severity.MEDIUM,
confidence=0.85,
title="Long function",
description="Function exceeds 50 lines",
reasoning="Long functions are harder to test and maintain",
suggestion="Consider breaking into smaller functions",
references=[],
prompt_version="style-v1.0",
)
db_session.add(finding1)
db_session.add(finding2)
await db_session.commit()
await db_session.refresh(review)
return review
@pytest.fixture
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
reviews = []
for i in range(5):
review = ReviewModel(
id=str(uuid4()),
repository="owner/repo" if i < 3 else "other/repo",
pr_number=i + 1,
pr_title=f"Test PR #{i + 1}",
base_sha=f"base{i:040d}",
head_sha=f"head{i:040d}",
author="testuser" if i % 2 == 0 else "otheruser",
is_draft=False,
status="completed" if i < 4 else "failed",
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
verdict_confidence=0.8 if i < 4 else None,
total_tokens=1000 * (i + 1),
total_cost_usd=0.01 * (i + 1),
created_at=datetime.now(UTC),
completed_at=datetime.now(UTC) if i < 4 else None,
)
db_session.add(review)
reviews.append(review)
# Add a finding to each completed review
if i < 4:
finding = FindingModel(
id=str(uuid4()),
review_id=review.id,
agent=AgentName.SECURITY,
file="src/test.py",
line_start=10,
line_end=15,
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
confidence=0.9,
title=f"Finding {i + 1}",
description="Test finding",
reasoning="Test reasoning",
prompt_version="security-v1.0",
)
db_session.add(finding)
await db_session.commit()
for review in reviews:
await db_session.refresh(review)
return reviews
@pytest.fixture
def mock_settings() -> Any:
class MockSecretStr:
def __init__(self, value: str) -> None:
self._value = value
def get_secret_value(self) -> str:
return self._value
class MockSettings:
github_token = MockSecretStr("ghp_test_token")
gitlab_token = MockSecretStr("glpat_test_token")
github_base_url = "https://api.github.com"
gitlab_base_url = "https://gitlab.com/api/v4"
github_webhook_secret = MockSecretStr("webhook_secret")
gitlab_webhook_token = MockSecretStr("gitlab_token")
integration_timeout = 30
integration_max_retries = 3
llm_timeout = 60
llm_max_retries = 3
post_comments = True
update_status = True
status_check_context = "arbiter"
templates_dir = Path("templates")
followup_enabled = True
followup_confidence_threshold = 0.5
return MockSettings()
@pytest.fixture
def mock_settings_no_github() -> Any:
class MockSettings:
github_token = None
gitlab_token = None
github_base_url = "https://api.github.com"
gitlab_base_url = "https://gitlab.com/api/v4"
integration_timeout = 30
integration_max_retries = 3
post_comments = False
update_status = False
status_check_context = "arbiter"
templates_dir = Path("templates")
return MockSettings()