test(backend): add unit tests
This commit is contained in:
91
backend/tests/conftest.py
Normal file
91
backend/tests/conftest.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlalchemy import JSON, event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from src.db.database import get_db
|
||||||
|
from src.main import app
|
||||||
|
from src.models import Base
|
||||||
|
|
||||||
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||||
|
|
||||||
|
|
||||||
|
def patch_postgres_types() -> None:
|
||||||
|
from sqlalchemy.dialects.sqlite import base as sqlite_base
|
||||||
|
|
||||||
|
def visit_JSONB(
|
||||||
|
self: Any,
|
||||||
|
type_: Any, # noqa: ARG001
|
||||||
|
**kw: Any,
|
||||||
|
) -> str:
|
||||||
|
return str(self.visit_JSON(JSON(), **kw))
|
||||||
|
|
||||||
|
sqlite_base.SQLiteTypeCompiler.visit_JSONB = visit_JSONB # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
def visit_UUID(
|
||||||
|
self: Any, # noqa: ARG001
|
||||||
|
type_: Any, # noqa: ARG001
|
||||||
|
**kw: Any, # noqa: ARG001
|
||||||
|
) -> str:
|
||||||
|
return "VARCHAR(36)"
|
||||||
|
|
||||||
|
sqlite_base.SQLiteTypeCompiler.visit_UUID = visit_UUID # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
|
patch_postgres_types()
|
||||||
|
|
||||||
|
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_connection: Any, _connection_record: Any) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
async_session_factory = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def setup_database() -> AsyncGenerator[None, None]:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
268
backend/tests/test_api.py
Normal file
268
backend/tests/test_api.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.models import Category, Difficulty, Explanation, Pattern, Question, Solution
|
||||||
|
|
||||||
|
SampleData = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
async def create_sample_data(db_session: AsyncSession) -> SampleData:
|
||||||
|
category = Category(name="Arrays", slug="arrays", description="Array problems")
|
||||||
|
db_session.add(category)
|
||||||
|
|
||||||
|
pattern = Pattern(
|
||||||
|
name="Two Pointers",
|
||||||
|
slug="two-pointers",
|
||||||
|
description="Two pointer technique",
|
||||||
|
)
|
||||||
|
db_session.add(pattern)
|
||||||
|
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
question = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers that add up to target.",
|
||||||
|
leetcode_id=1,
|
||||||
|
leetcode_url="https://leetcode.com/problems/two-sum/",
|
||||||
|
)
|
||||||
|
question.categories = [category]
|
||||||
|
question.patterns = [pattern]
|
||||||
|
db_session.add(question)
|
||||||
|
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
explanation = Explanation(
|
||||||
|
question_id=question.id,
|
||||||
|
approach="Use a hash map.",
|
||||||
|
intuition="Store seen values for O(1) lookup.",
|
||||||
|
time_complexity="O(n)",
|
||||||
|
space_complexity="O(n)",
|
||||||
|
)
|
||||||
|
db_session.add(explanation)
|
||||||
|
|
||||||
|
solution = Solution(
|
||||||
|
question_id=question.id,
|
||||||
|
approach_name="Hash Map",
|
||||||
|
code="def two_sum(nums, target): pass",
|
||||||
|
is_optimal=True,
|
||||||
|
)
|
||||||
|
db_session.add(solution)
|
||||||
|
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
return {"question": question, "category": category, "pattern": pattern}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_health_check(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_questions_empty(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/questions")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["items"] == []
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_questions(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/questions")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 1
|
||||||
|
assert len(data["items"]) == 1
|
||||||
|
assert data["items"][0]["title"] == "Two Sum"
|
||||||
|
assert data["items"][0]["difficulty"] == "easy"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_questions_filter_difficulty(
|
||||||
|
client: AsyncClient, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/questions?difficulty=easy")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["total"] == 1
|
||||||
|
|
||||||
|
response = await client.get("/api/questions?difficulty=hard")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["total"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_questions_invalid_difficulty(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/questions?difficulty=invalid")
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_question(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/questions/two-sum")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "Two Sum"
|
||||||
|
assert data["explanation"]["time_complexity"] == "O(n)"
|
||||||
|
assert len(data["solutions"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_question_not_found(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/questions/nonexistent")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_categories_empty(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/categories")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["items"] == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_categories(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/categories")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["items"]) == 1
|
||||||
|
assert data["items"][0]["name"] == "Arrays"
|
||||||
|
assert data["items"][0]["question_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_patterns_empty(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/patterns")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["items"] == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_patterns(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/patterns")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["items"]) == 1
|
||||||
|
assert data["items"][0]["name"] == "Two Pointers"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_pattern(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/patterns/two-pointers")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Two Pointers"
|
||||||
|
assert data["question_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_pattern_not_found(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/patterns/nonexistent")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_stats_empty(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/stats")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total_questions"] == 0
|
||||||
|
assert data["by_difficulty"]["easy"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_stats(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/stats")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total_questions"] == 1
|
||||||
|
assert data["by_difficulty"]["easy"] == 1
|
||||||
|
assert len(data["by_category"]) == 1
|
||||||
|
assert len(data["by_pattern"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_pattern_tutorial(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.get("/api/patterns/two-pointers/tutorial")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Two Pointers"
|
||||||
|
assert data["learning_progression"]["warmup"][0]["title"] == "Two Sum"
|
||||||
|
assert len(data["learning_progression"]["core"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_pattern_tutorial_not_found(client: AsyncClient) -> None:
|
||||||
|
response = await client.get("/api/patterns/nonexistent/tutorial")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_submit_passing(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
data = await create_sample_data(db_session)
|
||||||
|
q = data["question"]
|
||||||
|
q.test_cases = {
|
||||||
|
"visible": [{"input": {"nums": [2, 7], "target": 9}, "expected": [0, 1]}],
|
||||||
|
"hidden": [{"input": {"nums": [3, 2, 4], "target": 6}, "expected": [1, 2]}],
|
||||||
|
}
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/questions/two-sum/submit",
|
||||||
|
json={
|
||||||
|
"code": "def solve(nums, t): pass",
|
||||||
|
"hidden_outputs": [{"test_id": 1, "output": [1, 2]}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
assert result["passed"] is True
|
||||||
|
assert result["total_passed"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_submit_wrong_output(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
data = await create_sample_data(db_session)
|
||||||
|
data["question"].test_cases = {
|
||||||
|
"visible": [],
|
||||||
|
"hidden": [{"input": {"nums": [1, 2], "target": 3}, "expected": [0, 1]}],
|
||||||
|
}
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/questions/two-sum/submit",
|
||||||
|
json={
|
||||||
|
"code": "def solve(nums, t): pass",
|
||||||
|
"hidden_outputs": [{"test_id": 0, "output": [9, 9]}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["passed"] is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_submit_missing_output(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
data = await create_sample_data(db_session)
|
||||||
|
data["question"].test_cases = {
|
||||||
|
"visible": [],
|
||||||
|
"hidden": [{"input": {"nums": [1, 2], "target": 3}, "expected": [0, 1]}],
|
||||||
|
}
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/questions/two-sum/submit",
|
||||||
|
json={"code": "def solve(nums, t): pass", "hidden_outputs": []},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["hidden_results"][0]["error"] == "No output provided"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_submit_no_test_cases(client: AsyncClient, db_session: AsyncSession) -> None:
|
||||||
|
await create_sample_data(db_session)
|
||||||
|
response = await client.post(
|
||||||
|
"/api/questions/two-sum/submit",
|
||||||
|
json={"code": "pass", "hidden_outputs": []},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_submit_not_found(client: AsyncClient) -> None:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/questions/nonexistent/submit",
|
||||||
|
json={"code": "pass", "hidden_outputs": []},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
156
backend/tests/test_models.py
Normal file
156
backend/tests/test_models.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from src.models import Category, Difficulty, Explanation, Pattern, Question, Solution
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_category(db_session: AsyncSession) -> None:
|
||||||
|
category = Category(
|
||||||
|
name="Arrays",
|
||||||
|
slug="arrays",
|
||||||
|
description="Array manipulation problems",
|
||||||
|
)
|
||||||
|
db_session.add(category)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(select(Category))
|
||||||
|
categories = list(result.scalars().all())
|
||||||
|
assert len(categories) == 1
|
||||||
|
assert categories[0].name == "Arrays"
|
||||||
|
assert categories[0].slug == "arrays"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_pattern(db_session: AsyncSession) -> None:
|
||||||
|
pattern = Pattern(
|
||||||
|
name="Two Pointers",
|
||||||
|
slug="two-pointers",
|
||||||
|
description="Use two pointers to traverse data",
|
||||||
|
when_to_use="Sorted arrays, linked lists",
|
||||||
|
)
|
||||||
|
db_session.add(pattern)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(select(Pattern))
|
||||||
|
patterns = list(result.scalars().all())
|
||||||
|
assert len(patterns) == 1
|
||||||
|
assert patterns[0].name == "Two Pointers"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_question(db_session: AsyncSession) -> None:
|
||||||
|
question = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers that add up to target.",
|
||||||
|
leetcode_id=1,
|
||||||
|
)
|
||||||
|
db_session.add(question)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(select(Question))
|
||||||
|
questions = list(result.scalars().all())
|
||||||
|
assert len(questions) == 1
|
||||||
|
assert questions[0].title == "Two Sum"
|
||||||
|
assert questions[0].difficulty == Difficulty.EASY
|
||||||
|
|
||||||
|
|
||||||
|
async def test_question_with_categories(db_session: AsyncSession) -> None:
|
||||||
|
category = Category(name="Arrays", slug="arrays")
|
||||||
|
question = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers.",
|
||||||
|
)
|
||||||
|
question.categories = [category]
|
||||||
|
|
||||||
|
db_session.add(question)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(select(Question))
|
||||||
|
q = result.scalar_one()
|
||||||
|
assert len(q.categories) == 1
|
||||||
|
assert q.categories[0].name == "Arrays"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_question_with_patterns(db_session: AsyncSession) -> None:
|
||||||
|
pattern = Pattern(name="Two Pointers", slug="two-pointers")
|
||||||
|
question = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers.",
|
||||||
|
)
|
||||||
|
question.patterns = [pattern]
|
||||||
|
|
||||||
|
db_session.add(question)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(select(Question))
|
||||||
|
q = result.scalar_one()
|
||||||
|
assert len(q.patterns) == 1
|
||||||
|
assert q.patterns[0].name == "Two Pointers"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_question_with_explanation(db_session: AsyncSession) -> None:
|
||||||
|
question = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers.",
|
||||||
|
)
|
||||||
|
db_session.add(question)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
explanation = Explanation(
|
||||||
|
question_id=question.id,
|
||||||
|
approach="Use hash map",
|
||||||
|
intuition="Store seen values",
|
||||||
|
time_complexity="O(n)",
|
||||||
|
space_complexity="O(n)",
|
||||||
|
)
|
||||||
|
db_session.add(explanation)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Question)
|
||||||
|
.options(selectinload(Question.explanation))
|
||||||
|
.where(Question.id == question.id)
|
||||||
|
)
|
||||||
|
loaded_question = result.scalar_one()
|
||||||
|
assert loaded_question.explanation is not None
|
||||||
|
assert loaded_question.explanation.time_complexity == "O(n)"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_question_with_solutions(db_session: AsyncSession) -> None:
|
||||||
|
question = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers.",
|
||||||
|
)
|
||||||
|
db_session.add(question)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
solution = Solution(
|
||||||
|
question_id=question.id,
|
||||||
|
approach_name="Hash Map",
|
||||||
|
code="def solve(): pass",
|
||||||
|
is_optimal=True,
|
||||||
|
)
|
||||||
|
db_session.add(solution)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(Question).options(selectinload(Question.solutions)).where(Question.id == question.id)
|
||||||
|
)
|
||||||
|
loaded_question = result.scalar_one()
|
||||||
|
assert len(loaded_question.solutions) == 1
|
||||||
|
assert loaded_question.solutions[0].is_optimal is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_difficulty_enum() -> None:
|
||||||
|
assert Difficulty.EASY.value == "easy"
|
||||||
|
assert Difficulty.MEDIUM.value == "medium"
|
||||||
|
assert Difficulty.HARD.value == "hard"
|
||||||
286
backend/tests/test_services.py
Normal file
286
backend/tests/test_services.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from src.models import Category, Difficulty, Pattern, Question
|
||||||
|
from src.services import CategoryService, PatternService, QuestionService, StatsService
|
||||||
|
|
||||||
|
TestData = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
async def create_test_data(db_session: AsyncSession) -> TestData:
|
||||||
|
cat_arrays = Category(name="Arrays", slug="arrays")
|
||||||
|
cat_strings = Category(name="Strings", slug="strings")
|
||||||
|
db_session.add_all([cat_arrays, cat_strings])
|
||||||
|
|
||||||
|
pat_two_pointers = Pattern(name="Two Pointers", slug="two-pointers")
|
||||||
|
pat_sliding = Pattern(name="Sliding Window", slug="sliding-window")
|
||||||
|
db_session.add_all([pat_two_pointers, pat_sliding])
|
||||||
|
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
q1 = Question(
|
||||||
|
title="Two Sum",
|
||||||
|
slug="two-sum",
|
||||||
|
difficulty=Difficulty.EASY,
|
||||||
|
description="Find two numbers.",
|
||||||
|
)
|
||||||
|
q1.categories = [cat_arrays]
|
||||||
|
q1.patterns = [pat_two_pointers]
|
||||||
|
|
||||||
|
q2 = Question(
|
||||||
|
title="Three Sum",
|
||||||
|
slug="three-sum",
|
||||||
|
difficulty=Difficulty.MEDIUM,
|
||||||
|
description="Find three numbers.",
|
||||||
|
)
|
||||||
|
q2.categories = [cat_arrays]
|
||||||
|
q2.patterns = [pat_two_pointers]
|
||||||
|
|
||||||
|
q3 = Question(
|
||||||
|
title="Longest Substring",
|
||||||
|
slug="longest-substring",
|
||||||
|
difficulty=Difficulty.MEDIUM,
|
||||||
|
description="Find longest substring.",
|
||||||
|
)
|
||||||
|
q3.categories = [cat_strings]
|
||||||
|
q3.patterns = [pat_sliding]
|
||||||
|
|
||||||
|
db_session.add_all([q1, q2, q3])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cat_arrays": cat_arrays,
|
||||||
|
"cat_strings": cat_strings,
|
||||||
|
"pat_two_pointers": pat_two_pointers,
|
||||||
|
"pat_sliding": pat_sliding,
|
||||||
|
"q1": q1,
|
||||||
|
"q2": q2,
|
||||||
|
"q3": q3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuestionService:
|
||||||
|
async def test_get_questions(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions()
|
||||||
|
|
||||||
|
assert total == 3
|
||||||
|
assert len(questions) == 3
|
||||||
|
|
||||||
|
async def test_pagination(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(page=1, limit=2)
|
||||||
|
|
||||||
|
assert total == 3
|
||||||
|
assert len(questions) == 2
|
||||||
|
|
||||||
|
async def test_filter_difficulty(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(difficulties=[Difficulty.EASY])
|
||||||
|
|
||||||
|
assert total == 1
|
||||||
|
assert questions[0].title == "Two Sum"
|
||||||
|
|
||||||
|
async def test_filter_category(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(category_slug="arrays")
|
||||||
|
|
||||||
|
assert total == 2
|
||||||
|
|
||||||
|
async def test_filter_pattern(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(pattern_slug="sliding-window")
|
||||||
|
|
||||||
|
assert total == 1
|
||||||
|
assert questions[0].title == "Longest Substring"
|
||||||
|
|
||||||
|
async def test_search(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(search="Sum")
|
||||||
|
|
||||||
|
assert total == 2
|
||||||
|
|
||||||
|
async def test_get_by_slug(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
question = await service.get_question_by_slug("two-sum")
|
||||||
|
|
||||||
|
assert question is not None
|
||||||
|
assert question.title == "Two Sum"
|
||||||
|
|
||||||
|
async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None:
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
question = await service.get_question_by_slug("nonexistent")
|
||||||
|
|
||||||
|
assert question is None
|
||||||
|
|
||||||
|
async def test_get_by_id(self, db_session: AsyncSession) -> None:
|
||||||
|
data = await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
question = await service.get_question_by_id(data["q1"].id)
|
||||||
|
|
||||||
|
assert question is not None
|
||||||
|
assert question.title == "Two Sum"
|
||||||
|
|
||||||
|
async def test_get_by_id_not_found(self, db_session: AsyncSession) -> None:
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
question = await service.get_question_by_id(uuid4())
|
||||||
|
|
||||||
|
assert question is None
|
||||||
|
|
||||||
|
async def test_combined_filters(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(
|
||||||
|
difficulties=[Difficulty.MEDIUM],
|
||||||
|
category_slug="arrays",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total == 1
|
||||||
|
assert questions[0].title == "Three Sum"
|
||||||
|
|
||||||
|
async def test_pagination_page_two(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(page=2, limit=2)
|
||||||
|
|
||||||
|
assert total == 3
|
||||||
|
assert len(questions) == 1
|
||||||
|
|
||||||
|
async def test_pagination_beyond_data(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = QuestionService(db_session)
|
||||||
|
questions, total = await service.get_questions(page=10, limit=20)
|
||||||
|
|
||||||
|
assert total == 3
|
||||||
|
assert len(questions) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCategoryService:
|
||||||
|
async def test_get_categories(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = CategoryService(db_session)
|
||||||
|
categories = await service.get_categories()
|
||||||
|
|
||||||
|
assert len(categories) == 2
|
||||||
|
|
||||||
|
arrays = next((c, cnt) for c, cnt in categories if c.slug == "arrays")
|
||||||
|
assert arrays[1] == 2
|
||||||
|
|
||||||
|
async def test_get_by_slug(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = CategoryService(db_session)
|
||||||
|
category = await service.get_category_by_slug("arrays")
|
||||||
|
|
||||||
|
assert category is not None
|
||||||
|
assert category.name == "Arrays"
|
||||||
|
|
||||||
|
async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None:
|
||||||
|
service = CategoryService(db_session)
|
||||||
|
category = await service.get_category_by_slug("nonexistent")
|
||||||
|
|
||||||
|
assert category is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatternService:
|
||||||
|
async def test_get_patterns(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = PatternService(db_session)
|
||||||
|
patterns = await service.get_patterns()
|
||||||
|
|
||||||
|
assert len(patterns) == 2
|
||||||
|
|
||||||
|
async def test_get_by_slug(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = PatternService(db_session)
|
||||||
|
result = await service.get_pattern_by_slug("two-pointers")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
pattern, count = result
|
||||||
|
assert pattern.name == "Two Pointers"
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None:
|
||||||
|
service = PatternService(db_session)
|
||||||
|
result = await service.get_pattern_by_slug("nonexistent")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_get_tutorial(self, db_session: AsyncSession) -> None:
|
||||||
|
related = Pattern(name="BFS", slug="bfs", description="Breadth-first search")
|
||||||
|
pattern = Pattern(
|
||||||
|
name="Two Pointers",
|
||||||
|
slug="two-pointers",
|
||||||
|
description="Two pointer technique",
|
||||||
|
common_mistakes=[
|
||||||
|
{"title": "Off by one", "description": "Wrong bounds", "fix": "Use <="}
|
||||||
|
],
|
||||||
|
variations=[{"name": "Fast/slow", "description": "Different speeds"}],
|
||||||
|
related_patterns=["bfs"],
|
||||||
|
visualization_examples=[{"id": "ex1", "title": "Basic", "code": "x=1", "steps": []}],
|
||||||
|
)
|
||||||
|
db_session.add_all([related, pattern])
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
q1 = Question(
|
||||||
|
title="Easy Q", slug="easy-q", difficulty=Difficulty.EASY, description="Easy."
|
||||||
|
)
|
||||||
|
q2 = Question(title="Med Q", slug="med-q", difficulty=Difficulty.MEDIUM, description="Med.")
|
||||||
|
q3 = Question(
|
||||||
|
title="Hard Q", slug="hard-q", difficulty=Difficulty.HARD, description="Hard."
|
||||||
|
)
|
||||||
|
q1.patterns = [pattern]
|
||||||
|
q2.patterns = [pattern]
|
||||||
|
q3.patterns = [pattern]
|
||||||
|
db_session.add_all([q1, q2, q3])
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
service = PatternService(db_session)
|
||||||
|
result = await service.get_pattern_tutorial("two-pointers")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.name == "Two Pointers"
|
||||||
|
assert len(result.common_mistakes) == 1
|
||||||
|
assert len(result.variations) == 1
|
||||||
|
assert len(result.related_patterns) == 1
|
||||||
|
assert result.related_patterns[0].slug == "bfs"
|
||||||
|
assert len(result.learning_progression.warmup) == 1
|
||||||
|
assert len(result.learning_progression.core) == 1
|
||||||
|
assert len(result.learning_progression.challenge) == 1
|
||||||
|
assert len(result.visualization_examples) == 1
|
||||||
|
|
||||||
|
async def test_get_tutorial_not_found(self, db_session: AsyncSession) -> None:
|
||||||
|
service = PatternService(db_session)
|
||||||
|
result = await service.get_pattern_tutorial("nonexistent")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestStatsService:
|
||||||
|
async def test_get_stats(self, db_session: AsyncSession) -> None:
|
||||||
|
await create_test_data(db_session)
|
||||||
|
service = StatsService(db_session)
|
||||||
|
stats = await service.get_stats()
|
||||||
|
|
||||||
|
assert stats.total_questions == 3
|
||||||
|
assert stats.by_difficulty.easy == 1
|
||||||
|
assert stats.by_difficulty.medium == 2
|
||||||
|
assert stats.by_difficulty.hard == 0
|
||||||
|
assert len(stats.by_category) == 2
|
||||||
|
assert len(stats.by_pattern) == 2
|
||||||
|
|
||||||
|
async def test_get_stats_empty(self, db_session: AsyncSession) -> None:
|
||||||
|
service = StatsService(db_session)
|
||||||
|
stats = await service.get_stats()
|
||||||
|
|
||||||
|
assert stats.total_questions == 0
|
||||||
|
assert stats.by_difficulty.easy == 0
|
||||||
Reference in New Issue
Block a user