92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
"""Tests for API dependency injection."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from arbiter.api import deps as deps_module
|
|
from arbiter.config import Settings
|
|
|
|
|
|
class TestDependencies:
|
|
def test_get_settings_returns_settings(self) -> None:
|
|
# Clear the lru_cache to ensure fresh state
|
|
deps_module.get_settings.cache_clear()
|
|
result = deps_module.get_settings()
|
|
assert isinstance(result, Settings)
|
|
|
|
def test_get_settings_is_cached(self) -> None:
|
|
deps_module.get_settings.cache_clear()
|
|
result1 = deps_module.get_settings()
|
|
result2 = deps_module.get_settings()
|
|
assert result1 is result2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_redis_when_not_connected(self) -> None:
|
|
# Ensure no redis client is set
|
|
original = deps_module._redis_client
|
|
deps_module._redis_client = None
|
|
|
|
try:
|
|
# Should not raise
|
|
await deps_module.close_redis()
|
|
assert deps_module._redis_client is None
|
|
finally:
|
|
deps_module._redis_client = original
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_redis_when_connected(self) -> None:
|
|
original = deps_module._redis_client
|
|
mock_client = AsyncMock()
|
|
|
|
deps_module._redis_client = mock_client
|
|
|
|
try:
|
|
await deps_module.close_redis()
|
|
mock_client.close.assert_called_once()
|
|
assert deps_module._redis_client is None
|
|
finally:
|
|
deps_module._redis_client = original
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_yields_session(self, async_engine) -> None:
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from arbiter.db import session as session_module
|
|
|
|
# Set up the test engine (async_engine is used)
|
|
session_module._engine = async_engine
|
|
session_module._session_factory = async_sessionmaker(
|
|
bind=async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
try:
|
|
async for session in deps_module.get_db():
|
|
assert isinstance(session, AsyncSession)
|
|
finally:
|
|
session_module.reset_engine()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_redis_creates_client(self) -> None:
|
|
original = deps_module._redis_client
|
|
deps_module._redis_client = None
|
|
|
|
try:
|
|
with patch("arbiter.api.deps.Redis") as mock_redis:
|
|
mock_instance = MagicMock()
|
|
mock_redis.from_url.return_value = mock_instance
|
|
|
|
async for client in deps_module.get_redis():
|
|
assert client is mock_instance
|
|
mock_redis.from_url.assert_called_once()
|
|
|
|
# Second call should use cached client
|
|
async for client2 in deps_module.get_redis():
|
|
assert client2 is mock_instance
|
|
# Still only called once due to caching
|
|
assert mock_redis.from_url.call_count == 1
|
|
finally:
|
|
deps_module._redis_client = original
|