diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..4cad880 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,91 @@ +import asyncio +from collections.abc import AsyncGenerator, Generator +from typing import Any + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import JSON, event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from src.db.database import get_db +from src.main import app +from src.models import Base + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + + +def patch_postgres_types() -> None: + from sqlalchemy.dialects.sqlite import base as sqlite_base + + def visit_JSONB( + self: Any, + type_: Any, # noqa: ARG001 + **kw: Any, + ) -> str: + return str(self.visit_JSON(JSON(), **kw)) + + sqlite_base.SQLiteTypeCompiler.visit_JSONB = visit_JSONB # type: ignore[attr-defined] + + def visit_UUID( + self: Any, # noqa: ARG001 + type_: Any, # noqa: ARG001 + **kw: Any, # noqa: ARG001 + ) -> str: + return "VARCHAR(36)" + + sqlite_base.SQLiteTypeCompiler.visit_UUID = visit_UUID # type: ignore[method-assign] + + +patch_postgres_types() + +engine = create_async_engine(TEST_DATABASE_URL, echo=False) + + +@event.listens_for(engine.sync_engine, "connect") +def set_sqlite_pragma(dbapi_connection: Any, _connection_record: Any) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + +async_session_factory = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +@pytest.fixture(scope="session") +def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True) +async def setup_database() -> AsyncGenerator[None, None]: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + async with async_session_factory() as session: + yield session + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as test_client: + yield test_client + + app.dependency_overrides.clear() diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py new file mode 100644 index 0000000..19a82fe --- /dev/null +++ b/backend/tests/test_api.py @@ -0,0 +1,179 @@ +from typing import Any + +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from src.models import Category, Difficulty, Explanation, Pattern, Question, Solution + +SampleData = dict[str, Any] + + +async def create_sample_data(db_session: AsyncSession) -> SampleData: + category = Category(name="Arrays", slug="arrays", description="Array problems") + db_session.add(category) + + pattern = Pattern( + name="Two Pointers", + slug="two-pointers", + description="Two pointer technique", + ) + db_session.add(pattern) + + await db_session.flush() + + question = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers that add up to target.", + leetcode_id=1, + leetcode_url="https://leetcode.com/problems/two-sum/", + ) + question.categories = [category] + question.patterns = [pattern] + db_session.add(question) + + await db_session.flush() + + explanation = Explanation( + question_id=question.id, + approach="Use a hash map.", + intuition="Store seen values for O(1) lookup.", + time_complexity="O(n)", + space_complexity="O(n)", + ) + db_session.add(explanation) + + solution = Solution( + question_id=question.id, + approach_name="Hash Map", + code="def two_sum(nums, target): pass", + is_optimal=True, + ) + db_session.add(solution) + + await db_session.commit() + + return {"question": question, "category": category, "pattern": pattern} + + +async def test_health_check(client: AsyncClient) -> None: + response = await client.get("/api/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +async def test_list_questions_empty(client: AsyncClient) -> None: + response = await client.get("/api/questions") + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["total"] == 0 + + +async def test_list_questions(client: AsyncClient, db_session: AsyncSession) -> None: + await create_sample_data(db_session) + response = await client.get("/api/questions") + assert response.status_code == 200 + data = response.json() + assert data["total"] == 1 + assert len(data["items"]) == 1 + assert data["items"][0]["title"] == "Two Sum" + assert data["items"][0]["difficulty"] == "easy" + + +async def test_list_questions_filter_difficulty( + client: AsyncClient, db_session: AsyncSession +) -> None: + await create_sample_data(db_session) + response = await client.get("/api/questions?difficulty=easy") + assert response.status_code == 200 + assert response.json()["total"] == 1 + + response = await client.get("/api/questions?difficulty=hard") + assert response.status_code == 200 + assert response.json()["total"] == 0 + + +async def test_list_questions_invalid_difficulty(client: AsyncClient) -> None: + response = await client.get("/api/questions?difficulty=invalid") + assert response.status_code == 400 + + +async def test_get_question(client: AsyncClient, db_session: AsyncSession) -> None: + await create_sample_data(db_session) + response = await client.get("/api/questions/two-sum") + assert response.status_code == 200 + data = response.json() + assert data["title"] == "Two Sum" + assert data["explanation"]["time_complexity"] == "O(n)" + assert len(data["solutions"]) == 1 + + +async def test_get_question_not_found(client: AsyncClient) -> None: + response = await client.get("/api/questions/nonexistent") + assert response.status_code == 404 + + +async def test_list_categories_empty(client: AsyncClient) -> None: + response = await client.get("/api/categories") + assert response.status_code == 200 + assert response.json()["items"] == [] + + +async def test_list_categories(client: AsyncClient, db_session: AsyncSession) -> None: + await create_sample_data(db_session) + response = await client.get("/api/categories") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 1 + assert data["items"][0]["name"] == "Arrays" + assert data["items"][0]["question_count"] == 1 + + +async def test_list_patterns_empty(client: AsyncClient) -> None: + response = await client.get("/api/patterns") + assert response.status_code == 200 + assert response.json()["items"] == [] + + +async def test_list_patterns(client: AsyncClient, db_session: AsyncSession) -> None: + await create_sample_data(db_session) + response = await client.get("/api/patterns") + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 1 + assert data["items"][0]["name"] == "Two Pointers" + + +async def test_get_pattern(client: AsyncClient, db_session: AsyncSession) -> None: + await create_sample_data(db_session) + response = await client.get("/api/patterns/two-pointers") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Two Pointers" + assert data["question_count"] == 1 + + +async def test_get_pattern_not_found(client: AsyncClient) -> None: + response = await client.get("/api/patterns/nonexistent") + assert response.status_code == 404 + + +async def test_get_stats_empty(client: AsyncClient) -> None: + response = await client.get("/api/stats") + assert response.status_code == 200 + data = response.json() + assert data["total_questions"] == 0 + assert data["by_difficulty"]["easy"] == 0 + + +async def test_get_stats(client: AsyncClient, db_session: AsyncSession) -> None: + await create_sample_data(db_session) + response = await client.get("/api/stats") + assert response.status_code == 200 + data = response.json() + assert data["total_questions"] == 1 + assert data["by_difficulty"]["easy"] == 1 + assert len(data["by_category"]) == 1 + assert len(data["by_pattern"]) == 1 diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py new file mode 100644 index 0000000..094ee14 --- /dev/null +++ b/backend/tests/test_models.py @@ -0,0 +1,156 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from src.models import Category, Difficulty, Explanation, Pattern, Question, Solution + + +async def test_create_category(db_session: AsyncSession) -> None: + category = Category( + name="Arrays", + slug="arrays", + description="Array manipulation problems", + ) + db_session.add(category) + await db_session.commit() + + result = await db_session.execute(select(Category)) + categories = list(result.scalars().all()) + assert len(categories) == 1 + assert categories[0].name == "Arrays" + assert categories[0].slug == "arrays" + + +async def test_create_pattern(db_session: AsyncSession) -> None: + pattern = Pattern( + name="Two Pointers", + slug="two-pointers", + description="Use two pointers to traverse data", + when_to_use="Sorted arrays, linked lists", + ) + db_session.add(pattern) + await db_session.commit() + + result = await db_session.execute(select(Pattern)) + patterns = list(result.scalars().all()) + assert len(patterns) == 1 + assert patterns[0].name == "Two Pointers" + + +async def test_create_question(db_session: AsyncSession) -> None: + question = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers that add up to target.", + leetcode_id=1, + ) + db_session.add(question) + await db_session.commit() + + result = await db_session.execute(select(Question)) + questions = list(result.scalars().all()) + assert len(questions) == 1 + assert questions[0].title == "Two Sum" + assert questions[0].difficulty == Difficulty.EASY + + +async def test_question_with_categories(db_session: AsyncSession) -> None: + category = Category(name="Arrays", slug="arrays") + question = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers.", + ) + question.categories = [category] + + db_session.add(question) + await db_session.commit() + + result = await db_session.execute(select(Question)) + q = result.scalar_one() + assert len(q.categories) == 1 + assert q.categories[0].name == "Arrays" + + +async def test_question_with_patterns(db_session: AsyncSession) -> None: + pattern = Pattern(name="Two Pointers", slug="two-pointers") + question = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers.", + ) + question.patterns = [pattern] + + db_session.add(question) + await db_session.commit() + + result = await db_session.execute(select(Question)) + q = result.scalar_one() + assert len(q.patterns) == 1 + assert q.patterns[0].name == "Two Pointers" + + +async def test_question_with_explanation(db_session: AsyncSession) -> None: + question = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers.", + ) + db_session.add(question) + await db_session.flush() + + explanation = Explanation( + question_id=question.id, + approach="Use hash map", + intuition="Store seen values", + time_complexity="O(n)", + space_complexity="O(n)", + ) + db_session.add(explanation) + await db_session.commit() + + result = await db_session.execute( + select(Question) + .options(selectinload(Question.explanation)) + .where(Question.id == question.id) + ) + loaded_question = result.scalar_one() + assert loaded_question.explanation is not None + assert loaded_question.explanation.time_complexity == "O(n)" + + +async def test_question_with_solutions(db_session: AsyncSession) -> None: + question = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers.", + ) + db_session.add(question) + await db_session.flush() + + solution = Solution( + question_id=question.id, + approach_name="Hash Map", + code="def solve(): pass", + is_optimal=True, + ) + db_session.add(solution) + await db_session.commit() + + result = await db_session.execute( + select(Question).options(selectinload(Question.solutions)).where(Question.id == question.id) + ) + loaded_question = result.scalar_one() + assert len(loaded_question.solutions) == 1 + assert loaded_question.solutions[0].is_optimal is True + + +async def test_difficulty_enum() -> None: + assert Difficulty.EASY.value == "easy" + assert Difficulty.MEDIUM.value == "medium" + assert Difficulty.HARD.value == "hard" diff --git a/backend/tests/test_services.py b/backend/tests/test_services.py new file mode 100644 index 0000000..c92207e --- /dev/null +++ b/backend/tests/test_services.py @@ -0,0 +1,237 @@ +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from src.models import Category, Difficulty, Pattern, Question +from src.services import CategoryService, PatternService, QuestionService, StatsService + +TestData = dict[str, Any] + + +async def create_test_data(db_session: AsyncSession) -> TestData: + cat_arrays = Category(name="Arrays", slug="arrays") + cat_strings = Category(name="Strings", slug="strings") + db_session.add_all([cat_arrays, cat_strings]) + + pat_two_pointers = Pattern(name="Two Pointers", slug="two-pointers") + pat_sliding = Pattern(name="Sliding Window", slug="sliding-window") + db_session.add_all([pat_two_pointers, pat_sliding]) + + await db_session.flush() + + q1 = Question( + title="Two Sum", + slug="two-sum", + difficulty=Difficulty.EASY, + description="Find two numbers.", + ) + q1.categories = [cat_arrays] + q1.patterns = [pat_two_pointers] + + q2 = Question( + title="Three Sum", + slug="three-sum", + difficulty=Difficulty.MEDIUM, + description="Find three numbers.", + ) + q2.categories = [cat_arrays] + q2.patterns = [pat_two_pointers] + + q3 = Question( + title="Longest Substring", + slug="longest-substring", + difficulty=Difficulty.MEDIUM, + description="Find longest substring.", + ) + q3.categories = [cat_strings] + q3.patterns = [pat_sliding] + + db_session.add_all([q1, q2, q3]) + await db_session.commit() + + return { + "cat_arrays": cat_arrays, + "cat_strings": cat_strings, + "pat_two_pointers": pat_two_pointers, + "pat_sliding": pat_sliding, + "q1": q1, + "q2": q2, + "q3": q3, + } + + +class TestQuestionService: + async def test_get_questions(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions() + + assert total == 3 + assert len(questions) == 3 + + async def test_pagination(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(page=1, limit=2) + + assert total == 3 + assert len(questions) == 2 + + async def test_filter_difficulty(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(difficulties=[Difficulty.EASY]) + + assert total == 1 + assert questions[0].title == "Two Sum" + + async def test_filter_category(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(category_slug="arrays") + + assert total == 2 + + async def test_filter_pattern(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(pattern_slug="sliding-window") + + assert total == 1 + assert questions[0].title == "Longest Substring" + + async def test_search(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(search="Sum") + + assert total == 2 + + async def test_get_by_slug(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + question = await service.get_question_by_slug("two-sum") + + assert question is not None + assert question.title == "Two Sum" + + async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None: + service = QuestionService(db_session) + question = await service.get_question_by_slug("nonexistent") + + assert question is None + + async def test_get_by_id(self, db_session: AsyncSession) -> None: + data = await create_test_data(db_session) + service = QuestionService(db_session) + question = await service.get_question_by_id(data["q1"].id) + + assert question is not None + assert question.title == "Two Sum" + + async def test_get_by_id_not_found(self, db_session: AsyncSession) -> None: + from uuid import uuid4 + + service = QuestionService(db_session) + question = await service.get_question_by_id(uuid4()) + + assert question is None + + async def test_combined_filters(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions( + difficulties=[Difficulty.MEDIUM], + category_slug="arrays", + ) + + assert total == 1 + assert questions[0].title == "Three Sum" + + async def test_pagination_page_two(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(page=2, limit=2) + + assert total == 3 + assert len(questions) == 1 + + async def test_pagination_beyond_data(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = QuestionService(db_session) + questions, total = await service.get_questions(page=10, limit=20) + + assert total == 3 + assert len(questions) == 0 + + +class TestCategoryService: + async def test_get_categories(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = CategoryService(db_session) + categories = await service.get_categories() + + assert len(categories) == 2 + + arrays = next((c, cnt) for c, cnt in categories if c.slug == "arrays") + assert arrays[1] == 2 + + async def test_get_by_slug(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = CategoryService(db_session) + category = await service.get_category_by_slug("arrays") + + assert category is not None + assert category.name == "Arrays" + + async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None: + service = CategoryService(db_session) + category = await service.get_category_by_slug("nonexistent") + + assert category is None + + +class TestPatternService: + async def test_get_patterns(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = PatternService(db_session) + patterns = await service.get_patterns() + + assert len(patterns) == 2 + + async def test_get_by_slug(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = PatternService(db_session) + result = await service.get_pattern_by_slug("two-pointers") + + assert result is not None + pattern, count = result + assert pattern.name == "Two Pointers" + assert count == 2 + + async def test_get_by_slug_not_found(self, db_session: AsyncSession) -> None: + service = PatternService(db_session) + result = await service.get_pattern_by_slug("nonexistent") + + assert result is None + + +class TestStatsService: + async def test_get_stats(self, db_session: AsyncSession) -> None: + await create_test_data(db_session) + service = StatsService(db_session) + stats = await service.get_stats() + + assert stats.total_questions == 3 + assert stats.by_difficulty.easy == 1 + assert stats.by_difficulty.medium == 2 + assert stats.by_difficulty.hard == 0 + assert len(stats.by_category) == 2 + assert len(stats.by_pattern) == 2 + + async def test_get_stats_empty(self, db_session: AsyncSession) -> None: + service = StatsService(db_session) + stats = await service.get_stats() + + assert stats.total_questions == 0 + assert stats.by_difficulty.easy == 0