"""Tests for database session management.""" import pytest from sqlalchemy.ext.asyncio import AsyncSession from arbiter.db import session as session_module class TestDatabaseSession: def test_reset_engine_clears_globals(self) -> None: # First ensure the module level globals are in a known state session_module.reset_engine() # Verify they are cleared assert session_module._engine is None assert session_module._session_factory is None async def test_get_async_session_yields_session( self, async_engine, ) -> None: # We need to use the test engine, not the production one # This test uses the async_engine fixture which creates an in-memory DB from sqlalchemy import text from sqlalchemy.ext.asyncio import async_sessionmaker # Set up the test engine in the module session_module._engine = async_engine session_module._session_factory = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, ) try: async for db_session in session_module.get_async_session(): assert isinstance(db_session, AsyncSession) # Session should be usable result = await db_session.execute(text("SELECT 1")) assert result is not None finally: session_module.reset_engine() async def test_get_async_session_commits_on_success( self, async_engine, ) -> None: from sqlalchemy import text from sqlalchemy.ext.asyncio import async_sessionmaker from arbiter.db.models import ReviewModel # Set up the test engine session_module._engine = async_engine session_module._session_factory = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, ) try: # Create a review inside the session context async for db_session in session_module.get_async_session(): review = ReviewModel( repository="test/repo", pr_number=1, base_sha="abc123", head_sha="def456", status="pending", ) db_session.add(review) # Don't manually commit - let the context manager do it # Verify the data was committed by querying with a new session async with async_sessionmaker( bind=async_engine, class_=AsyncSession )() as verify_session: result = await verify_session.execute(text("SELECT COUNT(*) FROM reviews")) count = result.scalar() assert count == 1 finally: session_module.reset_engine() async def test_get_async_session_propagates_error( self, async_engine, ) -> None: from sqlalchemy.ext.asyncio import async_sessionmaker from arbiter.db.models import ReviewModel session_module._engine = async_engine session_module._session_factory = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, ) try: # The error should be re-raised after rollback is called with pytest.raises(ValueError, match="Test error"): async for db_session in session_module.get_async_session(): review = ReviewModel( repository="test/repo", pr_number=99, base_sha="abc123", head_sha="def456", status="pending", ) db_session.add(review) raise ValueError("Test error") # The session context manager handled the error correctly # (exception was propagated, rollback was called internally) finally: session_module.reset_engine() async def test_init_db_creates_tables(self, async_engine) -> None: from sqlalchemy import text from sqlalchemy.ext.asyncio import async_sessionmaker # Set up the engine (tables already created by fixture, but let's verify) session_module._engine = async_engine session_module._session_factory = None try: # Call init_db await session_module.init_db() # Verify tables exist async with async_sessionmaker(bind=async_engine, class_=AsyncSession)() as session: # Try to query the reviews table result = await session.execute(text("SELECT 1 FROM reviews LIMIT 1")) # If no error, table exists assert result is not None finally: session_module.reset_engine() async def test_close_db_disposes_engine(self, async_engine) -> None: session_module._engine = async_engine session_module._session_factory = "not none" # type: ignore await session_module.close_db() assert session_module._engine is None assert session_module._session_factory is None def test_session_factory_create(self) -> None: # Reset to ensure clean state session_module.reset_engine() # This will fail without proper settings, so we just test the path # where the factory already exists session_module._session_factory = "mock_factory" # type: ignore result = session_module.async_session_factory() assert result == "mock_factory" session_module.reset_engine()