Files
codetutor/backend/scripts/load_data.py

311 lines
10 KiB
Python

#!/usr/bin/env python
"""Load YAML content data into the database."""
import asyncio
import sys
from pathlib import Path
from typing import Any
import yaml
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.db.database import async_session_factory
from src.models import Category, Difficulty, Explanation, Pattern, Question, Solution
from src.models.question import QuestionPattern
async def load_categories(session: AsyncSession, data_dir: Path) -> dict[str, Category]:
"""Load categories from YAML file."""
categories_file = data_dir / "categories" / "categories.yaml"
if not categories_file.exists():
print(f"Warning: {categories_file} not found")
return {}
with open(categories_file) as f:
data = yaml.safe_load(f)
categories: dict[str, Category] = {}
for item in data.get("categories", []):
result = await session.execute(select(Category).where(Category.slug == item["slug"]))
existing = result.scalar_one_or_none()
if existing:
existing.name = item["name"]
existing.description = item.get("description")
categories[item["slug"]] = existing
else:
category = Category(
name=item["name"],
slug=item["slug"],
description=item.get("description"),
)
session.add(category)
categories[item["slug"]] = category
await session.flush()
print(f"Loaded {len(categories)} categories")
return categories
async def load_patterns(session: AsyncSession, data_dir: Path) -> dict[str, Pattern]:
"""Load patterns from YAML files.
Supports both:
- Legacy single file: patterns/patterns.yaml
- Individual files: patterns/<slug>.yaml (preferred for tutorials)
"""
patterns_dir = data_dir / "patterns"
patterns: dict[str, Pattern] = {}
# First, try loading individual pattern files (preferred for tutorials)
individual_files = list(patterns_dir.glob("*.yaml"))
# Filter out the legacy patterns.yaml file
individual_files = [f for f in individual_files if f.name != "patterns.yaml"]
if individual_files:
for pattern_file in individual_files:
with open(pattern_file) as f:
item = yaml.safe_load(f)
if not item or "slug" not in item:
print(f" Warning: Skipping {pattern_file.name} - missing slug")
continue
pattern = await _upsert_pattern(session, item)
patterns[item["slug"]] = pattern
print(f" Loaded: {item['name']}")
# Fall back to legacy patterns.yaml if no individual files found
legacy_file = patterns_dir / "patterns.yaml"
if legacy_file.exists():
with open(legacy_file) as f:
data = yaml.safe_load(f)
for item in data.get("patterns", []):
# Skip if already loaded from individual file
if item["slug"] in patterns:
continue
pattern = await _upsert_pattern(session, item)
patterns[item["slug"]] = pattern
await session.flush()
print(f"Loaded {len(patterns)} patterns")
return patterns
async def _upsert_pattern(session: AsyncSession, item: dict[str, Any]) -> Pattern:
"""Insert or update a single pattern from YAML data."""
result = await session.execute(select(Pattern).where(Pattern.slug == item["slug"]))
existing = result.scalar_one_or_none()
if existing:
pattern = existing
else:
pattern = Pattern(slug=item["slug"])
session.add(pattern)
# Core fields
pattern.name = item["name"]
pattern.description = item.get("description")
pattern.when_to_use = item.get("when_to_use")
# Tutorial content fields
pattern.metaphor = item.get("metaphor")
pattern.core_concept = item.get("core_concept")
pattern.visualization = item.get("visualization")
pattern.code_template = item.get("code_template")
# Structured data fields (JSONB)
pattern.recognition_signals = item.get("recognition_signals")
pattern.common_mistakes = item.get("common_mistakes")
pattern.variations = item.get("variations")
pattern.related_patterns = item.get("related_patterns")
pattern.prerequisite_patterns = item.get("prerequisite_patterns")
# Difficulty level
pattern.difficulty_level = item.get("difficulty_level")
# Interactive visualization examples
pattern.visualization_examples = item.get("visualization_examples")
# Pattern classification
pattern.pattern_type = item.get("pattern_type")
pattern.display_order = item.get("display_order")
return pattern
async def load_question(
session: AsyncSession,
question_file: Path,
categories: dict[str, Category],
patterns: dict[str, Pattern],
) -> None:
"""Load a single question from YAML file."""
with open(question_file) as f:
data: dict[str, Any] = yaml.safe_load(f)
slug = data["slug"]
result = await session.execute(
select(Question)
.where(Question.slug == slug)
.options(
selectinload(Question.explanation),
selectinload(Question.solutions),
selectinload(Question.categories),
selectinload(Question.patterns),
)
)
existing = result.scalar_one_or_none()
if existing:
question = existing
question.title = data["title"]
question.difficulty = Difficulty(data["difficulty"].lower())
question.description = data["description"]
question.constraints = data.get("constraints")
question.examples = data.get("examples")
question.leetcode_id = data.get("leetcode_id")
question.leetcode_url = data.get("leetcode_url")
question.function_signature = data.get("function_signature")
question.test_cases = data.get("test_cases")
else:
question = Question(
title=data["title"],
slug=slug,
difficulty=Difficulty(data["difficulty"].lower()),
description=data["description"],
constraints=data.get("constraints"),
examples=data.get("examples"),
leetcode_id=data.get("leetcode_id"),
leetcode_url=data.get("leetcode_url"),
function_signature=data.get("function_signature"),
test_cases=data.get("test_cases"),
)
session.add(question)
# Link categories
question.categories = [
categories[cat_slug] for cat_slug in data.get("categories", []) if cat_slug in categories
]
# Clear existing pattern links to handle is_optimal changes
await session.execute(
sa.delete(QuestionPattern).where(QuestionPattern.question_id == question.id)
)
await session.flush()
# Link patterns with is_optimal support
for pat_entry in data.get("patterns", []):
# Support both formats:
# Old: "heap" (string)
# New: {slug: "heap", is_optimal: true} (dict)
if isinstance(pat_entry, str):
pat_slug = pat_entry
is_optimal = False
else:
pat_slug = pat_entry["slug"]
is_optimal = pat_entry.get("is_optimal", False)
if pat_slug in patterns:
link = QuestionPattern(
question_id=question.id,
pattern_id=patterns[pat_slug].id,
is_optimal=is_optimal,
)
session.add(link)
await session.flush()
# Handle explanation
if "explanation" in data:
exp_data = data["explanation"]
if existing and existing.explanation:
explanation = question.explanation
explanation.approach = exp_data["approach"]
explanation.intuition = exp_data["intuition"]
explanation.common_pitfalls = exp_data.get("common_pitfalls")
explanation.key_takeaways = exp_data.get("key_takeaways")
explanation.time_complexity = exp_data["time_complexity"]
explanation.space_complexity = exp_data["space_complexity"]
explanation.complexity_explanation = exp_data.get("complexity_explanation")
else:
explanation = Explanation(
question_id=question.id,
approach=exp_data["approach"],
intuition=exp_data["intuition"],
common_pitfalls=exp_data.get("common_pitfalls"),
key_takeaways=exp_data.get("key_takeaways"),
time_complexity=exp_data["time_complexity"],
space_complexity=exp_data["space_complexity"],
complexity_explanation=exp_data.get("complexity_explanation"),
)
session.add(explanation)
# Handle solutions (delete existing and recreate)
if existing and existing.solutions:
for sol in existing.solutions:
await session.delete(sol)
await session.flush()
for sol_data in data.get("solutions", []):
solution = Solution(
question_id=question.id,
approach_name=sol_data["approach_name"],
code=sol_data["code"],
language=sol_data.get("language", "python"),
is_optimal=sol_data.get("is_optimal", False),
explanation=sol_data.get("explanation"),
)
session.add(solution)
print(f" Loaded: {data['title']}")
async def load_questions(
session: AsyncSession,
data_dir: Path,
categories: dict[str, Category],
patterns: dict[str, Pattern],
) -> int:
"""Load all questions from YAML files."""
questions_dir = data_dir / "questions"
if not questions_dir.exists():
print(f"Warning: {questions_dir} not found")
return 0
count = 0
for question_file in sorted(questions_dir.glob("*.yaml")):
await load_question(session, question_file, categories, patterns)
count += 1
return count
async def main() -> None:
"""Load all content data into the database."""
data_dir = Path(__file__).parent.parent / "data"
print("Loading content data...")
print(f"Data directory: {data_dir}")
async with async_session_factory() as session:
categories = await load_categories(session, data_dir)
patterns = await load_patterns(session, data_dir)
question_count = await load_questions(session, data_dir, categories, patterns)
await session.commit()
print(f"\nDone! Loaded {question_count} questions.")
if __name__ == "__main__":
asyncio.run(main())