import asyncio from collections.abc import AsyncGenerator, Generator from typing import Any import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import JSON, event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from src.db.database import get_db from src.main import app from src.models import Base TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" def patch_postgres_types() -> None: from sqlalchemy.dialects.sqlite import base as sqlite_base def visit_JSONB( self: Any, type_: Any, # noqa: ARG001 **kw: Any, ) -> str: return str(self.visit_JSON(JSON(), **kw)) sqlite_base.SQLiteTypeCompiler.visit_JSONB = visit_JSONB # type: ignore[attr-defined] def visit_UUID( self: Any, # noqa: ARG001 type_: Any, # noqa: ARG001 **kw: Any, # noqa: ARG001 ) -> str: return "VARCHAR(36)" sqlite_base.SQLiteTypeCompiler.visit_UUID = visit_UUID # type: ignore[method-assign] patch_postgres_types() engine = create_async_engine(TEST_DATABASE_URL, echo=False) @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_connection: Any, _connection_record: Any) -> None: cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() async_session_factory = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, ) @pytest.fixture(scope="session") def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @pytest.fixture(autouse=True) async def setup_database() -> AsyncGenerator[None, None]: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: async with async_session_factory() as session: yield session @pytest.fixture async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: async def override_get_db() -> AsyncGenerator[AsyncSession, None]: yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as test_client: yield test_client app.dependency_overrides.clear()