Files
codetutor/backend/tests/conftest.py

92 lines
2.5 KiB
Python

import asyncio
from collections.abc import AsyncGenerator, Generator
from typing import Any
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import JSON, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from src.db.database import get_db
from src.main import app
from src.models import Base
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
def patch_postgres_types() -> None:
from sqlalchemy.dialects.sqlite import base as sqlite_base
def visit_JSONB(
self: Any,
type_: Any, # noqa: ARG001
**kw: Any,
) -> str:
return str(self.visit_JSON(JSON(), **kw))
sqlite_base.SQLiteTypeCompiler.visit_JSONB = visit_JSONB # type: ignore[attr-defined]
def visit_UUID(
self: Any, # noqa: ARG001
type_: Any, # noqa: ARG001
**kw: Any, # noqa: ARG001
) -> str:
return "VARCHAR(36)"
sqlite_base.SQLiteTypeCompiler.visit_UUID = visit_UUID # type: ignore[method-assign]
patch_postgres_types()
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_connection: Any, _connection_record: Any) -> None:
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True)
async def setup_database() -> AsyncGenerator[None, None]:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_factory() as session:
yield session
@pytest.fixture
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as test_client:
yield test_client
app.dependency_overrides.clear()