diff --git a/tests/unit/test_repository.py b/tests/unit/test_repository.py new file mode 100644 index 0000000..1732eba --- /dev/null +++ b/tests/unit/test_repository.py @@ -0,0 +1,295 @@ +"""Unit tests for data repository.""" + +import tempfile +from pathlib import Path +from uuid import uuid4 + +import pandas as pd +import pytest + +from py_dvt_ate.data.models import Measurement, TestStatus +from py_dvt_ate.data.repository import SQLiteRepository + + +@pytest.fixture +def temp_db(): + """Create a temporary database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.db" + yield db_path + + +@pytest.fixture +def repository(temp_db): + """Create a repository instance for testing.""" + return SQLiteRepository(temp_db) + + +def test_create_run(repository): + """Test creating a new test run.""" + config = {"temperature": 25.0, "voltage": 3.3} + run_id = repository.create_run( + test_name="TempCo Test", + config=config, + operator="Test Engineer", + description="Test description", + ) + + assert run_id is not None + + # Verify run was created + run = repository.get_run(run_id) + assert run.test_name == "TempCo Test" + assert run.operator == "Test Engineer" + assert run.description == "Test description" + assert run.status == TestStatus.PENDING + + +def test_update_run_status(repository): + """Test updating test run status.""" + run_id = repository.create_run("Test", config={}) + + repository.update_run_status(run_id, TestStatus.RUNNING) + run = repository.get_run(run_id) + assert run.status == TestStatus.RUNNING + + repository.update_run_status(run_id, TestStatus.PASSED) + run = repository.get_run(run_id) + assert run.status == TestStatus.PASSED + + +def test_complete_run(repository): + """Test completing a test run.""" + run_id = repository.create_run("Test", config={}) + + repository.complete_run(run_id, TestStatus.PASSED) + run = repository.get_run(run_id) + + assert run.status == TestStatus.PASSED + assert run.completed_at is not None + + +def test_save_result(repository): + """Test saving a test result.""" + run_id = repository.create_run("Test", config={}) + + repository.save_result( + run_id=run_id, + parameter="output_voltage", + value=3.305, + unit="V", + lower_limit=3.267, + upper_limit=3.333, + ) + + results = repository.get_results(run_id) + assert len(results) == 1 + + result = results[0] + assert result.parameter == "output_voltage" + assert result.value == 3.305 + assert result.unit == "V" + assert result.lower_limit == 3.267 + assert result.upper_limit == 3.333 + assert result.passed is True + + +def test_save_result_fail(repository): + """Test saving a failing test result.""" + run_id = repository.create_run("Test", config={}) + + repository.save_result( + run_id=run_id, + parameter="output_voltage", + value=3.350, # Outside upper limit + unit="V", + lower_limit=3.267, + upper_limit=3.333, + ) + + results = repository.get_results(run_id) + result = results[0] + assert result.passed is False + + +def test_save_result_no_limits(repository): + """Test saving a result without limits.""" + run_id = repository.create_run("Test", config={}) + + repository.save_result( + run_id=run_id, + parameter="temperature", + value=25.5, + unit="°C", + ) + + results = repository.get_results(run_id) + result = results[0] + assert result.passed is None # No limits defined + + +def test_save_measurements(repository): + """Test saving time-series measurements to Parquet.""" + run_id = repository.create_run("Test", config={}) + + measurements = [ + Measurement( + timestamp=1234567890.0, + parameter="voltage", + value=3.3, + unit="V", + temperature=25.0, + input_voltage=5.0, + load_current=0.1, + ), + Measurement( + timestamp=1234567891.0, + parameter="voltage", + value=3.31, + unit="V", + temperature=25.1, + input_voltage=5.0, + load_current=0.1, + ), + ] + + repository.save_measurements(run_id, measurements) + + # Verify measurements were saved + df = repository.get_measurements_dataframe(run_id) + assert df is not None + assert len(df) == 2 + assert list(df["parameter"]) == ["voltage", "voltage"] + assert list(df["value"]) == [3.3, 3.31] + + +def test_save_measurements_append(repository): + """Test appending measurements to existing Parquet file.""" + run_id = repository.create_run("Test", config={}) + + # Save first batch + measurements1 = [ + Measurement( + timestamp=1234567890.0, + parameter="voltage", + value=3.3, + unit="V", + ) + ] + repository.save_measurements(run_id, measurements1) + + # Save second batch + measurements2 = [ + Measurement( + timestamp=1234567891.0, + parameter="voltage", + value=3.31, + unit="V", + ) + ] + repository.save_measurements(run_id, measurements2) + + # Verify both batches are present + df = repository.get_measurements_dataframe(run_id) + assert df is not None + assert len(df) == 2 + + +def test_get_measurements_nonexistent(repository): + """Test getting measurements for non-existent run.""" + fake_id = uuid4() + df = repository.get_measurements_dataframe(fake_id) + assert df is None + + +def test_save_empty_measurements(repository): + """Test saving empty measurement list.""" + run_id = repository.create_run("Test", config={}) + repository.save_measurements(run_id, []) + + df = repository.get_measurements_dataframe(run_id) + assert df is None + + +def test_get_nonexistent_run(repository): + """Test getting a non-existent run raises error.""" + fake_id = uuid4() + with pytest.raises(ValueError, match="not found"): + repository.get_run(fake_id) + + +def test_multiple_results(repository): + """Test saving and retrieving multiple results.""" + run_id = repository.create_run("Test", config={}) + + repository.save_result(run_id, "voltage", 3.3, "V") + repository.save_result(run_id, "current", 50.0, "uA") + repository.save_result(run_id, "temperature", 25.0, "°C") + + results = repository.get_results(run_id) + assert len(results) == 3 + + parameters = {r.parameter for r in results} + assert parameters == {"voltage", "current", "temperature"} + + +def test_custom_measurements_dir(temp_db): + """Test using a custom measurements directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + measurements_dir = Path(tmpdir) / "custom_measurements" + repo = SQLiteRepository(temp_db, measurements_dir=measurements_dir) + + run_id = repo.create_run("Test", config={}) + measurements = [ + Measurement( + timestamp=1234567890.0, + parameter="voltage", + value=3.3, + unit="V", + ) + ] + repo.save_measurements(run_id, measurements) + + # Verify file is in custom directory + expected_path = measurements_dir / f"run_{run_id}" / "measurements.parquet" + assert expected_path.exists() + + +def test_parquet_schema(repository): + """Test that Parquet file has correct schema.""" + run_id = repository.create_run("Test", config={}) + + measurements = [ + Measurement( + timestamp=1234567890.123, + parameter="voltage", + value=3.3, + unit="V", + temperature=25.5, + input_voltage=5.0, + load_current=0.1, + ) + ] + repository.save_measurements(run_id, measurements) + + df = repository.get_measurements_dataframe(run_id) + assert df is not None + + # Check columns + expected_columns = { + "timestamp", + "parameter", + "value", + "unit", + "temperature", + "input_voltage", + "load_current", + } + assert set(df.columns) == expected_columns + + # Check data types (approximately) + assert pd.api.types.is_float_dtype(df["timestamp"]) + assert pd.api.types.is_string_dtype(df["parameter"]) or pd.api.types.is_object_dtype( + df["parameter"] + ) + assert pd.api.types.is_float_dtype(df["value"])