160 lines
5.6 KiB
Python
160 lines
5.6 KiB
Python
"""Tests for database session management."""
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from arbiter.db import session as session_module
|
|
|
|
|
|
class TestDatabaseSession:
|
|
def test_reset_engine_clears_globals(self) -> None:
|
|
# First ensure the module level globals are in a known state
|
|
session_module.reset_engine()
|
|
|
|
# Verify they are cleared
|
|
assert session_module._engine is None
|
|
assert session_module._session_factory is None
|
|
|
|
async def test_get_async_session_yields_session(
|
|
self,
|
|
async_engine,
|
|
) -> None:
|
|
# We need to use the test engine, not the production one
|
|
# This test uses the async_engine fixture which creates an in-memory DB
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
# Set up the test engine in the module
|
|
session_module._engine = async_engine
|
|
session_module._session_factory = async_sessionmaker(
|
|
bind=async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
try:
|
|
async for db_session in session_module.get_async_session():
|
|
assert isinstance(db_session, AsyncSession)
|
|
# Session should be usable
|
|
result = await db_session.execute(text("SELECT 1"))
|
|
assert result is not None
|
|
finally:
|
|
session_module.reset_engine()
|
|
|
|
async def test_get_async_session_commits_on_success(
|
|
self,
|
|
async_engine,
|
|
) -> None:
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
from arbiter.db.models import ReviewModel
|
|
|
|
# Set up the test engine
|
|
session_module._engine = async_engine
|
|
session_module._session_factory = async_sessionmaker(
|
|
bind=async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
try:
|
|
# Create a review inside the session context
|
|
async for db_session in session_module.get_async_session():
|
|
review = ReviewModel(
|
|
repository="test/repo",
|
|
pr_number=1,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
status="pending",
|
|
)
|
|
db_session.add(review)
|
|
# Don't manually commit - let the context manager do it
|
|
|
|
# Verify the data was committed by querying with a new session
|
|
async with async_sessionmaker(
|
|
bind=async_engine, class_=AsyncSession
|
|
)() as verify_session:
|
|
result = await verify_session.execute(text("SELECT COUNT(*) FROM reviews"))
|
|
count = result.scalar()
|
|
assert count == 1
|
|
finally:
|
|
session_module.reset_engine()
|
|
|
|
async def test_get_async_session_propagates_error(
|
|
self,
|
|
async_engine,
|
|
) -> None:
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
from arbiter.db.models import ReviewModel
|
|
|
|
session_module._engine = async_engine
|
|
session_module._session_factory = async_sessionmaker(
|
|
bind=async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
try:
|
|
# The error should be re-raised after rollback is called
|
|
with pytest.raises(ValueError, match="Test error"):
|
|
async for db_session in session_module.get_async_session():
|
|
review = ReviewModel(
|
|
repository="test/repo",
|
|
pr_number=99,
|
|
base_sha="abc123",
|
|
head_sha="def456",
|
|
status="pending",
|
|
)
|
|
db_session.add(review)
|
|
raise ValueError("Test error")
|
|
|
|
# The session context manager handled the error correctly
|
|
# (exception was propagated, rollback was called internally)
|
|
finally:
|
|
session_module.reset_engine()
|
|
|
|
async def test_init_db_creates_tables(self, async_engine) -> None:
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
# Set up the engine (tables already created by fixture, but let's verify)
|
|
session_module._engine = async_engine
|
|
session_module._session_factory = None
|
|
|
|
try:
|
|
# Call init_db
|
|
await session_module.init_db()
|
|
|
|
# Verify tables exist
|
|
async with async_sessionmaker(bind=async_engine, class_=AsyncSession)() as session:
|
|
# Try to query the reviews table
|
|
result = await session.execute(text("SELECT 1 FROM reviews LIMIT 1"))
|
|
# If no error, table exists
|
|
assert result is not None
|
|
finally:
|
|
session_module.reset_engine()
|
|
|
|
async def test_close_db_disposes_engine(self, async_engine) -> None:
|
|
session_module._engine = async_engine
|
|
session_module._session_factory = "not none" # type: ignore
|
|
|
|
await session_module.close_db()
|
|
|
|
assert session_module._engine is None
|
|
assert session_module._session_factory is None
|
|
|
|
def test_session_factory_create(self) -> None:
|
|
# Reset to ensure clean state
|
|
session_module.reset_engine()
|
|
|
|
# This will fail without proper settings, so we just test the path
|
|
# where the factory already exists
|
|
session_module._session_factory = "mock_factory" # type: ignore
|
|
|
|
result = session_module.async_session_factory()
|
|
assert result == "mock_factory"
|
|
|
|
session_module.reset_engine()
|