test db session management

This commit is contained in:
2025-06-01 15:10:57 +00:00
parent e4f814efff
commit 46ce30476c

159
tests/test_db_session.py Normal file
View File

@@ -0,0 +1,159 @@
"""Tests for database session management."""
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from arbiter.db import session as session_module
class TestDatabaseSession:
def test_reset_engine_clears_globals(self) -> None:
# First ensure the module level globals are in a known state
session_module.reset_engine()
# Verify they are cleared
assert session_module._engine is None
assert session_module._session_factory is None
async def test_get_async_session_yields_session(
self,
async_engine,
) -> None:
# We need to use the test engine, not the production one
# This test uses the async_engine fixture which creates an in-memory DB
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker
# Set up the test engine in the module
session_module._engine = async_engine
session_module._session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
try:
async for db_session in session_module.get_async_session():
assert isinstance(db_session, AsyncSession)
# Session should be usable
result = await db_session.execute(text("SELECT 1"))
assert result is not None
finally:
session_module.reset_engine()
async def test_get_async_session_commits_on_success(
self,
async_engine,
) -> None:
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker
from arbiter.db.models import ReviewModel
# Set up the test engine
session_module._engine = async_engine
session_module._session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
try:
# Create a review inside the session context
async for db_session in session_module.get_async_session():
review = ReviewModel(
repository="test/repo",
pr_number=1,
base_sha="abc123",
head_sha="def456",
status="pending",
)
db_session.add(review)
# Don't manually commit - let the context manager do it
# Verify the data was committed by querying with a new session
async with async_sessionmaker(
bind=async_engine, class_=AsyncSession
)() as verify_session:
result = await verify_session.execute(text("SELECT COUNT(*) FROM reviews"))
count = result.scalar()
assert count == 1
finally:
session_module.reset_engine()
async def test_get_async_session_propagates_error(
self,
async_engine,
) -> None:
from sqlalchemy.ext.asyncio import async_sessionmaker
from arbiter.db.models import ReviewModel
session_module._engine = async_engine
session_module._session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
try:
# The error should be re-raised after rollback is called
with pytest.raises(ValueError, match="Test error"):
async for db_session in session_module.get_async_session():
review = ReviewModel(
repository="test/repo",
pr_number=99,
base_sha="abc123",
head_sha="def456",
status="pending",
)
db_session.add(review)
raise ValueError("Test error")
# The session context manager handled the error correctly
# (exception was propagated, rollback was called internally)
finally:
session_module.reset_engine()
async def test_init_db_creates_tables(self, async_engine) -> None:
from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker
# Set up the engine (tables already created by fixture, but let's verify)
session_module._engine = async_engine
session_module._session_factory = None
try:
# Call init_db
await session_module.init_db()
# Verify tables exist
async with async_sessionmaker(bind=async_engine, class_=AsyncSession)() as session:
# Try to query the reviews table
result = await session.execute(text("SELECT 1 FROM reviews LIMIT 1"))
# If no error, table exists
assert result is not None
finally:
session_module.reset_engine()
async def test_close_db_disposes_engine(self, async_engine) -> None:
session_module._engine = async_engine
session_module._session_factory = "not none" # type: ignore
await session_module.close_db()
assert session_module._engine is None
assert session_module._session_factory is None
def test_session_factory_create(self) -> None:
# Reset to ensure clean state
session_module.reset_engine()
# This will fail without proper settings, so we just test the path
# where the factory already exists
session_module._session_factory = "mock_factory" # type: ignore
result = session_module.async_session_factory()
assert result == "mock_factory"
session_module.reset_engine()