diff --git a/tests/test_db_session.py b/tests/test_db_session.py new file mode 100644 index 0000000..258a3f9 --- /dev/null +++ b/tests/test_db_session.py @@ -0,0 +1,159 @@ +"""Tests for database session management.""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from arbiter.db import session as session_module + + +class TestDatabaseSession: + def test_reset_engine_clears_globals(self) -> None: + # First ensure the module level globals are in a known state + session_module.reset_engine() + + # Verify they are cleared + assert session_module._engine is None + assert session_module._session_factory is None + + async def test_get_async_session_yields_session( + self, + async_engine, + ) -> None: + # We need to use the test engine, not the production one + # This test uses the async_engine fixture which creates an in-memory DB + from sqlalchemy import text + from sqlalchemy.ext.asyncio import async_sessionmaker + + # Set up the test engine in the module + session_module._engine = async_engine + session_module._session_factory = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + try: + async for db_session in session_module.get_async_session(): + assert isinstance(db_session, AsyncSession) + # Session should be usable + result = await db_session.execute(text("SELECT 1")) + assert result is not None + finally: + session_module.reset_engine() + + async def test_get_async_session_commits_on_success( + self, + async_engine, + ) -> None: + from sqlalchemy import text + from sqlalchemy.ext.asyncio import async_sessionmaker + + from arbiter.db.models import ReviewModel + + # Set up the test engine + session_module._engine = async_engine + session_module._session_factory = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + try: + # Create a review inside the session context + async for db_session in session_module.get_async_session(): + review = ReviewModel( + repository="test/repo", + pr_number=1, + base_sha="abc123", + head_sha="def456", + status="pending", + ) + db_session.add(review) + # Don't manually commit - let the context manager do it + + # Verify the data was committed by querying with a new session + async with async_sessionmaker( + bind=async_engine, class_=AsyncSession + )() as verify_session: + result = await verify_session.execute(text("SELECT COUNT(*) FROM reviews")) + count = result.scalar() + assert count == 1 + finally: + session_module.reset_engine() + + async def test_get_async_session_propagates_error( + self, + async_engine, + ) -> None: + from sqlalchemy.ext.asyncio import async_sessionmaker + + from arbiter.db.models import ReviewModel + + session_module._engine = async_engine + session_module._session_factory = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + try: + # The error should be re-raised after rollback is called + with pytest.raises(ValueError, match="Test error"): + async for db_session in session_module.get_async_session(): + review = ReviewModel( + repository="test/repo", + pr_number=99, + base_sha="abc123", + head_sha="def456", + status="pending", + ) + db_session.add(review) + raise ValueError("Test error") + + # The session context manager handled the error correctly + # (exception was propagated, rollback was called internally) + finally: + session_module.reset_engine() + + async def test_init_db_creates_tables(self, async_engine) -> None: + from sqlalchemy import text + from sqlalchemy.ext.asyncio import async_sessionmaker + + # Set up the engine (tables already created by fixture, but let's verify) + session_module._engine = async_engine + session_module._session_factory = None + + try: + # Call init_db + await session_module.init_db() + + # Verify tables exist + async with async_sessionmaker(bind=async_engine, class_=AsyncSession)() as session: + # Try to query the reviews table + result = await session.execute(text("SELECT 1 FROM reviews LIMIT 1")) + # If no error, table exists + assert result is not None + finally: + session_module.reset_engine() + + async def test_close_db_disposes_engine(self, async_engine) -> None: + session_module._engine = async_engine + session_module._session_factory = "not none" # type: ignore + + await session_module.close_db() + + assert session_module._engine is None + assert session_module._session_factory is None + + def test_session_factory_create(self) -> None: + # Reset to ensure clean state + session_module.reset_engine() + + # This will fail without proper settings, so we just test the path + # where the factory already exists + session_module._session_factory = "mock_factory" # type: ignore + + result = session_module.async_session_factory() + assert result == "mock_factory" + + session_module.reset_engine()