Files
arbiter/tests/test_deps.py
2025-06-05 19:05:01 +00:00

92 lines
3.1 KiB
Python

"""Tests for API dependency injection."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from arbiter.api import deps as deps_module
from arbiter.config import Settings
class TestDependencies:
def test_get_settings_returns_settings(self) -> None:
# Clear the lru_cache to ensure fresh state
deps_module.get_settings.cache_clear()
result = deps_module.get_settings()
assert isinstance(result, Settings)
def test_get_settings_is_cached(self) -> None:
deps_module.get_settings.cache_clear()
result1 = deps_module.get_settings()
result2 = deps_module.get_settings()
assert result1 is result2
@pytest.mark.asyncio
async def test_close_redis_when_not_connected(self) -> None:
# Ensure no redis client is set
original = deps_module._redis_client
deps_module._redis_client = None
try:
# Should not raise
await deps_module.close_redis()
assert deps_module._redis_client is None
finally:
deps_module._redis_client = original
@pytest.mark.asyncio
async def test_close_redis_when_connected(self) -> None:
original = deps_module._redis_client
mock_client = AsyncMock()
deps_module._redis_client = mock_client
try:
await deps_module.close_redis()
mock_client.close.assert_called_once()
assert deps_module._redis_client is None
finally:
deps_module._redis_client = original
@pytest.mark.asyncio
async def test_get_db_yields_session(self, async_engine) -> None:
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from arbiter.db import session as session_module
# Set up the test engine (async_engine is used)
session_module._engine = async_engine
session_module._session_factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
try:
async for session in deps_module.get_db():
assert isinstance(session, AsyncSession)
finally:
session_module.reset_engine()
@pytest.mark.asyncio
async def test_get_redis_creates_client(self) -> None:
original = deps_module._redis_client
deps_module._redis_client = None
try:
with patch("arbiter.api.deps.Redis") as mock_redis:
mock_instance = MagicMock()
mock_redis.from_url.return_value = mock_instance
async for client in deps_module.get_redis():
assert client is mock_instance
mock_redis.from_url.assert_called_once()
# Second call should use cached client
async for client2 in deps_module.get_redis():
assert client2 is mock_instance
# Still only called once due to caching
assert mock_redis.from_url.call_count == 1
finally:
deps_module._redis_client = original