"""Pytest configuration and fixtures.""" from collections.abc import AsyncGenerator from datetime import UTC, datetime from pathlib import Path from typing import Any from uuid import uuid4 import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from arbiter.db.models import Base, FindingModel, ReviewModel from arbiter.integrations.base import Comment, CommitStatus, Platform, PlatformClient from arbiter.llm.client import LLMClient, LLMResponse from arbiter.llm.prompts import PromptRegistry from arbiter.models import Policy from arbiter.models.enums import AgentName, Severity, Verdict class MockLLMClient(LLMClient): """Mock LLM client for testing.""" def __init__(self, responses: list[str] | None = None) -> None: self.responses = responses or [] self.calls: list[dict[str, Any]] = [] self._call_index = 0 async def complete( self, messages: list[dict[str, str]], model: str, **kwargs: Any, ) -> LLMResponse: """Record the call and return a canned response.""" self.calls.append( { "messages": messages, "model": model, "kwargs": kwargs, } ) content = "" if self._call_index < len(self.responses): content = self.responses[self._call_index] self._call_index += 1 return LLMResponse( content=content, model=model, tokens_in=100, tokens_out=50, cost_usd=0.001, ) def reset(self) -> None: self.calls = [] self._call_index = 0 @pytest.fixture def mock_llm() -> MockLLMClient: return MockLLMClient() @pytest.fixture def mock_llm_with_findings() -> MockLLMClient: response = """```json [ { "file": "src/auth.py", "line_start": 10, "line_end": 15, "severity": "high", "confidence": 0.9, "title": "SQL Injection vulnerability", "description": "User input is directly concatenated into SQL query", "reasoning": "String concatenation in SQL queries allows attackers to inject malicious SQL", "suggestion": "Use parameterized queries instead", "references": ["https://owasp.org/www-community/attacks/SQL_Injection"] } ] ```""" return MockLLMClient(responses=[response]) @pytest.fixture def sample_diff() -> str: return """diff --git a/src/auth.py b/src/auth.py index 1234567..abcdefg 100644 --- a/src/auth.py +++ b/src/auth.py @@ -8,6 +8,12 @@ def authenticate(username, password): if not username or not password: return None + # Check user in database + query = "SELECT * FROM users WHERE username = '" + username + "'" + cursor.execute(query) + user = cursor.fetchone() + if user and user.password == password: + return user return None """ @pytest.fixture def sample_policy() -> Policy: return Policy() @pytest.fixture def prompt_registry(tmp_path: Path) -> PromptRegistry: templates_dir = tmp_path / "templates" templates_dir.mkdir() (templates_dir / "security-v1.0.md").write_text( "Review for security: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}" ) (templates_dir / "style-v1.0.md").write_text( "Review for style: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}" ) (templates_dir / "complexity-v1.0.md").write_text( "Review for complexity: {{diff}}\n{{prompt_additions}}\n{{static_analysis_context}}" ) return PromptRegistry(templates_dir) # Database fixtures for integration tests @pytest.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=False, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest.fixture async def db_session(async_engine) -> AsyncGenerator[AsyncSession, None]: session_factory = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, ) async with session_factory() as session: yield session @pytest.fixture async def test_client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: from arbiter.api.deps import get_db, get_redis from arbiter.main import app async def override_get_db() -> AsyncGenerator[AsyncSession, None]: yield db_session async def override_get_redis() -> AsyncGenerator[MockRedis, None]: yield MockRedis() app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_redis] = override_get_redis transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() class MockRedis: """Mock Redis client for testing.""" def __init__(self) -> None: self._data: dict[str, Any] = {} async def get(self, key: str) -> str | None: return self._data.get(key) async def set(self, key: str, value: str, ex: int | None = None) -> bool: # noqa: ARG002 self._data[key] = value return True async def delete(self, key: str) -> int: if key in self._data: del self._data[key] return 1 return 0 async def ping(self) -> bool: return True async def llen(self, key: str) -> int: data = self._data.get(key, []) return len(data) if isinstance(data, list) else 0 async def enqueue_job(self, func_name: str, **kwargs: Any) -> Any: # noqa: ARG002 return type("Job", (), {"job_id": "test-job-id"})() @pytest.fixture def mock_redis() -> MockRedis: return MockRedis() class MockPlatformClient(PlatformClient): """Mock platform client for testing.""" def __init__(self) -> None: self._comments: list[Comment] = [] self._posted_comments: list[dict[str, Any]] = [] self._status_updates: list[dict[str, Any]] = [] self._diff = "mock diff content" self._closed = False self._fail_on: set[str] = set() # Methods to fail on @property def platform(self) -> Platform: return Platform.GITHUB async def get_pr_diff( self, repository: str, # noqa: ARG002 pr_number: int, # noqa: ARG002 ) -> str: if "get_pr_diff" in self._fail_on: from arbiter.integrations import IntegrationError raise IntegrationError("Mock failure: get_pr_diff") return self._diff async def post_comment(self, repository: str, pr_number: int, body: str) -> str: if "post_comment" in self._fail_on: from arbiter.integrations import IntegrationError raise IntegrationError("Mock failure: post_comment") comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-123" self._posted_comments.append( {"repository": repository, "pr_number": pr_number, "body": body, "url": comment_url} ) return comment_url async def update_commit_status( self, repository: str, sha: str, status: CommitStatus, description: str, context: str, target_url: str | None = None, ) -> None: if "update_commit_status" in self._fail_on: from arbiter.integrations import IntegrationError raise IntegrationError("Mock failure: update_commit_status") self._status_updates.append( { "repository": repository, "sha": sha, "status": status, "description": description, "context": context, "target_url": target_url, } ) async def get_pr_info(self, repository: str, pr_number: int) -> Any: if "get_pr_info" in self._fail_on: from arbiter.integrations import IntegrationError raise IntegrationError("Mock failure: get_pr_info") from arbiter.integrations.base import PullRequestInfo return PullRequestInfo( platform=Platform.GITHUB, repository=repository, pr_number=pr_number, head_sha="abc123", base_sha="def456", head_ref="feature", base_ref="main", title="Test PR", author="testuser", url=f"https://github.com/{repository}/pull/{pr_number}", is_draft=False, ) async def get_comments( self, repository: str, # noqa: ARG002 pr_number: int, # noqa: ARG002 ) -> list[Comment]: if "get_comments" in self._fail_on: from arbiter.integrations import IntegrationError raise IntegrationError("Mock failure: get_comments") return self._comments async def update_comment( self, repository: str, pr_number: int, comment_id: str, body: str ) -> str: if "update_comment" in self._fail_on: from arbiter.integrations import IntegrationError raise IntegrationError("Mock failure: update_comment") comment_url = f"https://github.com/{repository}/pull/{pr_number}#comment-{comment_id}" self._posted_comments.append( { "repository": repository, "pr_number": pr_number, "comment_id": comment_id, "body": body, "url": comment_url, } ) return comment_url async def close(self) -> None: self._closed = True @pytest.fixture def mock_platform_client() -> MockPlatformClient: return MockPlatformClient() @pytest.fixture async def completed_review_fixture(db_session: AsyncSession) -> ReviewModel: review = ReviewModel( id=str(uuid4()), repository="owner/repo", pr_number=42, pr_title="Test PR", base_sha="abc1234567890", head_sha="def0987654321", author="testuser", is_draft=False, status="completed", verdict=Verdict.COMMENT, verdict_confidence=0.75, verdict_reasoning="Found some issues to discuss", total_tokens=1500, total_cost_usd=0.015, tokens_by_agent={"security": 500, "style": 500, "complexity": 500}, cost_by_agent={"security": 0.005, "style": 0.005, "complexity": 0.005}, created_at=datetime.now(UTC), started_at=datetime.now(UTC), completed_at=datetime.now(UTC), ) db_session.add(review) # Add findings finding1 = FindingModel( id=str(uuid4()), review_id=review.id, agent=AgentName.SECURITY, file="src/auth.py", line_start=10, line_end=15, severity=Severity.HIGH, confidence=0.9, title="SQL Injection vulnerability", description="User input is directly concatenated into SQL query", reasoning="String concatenation in SQL queries allows attackers to inject malicious SQL", suggestion="Use parameterized queries instead", references=["https://owasp.org/www-community/attacks/SQL_Injection"], prompt_version="security-v1.0", ) finding2 = FindingModel( id=str(uuid4()), review_id=review.id, agent=AgentName.STYLE, file="src/auth.py", line_start=20, line_end=25, severity=Severity.MEDIUM, confidence=0.85, title="Long function", description="Function exceeds 50 lines", reasoning="Long functions are harder to test and maintain", suggestion="Consider breaking into smaller functions", references=[], prompt_version="style-v1.0", ) db_session.add(finding1) db_session.add(finding2) await db_session.commit() await db_session.refresh(review) return review @pytest.fixture async def sample_reviews_fixture(db_session: AsyncSession) -> list[ReviewModel]: reviews = [] for i in range(5): review = ReviewModel( id=str(uuid4()), repository="owner/repo" if i < 3 else "other/repo", pr_number=i + 1, pr_title=f"Test PR #{i + 1}", base_sha=f"base{i:040d}", head_sha=f"head{i:040d}", author="testuser" if i % 2 == 0 else "otheruser", is_draft=False, status="completed" if i < 4 else "failed", verdict=Verdict.APPROVE if i == 0 else (Verdict.COMMENT if i < 4 else None), verdict_confidence=0.8 if i < 4 else None, total_tokens=1000 * (i + 1), total_cost_usd=0.01 * (i + 1), created_at=datetime.now(UTC), completed_at=datetime.now(UTC) if i < 4 else None, ) db_session.add(review) reviews.append(review) # Add a finding to each completed review if i < 4: finding = FindingModel( id=str(uuid4()), review_id=review.id, agent=AgentName.SECURITY, file="src/test.py", line_start=10, line_end=15, severity=Severity.CRITICAL if i == 0 else Severity.HIGH, confidence=0.9, title=f"Finding {i + 1}", description="Test finding", reasoning="Test reasoning", prompt_version="security-v1.0", ) db_session.add(finding) await db_session.commit() for review in reviews: await db_session.refresh(review) return reviews @pytest.fixture def mock_settings() -> Any: class MockSecretStr: def __init__(self, value: str) -> None: self._value = value def get_secret_value(self) -> str: return self._value class MockSettings: github_token = MockSecretStr("ghp_test_token") gitlab_token = MockSecretStr("glpat_test_token") github_base_url = "https://api.github.com" gitlab_base_url = "https://gitlab.com/api/v4" github_webhook_secret = MockSecretStr("webhook_secret") gitlab_webhook_token = MockSecretStr("gitlab_token") integration_timeout = 30 integration_max_retries = 3 llm_timeout = 60 llm_max_retries = 3 post_comments = True update_status = True status_check_context = "arbiter" templates_dir = Path("templates") followup_enabled = True followup_confidence_threshold = 0.5 return MockSettings() @pytest.fixture def mock_settings_no_github() -> Any: class MockSettings: github_token = None gitlab_token = None github_base_url = "https://api.github.com" gitlab_base_url = "https://gitlab.com/api/v4" integration_timeout = 30 integration_max_retries = 3 post_comments = False update_status = False status_check_context = "arbiter" templates_dir = Path("templates") return MockSettings()