Files
codetutor/backend/tests/test_models.py

157 lines
4.6 KiB
Python

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from src.models import Category, Difficulty, Explanation, Pattern, Question, Solution
async def test_create_category(db_session: AsyncSession) -> None:
category = Category(
name="Arrays",
slug="arrays",
description="Array manipulation problems",
)
db_session.add(category)
await db_session.commit()
result = await db_session.execute(select(Category))
categories = list(result.scalars().all())
assert len(categories) == 1
assert categories[0].name == "Arrays"
assert categories[0].slug == "arrays"
async def test_create_pattern(db_session: AsyncSession) -> None:
pattern = Pattern(
name="Two Pointers",
slug="two-pointers",
description="Use two pointers to traverse data",
when_to_use="Sorted arrays, linked lists",
)
db_session.add(pattern)
await db_session.commit()
result = await db_session.execute(select(Pattern))
patterns = list(result.scalars().all())
assert len(patterns) == 1
assert patterns[0].name == "Two Pointers"
async def test_create_question(db_session: AsyncSession) -> None:
question = Question(
title="Two Sum",
slug="two-sum",
difficulty=Difficulty.EASY,
description="Find two numbers that add up to target.",
leetcode_id=1,
)
db_session.add(question)
await db_session.commit()
result = await db_session.execute(select(Question))
questions = list(result.scalars().all())
assert len(questions) == 1
assert questions[0].title == "Two Sum"
assert questions[0].difficulty == Difficulty.EASY
async def test_question_with_categories(db_session: AsyncSession) -> None:
category = Category(name="Arrays", slug="arrays")
question = Question(
title="Two Sum",
slug="two-sum",
difficulty=Difficulty.EASY,
description="Find two numbers.",
)
question.categories = [category]
db_session.add(question)
await db_session.commit()
result = await db_session.execute(select(Question))
q = result.scalar_one()
assert len(q.categories) == 1
assert q.categories[0].name == "Arrays"
async def test_question_with_patterns(db_session: AsyncSession) -> None:
pattern = Pattern(name="Two Pointers", slug="two-pointers")
question = Question(
title="Two Sum",
slug="two-sum",
difficulty=Difficulty.EASY,
description="Find two numbers.",
)
question.patterns = [pattern]
db_session.add(question)
await db_session.commit()
result = await db_session.execute(select(Question))
q = result.scalar_one()
assert len(q.patterns) == 1
assert q.patterns[0].name == "Two Pointers"
async def test_question_with_explanation(db_session: AsyncSession) -> None:
question = Question(
title="Two Sum",
slug="two-sum",
difficulty=Difficulty.EASY,
description="Find two numbers.",
)
db_session.add(question)
await db_session.flush()
explanation = Explanation(
question_id=question.id,
approach="Use hash map",
intuition="Store seen values",
time_complexity="O(n)",
space_complexity="O(n)",
)
db_session.add(explanation)
await db_session.commit()
result = await db_session.execute(
select(Question)
.options(selectinload(Question.explanation))
.where(Question.id == question.id)
)
loaded_question = result.scalar_one()
assert loaded_question.explanation is not None
assert loaded_question.explanation.time_complexity == "O(n)"
async def test_question_with_solutions(db_session: AsyncSession) -> None:
question = Question(
title="Two Sum",
slug="two-sum",
difficulty=Difficulty.EASY,
description="Find two numbers.",
)
db_session.add(question)
await db_session.flush()
solution = Solution(
question_id=question.id,
approach_name="Hash Map",
code="def solve(): pass",
is_optimal=True,
)
db_session.add(solution)
await db_session.commit()
result = await db_session.execute(
select(Question).options(selectinload(Question.solutions)).where(Question.id == question.id)
)
loaded_question = result.scalar_one()
assert len(loaded_question.solutions) == 1
assert loaded_question.solutions[0].is_optimal is True
async def test_difficulty_enum() -> None:
assert Difficulty.EASY.value == "easy"
assert Difficulty.MEDIUM.value == "medium"
assert Difficulty.HARD.value == "hard"