491 lines
15 KiB
Python
491 lines
15 KiB
Python
"""Pytest configuration and fixtures."""
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from arbiter.db.models import Base, FindingModel, ReviewModel
|
|
from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient
|
|
from arbiter.llm.client import LLMClient, LLMResponse
|
|
from arbiter.llm.prompts import PromptRegistry
|
|
from arbiter.models import Policy
|
|
from arbiter.models.enums import AgentName, Severity, Verdict
|
|
|
|
|
|
class MockLLMClient(LLMClient):
|
|
"""Mock LLM client for testing."""
|
|
|
|
def __init__(self, responses: list[str] | None = None) -> None:
|
|
self.responses = responses or []
|
|
self.calls: list[dict[str, Any]] = []
|
|
self._call_index = 0
|
|
|
|
async def complete(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
model: str,
|
|
**kwargs: Any,
|
|
) -> LLMResponse:
|
|
"""Record the call and return a canned response."""
|
|
self.calls.append(
|
|
{
|
|
"messages": messages,
|
|
"model": model,
|
|
"kwargs": kwargs,
|
|
}
|
|
)
|
|
|
|
content = ""
|
|
if self._call_index < len(self.responses):
|
|
content = self.responses[self._call_index]
|
|
self._call_index += 1
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
cost_usd=0.001,
|
|
)
|
|
|
|
def reset(self) -> None:
|
|
self.calls = []
|
|
self._call_index = 0
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm() -> MockLLMClient:
|
|
return MockLLMClient()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_with_findings() -> MockLLMClient:
|
|
response = """```json
|
|
[
|
|
{
|
|
"file": "src/auth.py",
|
|
"line_start": 10,
|
|
"line_end": 15,
|
|
"severity": "high",
|
|
"confidence": 0.9,
|
|
"title": "SQL Injection vulnerability",
|
|
"description": "User input is directly concatenated into SQL query",
|
|
"reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL",
|
|
"suggestion": "Use parameterized queries instead",
|
|
"references": ["https://owasp.org/www-community/attacks/SQL_Injection"]
|
|
}
|
|
]
|
|
```"""
|
|
return MockLLMClient(responses=[response])
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_diff() -> str:
|
|
return """diff --git a/src/auth.py b/src/auth.py
|
|
index 1234567..abcdefg 100644
|
|
--- a/src/auth.py
|
|
+++ b/src/auth.py
|
|
@@ -8,6 +8,12 @@ def authenticate(username, password):
|
|
if not username or not password:
|
|
return None
|
|
|
|
+ # Check user in database
|
|
+ query = "SELECT * FROM users WHERE username = '" + username + "'"
|
|
+ cursor.execute(query)
|
|
+ user = cursor.fetchone()
|
|
+ if user and user.password == password:
|
|
+ return user
|
|
return None
|
|
"""
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_policy() -> Policy:
|
|
return Policy()
|
|
|
|
|
|
@pytest.fixture
|
|
def prompt_registry(tmp_path: Path) -> PromptRegistry:
|
|
templates_dir = tmp_path / "templates"
|
|
templates_dir.mkdir()
|
|
|
|
(templates_dir / "security-v1.0.md").write_text(
|
|
"Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
|
)
|
|
(templates_dir / "style-v1.0.md").write_text(
|
|
"Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
|
)
|
|
(templates_dir / "complexity-v1.0.md").write_text(
|
|
"Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}"
|
|
)
|
|
|
|
return PromptRegistry(templates_dir)
|
|
|
|
|
|
# Database fixtures for integration tests
|
|
@pytest.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
|
session_factory = async_sessionmaker(
|
|
bind=async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
async with session_factory() as session:
|
|
yield session
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
|
from arbiter.api.deps import get_db, get_redis
|
|
from arbiter.main import app
|
|
|
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
yield db_session
|
|
|
|
async def override_get_redis() -> AsyncGenerator[MockRedis, None]:
|
|
yield MockRedis()
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_redis] = override_get_redis
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class MockRedis:
|
|
"""Mock Redis client for testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self._data: dict[str, Any] = {}
|
|
|
|
async def get(self, key: str) -> str | None:
|
|
return self._data.get(key)
|
|
|
|
async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002
|
|
self._data[key] = value
|
|
return True
|
|
|
|
async def delete(self, key: str) -> int:
|
|
if key in self._data:
|
|
del self._data[key]
|
|
return 1
|
|
return 0
|
|
|
|
async def ping(self) -> bool:
|
|
return True
|
|
|
|
async def llen(self, key: str) -> int:
|
|
data = self._data.get(key, [])
|
|
return len(data) if isinstance(data, list) else 0
|
|
|
|
async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002
|
|
return type("Job", (), {"job_id": "test-job-id"})()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_redis() -> MockRedis:
|
|
return MockRedis()
|
|
|
|
|
|
class MockPlatformClient(PlatformClient):
|
|
"""Mock platform client for testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self._comments: list[Comment] = []
|
|
self._posted_comments: list[dict[str, Any]] = []
|
|
self._status_updates: list[dict[str, Any]] = []
|
|
self._diff = "mock diff content"
|
|
self._closed = False
|
|
self._fail_on: set[str] = set() # Methods to fail on
|
|
|
|
@property
|
|
def platform(self) -> Platform:
|
|
return Platform.GITHUB
|
|
|
|
async def get_pr_diff(
|
|
self,
|
|
repository: str, # noqa: ARG002
|
|
pr_number: int, # noqa: ARG002
|
|
) -> str:
|
|
if "get_pr_diff" in self._fail_on:
|
|
from arbiter.integrations import IntegrationError
|
|
|
|
raise IntegrationError("Mock failure: get_pr_diff")
|
|
return self._diff
|
|
|
|
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
|
if "post_comment" in self._fail_on:
|
|
from arbiter.integrations import IntegrationError
|
|
|
|
raise IntegrationError("Mock failure: post_comment")
|
|
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123"
|
|
self._posted_comments.append(
|
|
{"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url}
|
|
)
|
|
return comment_url
|
|
|
|
async def update_commit_status(
|
|
self,
|
|
repository: str,
|
|
sha: str,
|
|
status: CommitStatus,
|
|
description: str,
|
|
context: str,
|
|
target_url: str | None = None,
|
|
) -> None:
|
|
if "update_commit_status" in self._fail_on:
|
|
from arbiter.integrations import IntegrationError
|
|
|
|
raise IntegrationError("Mock failure: update_commit_status")
|
|
self._status_updates.append(
|
|
{
|
|
"repository": repository,
|
|
"sha": sha,
|
|
"status": status,
|
|
"description": description,
|
|
"context": context,
|
|
"target_url": target_url,
|
|
}
|
|
)
|
|
|
|
async def get_pr_info(self, repository: str, pr_number: int) -> Any:
|
|
if "get_pr_info" in self._fail_on:
|
|
from arbiter.integrations import IntegrationError
|
|
|
|
raise IntegrationError("Mock failure: get_pr_info")
|
|
from arbiter.integrations.base import PullRequestInfo
|
|
|
|
return PullRequestInfo(
|
|
platform=Platform.GITHUB,
|
|
repository=repository,
|
|
pr_number=pr_number,
|
|
head_sha="abc123",
|
|
base_sha="def456",
|
|
head_ref="feature",
|
|
base_ref="main",
|
|
title="Test PR",
|
|
author="testuser",
|
|
url=f"https://github.com/{repository}/pull/{pr_number}",
|
|
is_draft=False,
|
|
)
|
|
|
|
async def get_comments(
|
|
self,
|
|
repository: str, # noqa: ARG002
|
|
pr_number: int, # noqa: ARG002
|
|
) -> list[Comment]:
|
|
if "get_comments" in self._fail_on:
|
|
from arbiter.integrations import IntegrationError
|
|
|
|
raise IntegrationError("Mock failure: get_comments")
|
|
return self._comments
|
|
|
|
async def update_comment(
|
|
self, repository: str, pr_number: int, comment_id: str, body: str
|
|
) -> str:
|
|
if "update_comment" in self._fail_on:
|
|
from arbiter.integrations import IntegrationError
|
|
|
|
raise IntegrationError("Mock failure: update_comment")
|
|
comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}"
|
|
self._posted_comments.append(
|
|
{
|
|
"repository": repository,
|
|
"pr_number": pr_number,
|
|
"comment_id": comment_id,
|
|
"body": body,
|
|
"url": comment_url,
|
|
}
|
|
)
|
|
return comment_url
|
|
|
|
async def close(self) -> None:
|
|
self._closed = True
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_platform_client() -> MockPlatformClient:
|
|
return MockPlatformClient()
|
|
|
|
|
|
@pytest.fixture
|
|
async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel:
|
|
review = ReviewModel(
|
|
id=str(uuid4()),
|
|
repository="owner/repo",
|
|
pr_number=42,
|
|
pr_title="Test PR",
|
|
base_sha="abc1234567890",
|
|
head_sha="def0987654321",
|
|
author="testuser",
|
|
is_draft=False,
|
|
status="completed",
|
|
verdict=Verdict.COMMENT,
|
|
verdict_confidence=0.75,
|
|
verdict_reasoning="Found some issues to discuss",
|
|
total_tokens=1500,
|
|
total_cost_usd=0.015,
|
|
tokens_by_agent={"security": 500, "style": 500, "complexity": 500},
|
|
cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005},
|
|
created_at=datetime.now(UTC),
|
|
started_at=datetime.now(UTC),
|
|
completed_at=datetime.now(UTC),
|
|
)
|
|
db_session.add(review)
|
|
|
|
# Add findings
|
|
finding1 = FindingModel(
|
|
id=str(uuid4()),
|
|
review_id=review.id,
|
|
agent=AgentName.SECURITY,
|
|
file="src/auth.py",
|
|
line_start=10,
|
|
line_end=15,
|
|
severity=Severity.HIGH,
|
|
confidence=0.9,
|
|
title="SQL Injection vulnerability",
|
|
description="User input is directly concatenated into SQL query",
|
|
reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL",
|
|
suggestion="Use parameterized queries instead",
|
|
references=["https://owasp.org/www-community/attacks/SQL_Injection"],
|
|
prompt_version="security-v1.0",
|
|
)
|
|
finding2 = FindingModel(
|
|
id=str(uuid4()),
|
|
review_id=review.id,
|
|
agent=AgentName.STYLE,
|
|
file="src/auth.py",
|
|
line_start=20,
|
|
line_end=25,
|
|
severity=Severity.MEDIUM,
|
|
confidence=0.85,
|
|
title="Long function",
|
|
description="Function exceeds 50 lines",
|
|
reasoning="Long functions are harder to test and maintain",
|
|
suggestion="Consider breaking into smaller functions",
|
|
references=[],
|
|
prompt_version="style-v1.0",
|
|
)
|
|
db_session.add(finding1)
|
|
db_session.add(finding2)
|
|
|
|
await db_session.commit()
|
|
await db_session.refresh(review)
|
|
return review
|
|
|
|
|
|
@pytest.fixture
|
|
async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]:
|
|
reviews = []
|
|
for i in range(5):
|
|
review = ReviewModel(
|
|
id=str(uuid4()),
|
|
repository="owner/repo" if i < 3 else "other/repo",
|
|
pr_number=i + 1,
|
|
pr_title=f"Test PR #{i + 1}",
|
|
base_sha=f"base{i:040d}",
|
|
head_sha=f"head{i:040d}",
|
|
author="testuser" if i % 2 == 0 else "otheruser",
|
|
is_draft=False,
|
|
status="completed" if i < 4 else "failed",
|
|
verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None),
|
|
verdict_confidence=0.8 if i < 4 else None,
|
|
total_tokens=1000 * (i + 1),
|
|
total_cost_usd=0.01 * (i + 1),
|
|
created_at=datetime.now(UTC),
|
|
completed_at=datetime.now(UTC) if i < 4 else None,
|
|
)
|
|
db_session.add(review)
|
|
reviews.append(review)
|
|
|
|
# Add a finding to each completed review
|
|
if i < 4:
|
|
finding = FindingModel(
|
|
id=str(uuid4()),
|
|
review_id=review.id,
|
|
agent=AgentName.SECURITY,
|
|
file="src/test.py",
|
|
line_start=10,
|
|
line_end=15,
|
|
severity=Severity.CRITICAL if i == 0 else Severity.HIGH,
|
|
confidence=0.9,
|
|
title=f"Finding {i + 1}",
|
|
description="Test finding",
|
|
reasoning="Test reasoning",
|
|
prompt_version="security-v1.0",
|
|
)
|
|
db_session.add(finding)
|
|
|
|
await db_session.commit()
|
|
for review in reviews:
|
|
await db_session.refresh(review)
|
|
return reviews
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_settings() -> Any:
|
|
class MockSecretStr:
|
|
def __init__(self, value: str) -> None:
|
|
self._value = value
|
|
|
|
def get_secret_value(self) -> str:
|
|
return self._value
|
|
|
|
class MockSettings:
|
|
github_token = MockSecretStr("ghp_test_token")
|
|
gitlab_token = MockSecretStr("glpat_test_token")
|
|
github_base_url = "https://api.github.com"
|
|
gitlab_base_url = "https://gitlab.com/api/v4"
|
|
github_webhook_secret = MockSecretStr("webhook_secret")
|
|
gitlab_webhook_token = MockSecretStr("gitlab_token")
|
|
integration_timeout = 30
|
|
integration_max_retries = 3
|
|
llm_timeout = 60
|
|
llm_max_retries = 3
|
|
post_comments = True
|
|
update_status = True
|
|
status_check_context = "arbiter"
|
|
templates_dir = Path("templates")
|
|
followup_enabled = True
|
|
followup_confidence_threshold = 0.5
|
|
|
|
return MockSettings()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_settings_no_github() -> Any:
|
|
class MockSettings:
|
|
github_token = None
|
|
gitlab_token = None
|
|
github_base_url = "https://api.github.com"
|
|
gitlab_base_url = "https://gitlab.com/api/v4"
|
|
integration_timeout = 30
|
|
integration_max_retries = 3
|
|
post_comments = False
|
|
update_status = False
|
|
status_check_context = "arbiter"
|
|
templates_dir = Path("templates")
|
|
|
|
return MockSettings()
|