287 lines
10 KiB
Python
287 lines
10 KiB
Python
from typing import Any
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from src.models import Category, Difficulty, Pattern, Question
|
|
from src.services import CategoryService, PatternService, QuestionService, StatsService
|
|
|
|
TestData = dict[str, Any]
|
|
|
|
|
|
async def create_test_data(db_session: AsyncSession) -> TestData:
|
|
cat_arrays = Category(name="Arrays", slug="arrays")
|
|
cat_strings = Category(name="Strings", slug="strings")
|
|
db_session.add_all([cat_arrays, cat_strings])
|
|
|
|
pat_two_pointers = Pattern(name="Two Pointers", slug="two-pointers")
|
|
pat_sliding = Pattern(name="Sliding Window", slug="sliding-window")
|
|
db_session.add_all([pat_two_pointers, pat_sliding])
|
|
|
|
await db_session.flush()
|
|
|
|
q1 = Question(
|
|
title="Two Sum",
|
|
slug="two-sum",
|
|
difficulty=Difficulty.EASY,
|
|
description="Find two numbers.",
|
|
)
|
|
q1.categories = [cat_arrays]
|
|
q1.patterns = [pat_two_pointers]
|
|
|
|
q2 = Question(
|
|
title="Three Sum",
|
|
slug="three-sum",
|
|
difficulty=Difficulty.MEDIUM,
|
|
description="Find three numbers.",
|
|
)
|
|
q2.categories = [cat_arrays]
|
|
q2.patterns = [pat_two_pointers]
|
|
|
|
q3 = Question(
|
|
title="Longest Substring",
|
|
slug="longest-substring",
|
|
difficulty=Difficulty.MEDIUM,
|
|
description="Find longest substring.",
|
|
)
|
|
q3.categories = [cat_strings]
|
|
q3.patterns = [pat_sliding]
|
|
|
|
db_session.add_all([q1, q2, q3])
|
|
await db_session.commit()
|
|
|
|
return {
|
|
"cat_arrays": cat_arrays,
|
|
"cat_strings": cat_strings,
|
|
"pat_two_pointers": pat_two_pointers,
|
|
"pat_sliding": pat_sliding,
|
|
"q1": q1,
|
|
"q2": q2,
|
|
"q3": q3,
|
|
}
|
|
|
|
|
|
class TestQuestionService:
|
|
async def test_get_questions(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions()
|
|
|
|
assert total == 3
|
|
assert len(questions) == 3
|
|
|
|
async def test_pagination(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(page=1, limit=2)
|
|
|
|
assert total == 3
|
|
assert len(questions) == 2
|
|
|
|
async def test_filter_difficulty(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(difficulties=[Difficulty.EASY])
|
|
|
|
assert total == 1
|
|
assert questions[0].title == "Two Sum"
|
|
|
|
async def test_filter_category(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(category_slug="arrays")
|
|
|
|
assert total == 2
|
|
|
|
async def test_filter_pattern(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(pattern_slug="sliding-window")
|
|
|
|
assert total == 1
|
|
assert questions[0].title == "Longest Substring"
|
|
|
|
async def test_search(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(search="Sum")
|
|
|
|
assert total == 2
|
|
|
|
async def test_get_by_slug(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
question = await service.get_question_by_slug("two-sum")
|
|
|
|
assert question is not None
|
|
assert question.title == "Two Sum"
|
|
|
|
async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None:
|
|
service = QuestionService(db_session)
|
|
question = await service.get_question_by_slug("nonexistent")
|
|
|
|
assert question is None
|
|
|
|
async def test_get_by_id(self, db_session: AsyncSession) -> None:
|
|
data = await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
question = await service.get_question_by_id(data["q1"].id)
|
|
|
|
assert question is not None
|
|
assert question.title == "Two Sum"
|
|
|
|
async def test_get_by_id_not_found(self, db_session: AsyncSession) -> None:
|
|
from uuid import uuid4
|
|
|
|
service = QuestionService(db_session)
|
|
question = await service.get_question_by_id(uuid4())
|
|
|
|
assert question is None
|
|
|
|
async def test_combined_filters(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(
|
|
difficulties=[Difficulty.MEDIUM],
|
|
category_slug="arrays",
|
|
)
|
|
|
|
assert total == 1
|
|
assert questions[0].title == "Three Sum"
|
|
|
|
async def test_pagination_page_two(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(page=2, limit=2)
|
|
|
|
assert total == 3
|
|
assert len(questions) == 1
|
|
|
|
async def test_pagination_beyond_data(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = QuestionService(db_session)
|
|
questions, total = await service.get_questions(page=10, limit=20)
|
|
|
|
assert total == 3
|
|
assert len(questions) == 0
|
|
|
|
|
|
class TestCategoryService:
|
|
async def test_get_categories(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = CategoryService(db_session)
|
|
categories = await service.get_categories()
|
|
|
|
assert len(categories) == 2
|
|
|
|
arrays = next((c, cnt) for c, cnt in categories if c.slug == "arrays")
|
|
assert arrays[1] == 2
|
|
|
|
async def test_get_by_slug(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = CategoryService(db_session)
|
|
category = await service.get_category_by_slug("arrays")
|
|
|
|
assert category is not None
|
|
assert category.name == "Arrays"
|
|
|
|
async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None:
|
|
service = CategoryService(db_session)
|
|
category = await service.get_category_by_slug("nonexistent")
|
|
|
|
assert category is None
|
|
|
|
|
|
class TestPatternService:
|
|
async def test_get_patterns(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = PatternService(db_session)
|
|
patterns = await service.get_patterns()
|
|
|
|
assert len(patterns) == 2
|
|
|
|
async def test_get_by_slug(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = PatternService(db_session)
|
|
result = await service.get_pattern_by_slug("two-pointers")
|
|
|
|
assert result is not None
|
|
pattern, count = result
|
|
assert pattern.name == "Two Pointers"
|
|
assert count == 2
|
|
|
|
async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None:
|
|
service = PatternService(db_session)
|
|
result = await service.get_pattern_by_slug("nonexistent")
|
|
|
|
assert result is None
|
|
|
|
async def test_get_tutorial(self, db_session: AsyncSession) -> None:
|
|
related = Pattern(name="BFS", slug="bfs", description="Breadth-first search")
|
|
pattern = Pattern(
|
|
name="Two Pointers",
|
|
slug="two-pointers",
|
|
description="Two pointer technique",
|
|
common_mistakes=[
|
|
{"title": "Off by one", "description": "Wrong bounds", "fix": "Use <="}
|
|
],
|
|
variations=[{"name": "Fast/slow", "description": "Different speeds"}],
|
|
related_patterns=["bfs"],
|
|
visualization_examples=[{"id": "ex1", "title": "Basic", "code": "x=1", "steps": []}],
|
|
)
|
|
db_session.add_all([related, pattern])
|
|
await db_session.flush()
|
|
|
|
q1 = Question(
|
|
title="Easy Q", slug="easy-q", difficulty=Difficulty.EASY, description="Easy."
|
|
)
|
|
q2 = Question(title="Med Q", slug="med-q", difficulty=Difficulty.MEDIUM, description="Med.")
|
|
q3 = Question(
|
|
title="Hard Q", slug="hard-q", difficulty=Difficulty.HARD, description="Hard."
|
|
)
|
|
q1.patterns = [pattern]
|
|
q2.patterns = [pattern]
|
|
q3.patterns = [pattern]
|
|
db_session.add_all([q1, q2, q3])
|
|
await db_session.commit()
|
|
|
|
service = PatternService(db_session)
|
|
result = await service.get_pattern_tutorial("two-pointers")
|
|
|
|
assert result is not None
|
|
assert result.name == "Two Pointers"
|
|
assert len(result.common_mistakes) == 1
|
|
assert len(result.variations) == 1
|
|
assert len(result.related_patterns) == 1
|
|
assert result.related_patterns[0].slug == "bfs"
|
|
assert len(result.learning_progression.warmup) == 1
|
|
assert len(result.learning_progression.core) == 1
|
|
assert len(result.learning_progression.challenge) == 1
|
|
assert len(result.visualization_examples) == 1
|
|
|
|
async def test_get_tutorial_not_found(self, db_session: AsyncSession) -> None:
|
|
service = PatternService(db_session)
|
|
result = await service.get_pattern_tutorial("nonexistent")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestStatsService:
|
|
async def test_get_stats(self, db_session: AsyncSession) -> None:
|
|
await create_test_data(db_session)
|
|
service = StatsService(db_session)
|
|
stats = await service.get_stats()
|
|
|
|
assert stats.total_questions == 3
|
|
assert stats.by_difficulty.easy == 1
|
|
assert stats.by_difficulty.medium == 2
|
|
assert stats.by_difficulty.hard == 0
|
|
assert len(stats.by_category) == 2
|
|
assert len(stats.by_pattern) == 2
|
|
|
|
async def test_get_stats_empty(self, db_session: AsyncSession) -> None:
|
|
service = StatsService(db_session)
|
|
stats = await service.get_stats()
|
|
|
|
assert stats.total_questions == 0
|
|
assert stats.by_difficulty.easy == 0
|