test db session management
This commit is contained in:
159
tests/test_db_session.py
Normal file
159
tests/test_db_session.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Tests for database session management."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from arbiter.db import session as session_module
|
||||
|
||||
|
||||
class TestDatabaseSession:
|
||||
def test_reset_engine_clears_globals(self) -> None:
|
||||
# First ensure the module level globals are in a known state
|
||||
session_module.reset_engine()
|
||||
|
||||
# Verify they are cleared
|
||||
assert session_module._engine is None
|
||||
assert session_module._session_factory is None
|
||||
|
||||
async def test_get_async_session_yields_session(
|
||||
self,
|
||||
async_engine,
|
||||
) -> None:
|
||||
# We need to use the test engine, not the production one
|
||||
# This test uses the async_engine fixture which creates an in-memory DB
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
# Set up the test engine in the module
|
||||
session_module._engine = async_engine
|
||||
session_module._session_factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
try:
|
||||
async for db_session in session_module.get_async_session():
|
||||
assert isinstance(db_session, AsyncSession)
|
||||
# Session should be usable
|
||||
result = await db_session.execute(text("SELECT 1"))
|
||||
assert result is not None
|
||||
finally:
|
||||
session_module.reset_engine()
|
||||
|
||||
async def test_get_async_session_commits_on_success(
|
||||
self,
|
||||
async_engine,
|
||||
) -> None:
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from arbiter.db.models import ReviewModel
|
||||
|
||||
# Set up the test engine
|
||||
session_module._engine = async_engine
|
||||
session_module._session_factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create a review inside the session context
|
||||
async for db_session in session_module.get_async_session():
|
||||
review = ReviewModel(
|
||||
repository="test/repo",
|
||||
pr_number=1,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(review)
|
||||
# Don't manually commit - let the context manager do it
|
||||
|
||||
# Verify the data was committed by querying with a new session
|
||||
async with async_sessionmaker(
|
||||
bind=async_engine, class_=AsyncSession
|
||||
)() as verify_session:
|
||||
result = await verify_session.execute(text("SELECT COUNT(*) FROM reviews"))
|
||||
count = result.scalar()
|
||||
assert count == 1
|
||||
finally:
|
||||
session_module.reset_engine()
|
||||
|
||||
async def test_get_async_session_propagates_error(
|
||||
self,
|
||||
async_engine,
|
||||
) -> None:
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
from arbiter.db.models import ReviewModel
|
||||
|
||||
session_module._engine = async_engine
|
||||
session_module._session_factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
try:
|
||||
# The error should be re-raised after rollback is called
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
async for db_session in session_module.get_async_session():
|
||||
review = ReviewModel(
|
||||
repository="test/repo",
|
||||
pr_number=99,
|
||||
base_sha="abc123",
|
||||
head_sha="def456",
|
||||
status="pending",
|
||||
)
|
||||
db_session.add(review)
|
||||
raise ValueError("Test error")
|
||||
|
||||
# The session context manager handled the error correctly
|
||||
# (exception was propagated, rollback was called internally)
|
||||
finally:
|
||||
session_module.reset_engine()
|
||||
|
||||
async def test_init_db_creates_tables(self, async_engine) -> None:
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
# Set up the engine (tables already created by fixture, but let's verify)
|
||||
session_module._engine = async_engine
|
||||
session_module._session_factory = None
|
||||
|
||||
try:
|
||||
# Call init_db
|
||||
await session_module.init_db()
|
||||
|
||||
# Verify tables exist
|
||||
async with async_sessionmaker(bind=async_engine, class_=AsyncSession)() as session:
|
||||
# Try to query the reviews table
|
||||
result = await session.execute(text("SELECT 1 FROM reviews LIMIT 1"))
|
||||
# If no error, table exists
|
||||
assert result is not None
|
||||
finally:
|
||||
session_module.reset_engine()
|
||||
|
||||
async def test_close_db_disposes_engine(self, async_engine) -> None:
|
||||
session_module._engine = async_engine
|
||||
session_module._session_factory = "not none" # type: ignore
|
||||
|
||||
await session_module.close_db()
|
||||
|
||||
assert session_module._engine is None
|
||||
assert session_module._session_factory is None
|
||||
|
||||
def test_session_factory_create(self) -> None:
|
||||
# Reset to ensure clean state
|
||||
session_module.reset_engine()
|
||||
|
||||
# This will fail without proper settings, so we just test the path
|
||||
# where the factory already exists
|
||||
session_module._session_factory = "mock_factory" # type: ignore
|
||||
|
||||
result = session_module.async_session_factory()
|
||||
assert result == "mock_factory"
|
||||
|
||||
session_module.reset_engine()
|
||||
Reference in New Issue
Block a user