55 lines
1.4 KiB
Python
55 lines
1.4 KiB
Python
"""FastAPI dependency injection for Arbiter."""
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from functools import lru_cache
|
|
from typing import Annotated
|
|
|
|
from fastapi import Depends
|
|
from redis.asyncio import Redis
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from arbiter.config import Settings
|
|
from arbiter.config import get_settings as _get_settings
|
|
from arbiter.db.session import get_async_session
|
|
|
|
# Redis client cache
|
|
_redis_client: Redis | None = None
|
|
|
|
|
|
@lru_cache
|
|
def get_settings() -> Settings:
|
|
return _get_settings()
|
|
|
|
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
async for session in get_async_session():
|
|
yield session
|
|
|
|
|
|
async def get_redis() -> AsyncGenerator[Redis, None]:
|
|
"""Get Redis client dependency."""
|
|
global _redis_client
|
|
settings = get_settings()
|
|
|
|
if _redis_client is None:
|
|
_redis_client = Redis.from_url(
|
|
settings.redis_url,
|
|
max_connections=settings.redis_max_connections,
|
|
decode_responses=True,
|
|
)
|
|
|
|
yield _redis_client
|
|
|
|
|
|
async def close_redis() -> None:
|
|
global _redis_client
|
|
if _redis_client is not None:
|
|
await _redis_client.close()
|
|
_redis_client = None
|
|
|
|
|
|
# Type aliases for cleaner route signatures
|
|
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
|
RedisClient = Annotated[Redis, Depends(get_redis)]
|
|
AppSettings = Annotated[Settings, Depends(get_settings)]
|