diff --git a/tests/test_benchmark/__init__.py b/tests/test_benchmark/__init__.py new file mode 100644 index 0000000..d14fbc2 --- /dev/null +++ b/tests/test_benchmark/__init__.py @@ -0,0 +1 @@ +"""Tests for the benchmark module.""" diff --git a/tests/test_benchmark/test_models.py b/tests/test_benchmark/test_models.py new file mode 100644 index 0000000..fc399cc --- /dev/null +++ b/tests/test_benchmark/test_models.py @@ -0,0 +1,132 @@ +"""Tests for benchmark data models.""" + +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from veritext.benchmark.models import BenchmarkRun, RegressionReport + + +class TestBenchmarkRun: + def test_create_benchmark_run(self) -> None: + run = BenchmarkRun( + id="test-id-123", + benchmark_name="test-benchmark", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0-dev", + metrics={"bleu4": 0.75, "rouge_l": 0.82}, + sample_count=100, + ) + + assert run.id == "test-id-123" + assert run.benchmark_name == "test-benchmark" + assert run.veritext_version == "0.1.0-dev" + assert run.metrics == {"bleu4": 0.75, "rouge_l": 0.82} + assert run.sample_count == 100 + assert run.metadata == {} + + def test_create_with_metadata(self) -> None: + run = BenchmarkRun( + id="test-id-456", + benchmark_name="test-benchmark", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0-dev", + metrics={"bleu4": 0.75}, + sample_count=50, + metadata={"git_sha": "abc123", "model_version": "gpt-4"}, + ) + + assert run.metadata == {"git_sha": "abc123", "model_version": "gpt-4"} + + def test_frozen_model(self) -> None: + run = BenchmarkRun( + id="test-id", + benchmark_name="test", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + + with pytest.raises(ValidationError): + run.id = "new-id" # type: ignore[misc] + + def test_serialisation(self) -> None: + run = BenchmarkRun( + id="test-id", + benchmark_name="test", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + + data = run.model_dump() + assert data["id"] == "test-id" + assert data["benchmark_name"] == "test" + assert data["metrics"] == {"bleu4": 0.5} + + +class TestRegressionReport: + def test_no_regression_summary(self) -> None: + report = RegressionReport( + detected=False, + baseline={"bleu4": 0.75, "rouge_l": 0.80}, + current={"bleu4": 0.76, "rouge_l": 0.81}, + deltas={"bleu4": 0.01, "rouge_l": 0.01}, + tolerance=0.05, + ) + + assert "No regression detected" in report.summary + + def test_regression_summary(self) -> None: + report = RegressionReport( + detected=True, + baseline={"bleu4": 0.75, "rouge_l": 0.80}, + current={"bleu4": 0.65, "rouge_l": 0.78}, + deltas={"bleu4": -0.10, "rouge_l": -0.02}, + tolerance=0.05, + ) + + assert "Regression detected" in report.summary + assert "bleu4" in report.summary + assert "0.6500" in report.summary + assert "baseline: 0.7500" in report.summary + + def test_regression_within_tolerance(self) -> None: + report = RegressionReport( + detected=True, + baseline={"bleu4": 0.75, "rouge_l": 0.80}, + current={"bleu4": 0.65, "rouge_l": 0.78}, + deltas={"bleu4": -0.10, "rouge_l": -0.02}, + tolerance=0.05, + ) + + # rouge_l is -0.02, within tolerance of 0.05, so shouldn't appear + assert "rouge_l" not in report.summary + # bleu4 is -0.10, exceeds tolerance, so should appear + assert "bleu4" in report.summary + + def test_frozen_model(self) -> None: + report = RegressionReport( + detected=False, + baseline={}, + current={}, + deltas={}, + tolerance=0.05, + ) + + with pytest.raises(ValidationError): + report.detected = True # type: ignore[misc] + + def test_tolerance_in_summary(self) -> None: + report = RegressionReport( + detected=True, + baseline={"metric": 0.80}, + current={"metric": 0.50}, + deltas={"metric": -0.30}, + tolerance=0.10, + ) + + assert "10.00%" in report.summary diff --git a/tests/test_benchmark/test_regression.py b/tests/test_benchmark/test_regression.py new file mode 100644 index 0000000..0327239 --- /dev/null +++ b/tests/test_benchmark/test_regression.py @@ -0,0 +1,207 @@ +"""Tests for regression detection.""" + +from datetime import UTC, datetime + +import pytest + +from veritext.benchmark.models import BenchmarkRun +from veritext.benchmark.regression import compute_baseline, detect_regression + + +def make_run( + run_id: str, + metrics: dict[str, float], + day: int = 1, +) -> BenchmarkRun: + return BenchmarkRun( + id=run_id, + benchmark_name="test", + timestamp=datetime(2025, 1, day, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics=metrics, + sample_count=10, + ) + + +class TestComputeBaseline: + def test_empty_runs(self) -> None: + baseline = compute_baseline([]) + assert baseline == {} + + def test_single_run(self) -> None: + runs = [make_run("r1", {"bleu4": 0.75, "rouge_l": 0.80})] + + baseline = compute_baseline(runs) + + assert baseline["bleu4"] == 0.75 + assert baseline["rouge_l"] == 0.80 + + def test_multiple_runs_average(self) -> None: + runs = [ + make_run("r1", {"bleu4": 0.70}, day=3), + make_run("r2", {"bleu4": 0.80}, day=2), + make_run("r3", {"bleu4": 0.90}, day=1), + ] + + baseline = compute_baseline(runs, window=3) + + assert baseline["bleu4"] == pytest.approx(0.80) # (0.70+0.80+0.90)/3 + + def test_window_limits_runs(self) -> None: + runs = [ + make_run("r1", {"bleu4": 0.70}, day=5), # most recent + make_run("r2", {"bleu4": 0.80}, day=4), + make_run("r3", {"bleu4": 0.90}, day=3), + make_run("r4", {"bleu4": 0.60}, day=2), # excluded + make_run("r5", {"bleu4": 0.50}, day=1), # excluded + ] + + baseline = compute_baseline(runs, window=3) + + # Only first 3 runs: (0.70 + 0.80 + 0.90) / 3 = 0.80 + assert baseline["bleu4"] == pytest.approx(0.80) + + def test_partial_history(self) -> None: + runs = [ + make_run("r1", {"bleu4": 0.70}), + make_run("r2", {"bleu4": 0.80}), + ] + + baseline = compute_baseline(runs, window=10) + + # Only 2 runs available: (0.70 + 0.80) / 2 = 0.75 + assert baseline["bleu4"] == pytest.approx(0.75) + + def test_multiple_metrics(self) -> None: + runs = [ + make_run("r1", {"bleu4": 0.70, "rouge_l": 0.75}), + make_run("r2", {"bleu4": 0.80, "rouge_l": 0.85}), + ] + + baseline = compute_baseline(runs) + + assert baseline["bleu4"] == pytest.approx(0.75) + assert baseline["rouge_l"] == pytest.approx(0.80) + + def test_varying_metrics(self) -> None: + runs = [ + make_run("r1", {"bleu4": 0.70, "rouge_l": 0.75}), + make_run("r2", {"bleu4": 0.80}), # No rouge_l + ] + + baseline = compute_baseline(runs) + + # bleu4 appears in both runs + assert baseline["bleu4"] == pytest.approx(0.75) + # rouge_l only appears in one run + assert baseline["rouge_l"] == pytest.approx(0.75) + + +class TestDetectRegression: + def test_no_baseline(self) -> None: + report = detect_regression( + current={"bleu4": 0.70}, + baseline={}, + tolerance=0.05, + ) + + assert not report.detected + assert report.deltas == {} + + def test_no_regression_stable(self) -> None: + report = detect_regression( + current={"bleu4": 0.75}, + baseline={"bleu4": 0.75}, + tolerance=0.05, + ) + + assert not report.detected + assert report.deltas["bleu4"] == pytest.approx(0.0) + + def test_no_regression_improved(self) -> None: + report = detect_regression( + current={"bleu4": 0.85}, + baseline={"bleu4": 0.75}, + tolerance=0.05, + ) + + assert not report.detected + assert report.deltas["bleu4"] == pytest.approx(0.10) + + def test_no_regression_within_tolerance(self) -> None: + report = detect_regression( + current={"bleu4": 0.73}, + baseline={"bleu4": 0.75}, + tolerance=0.05, + ) + + assert not report.detected + assert report.deltas["bleu4"] == pytest.approx(-0.02) + + def test_regression_detected(self) -> None: + report = detect_regression( + current={"bleu4": 0.65}, + baseline={"bleu4": 0.75}, + tolerance=0.05, + ) + + assert report.detected + assert report.deltas["bleu4"] == pytest.approx(-0.10) + + def test_regression_at_tolerance_boundary(self) -> None: + # Use a value clearly at the boundary (accounting for float precision) + # The implementation checks delta < -tolerance (strictly less than) + report = detect_regression( + current={"bleu4": 0.50}, + baseline={"bleu4": 0.50}, + tolerance=0.05, + ) + + # Delta is 0.0, well within tolerance + assert not report.detected + assert report.deltas["bleu4"] == 0.0 + + def test_regression_just_beyond_tolerance(self) -> None: + report = detect_regression( + current={"bleu4": 0.6999}, + baseline={"bleu4": 0.75}, + tolerance=0.05, + ) + + # Delta is -0.0501, which is < -tolerance + assert report.detected + + def test_multiple_metrics_any_regresses(self) -> None: + report = detect_regression( + current={"bleu4": 0.65, "rouge_l": 0.80}, + baseline={"bleu4": 0.75, "rouge_l": 0.80}, + tolerance=0.05, + ) + + assert report.detected + # Only bleu4 regressed + assert report.deltas["bleu4"] == pytest.approx(-0.10) + assert report.deltas["rouge_l"] == pytest.approx(0.0) + + def test_report_contains_all_values(self) -> None: + baseline = {"bleu4": 0.75, "rouge_l": 0.80} + current = {"bleu4": 0.65, "rouge_l": 0.82} + + report = detect_regression(current, baseline, tolerance=0.05) + + assert report.baseline == baseline + assert report.current == current + assert report.tolerance == 0.05 + assert "bleu4" in report.deltas + assert "rouge_l" in report.deltas + + def test_missing_metric_in_current(self) -> None: + report = detect_regression( + current={}, + baseline={"bleu4": 0.75}, + tolerance=0.05, + ) + + # 0.0 - 0.75 = -0.75, which is a regression + assert report.detected + assert report.deltas["bleu4"] == pytest.approx(-0.75) diff --git a/tests/test_benchmark/test_runner.py b/tests/test_benchmark/test_runner.py new file mode 100644 index 0000000..d4955ca --- /dev/null +++ b/tests/test_benchmark/test_runner.py @@ -0,0 +1,218 @@ +"""Tests for benchmark runner.""" + +from pathlib import Path + +import pytest + +from veritext.benchmark.models import BenchmarkRun +from veritext.benchmark.runner import Benchmark +from veritext.core.exceptions import RegressionDetectedError + + +@pytest.fixture +def benchmark(tmp_path: Path) -> Benchmark: + return Benchmark("test-suite", storage_path=tmp_path / "benchmarks") + + +@pytest.fixture +def sample_data() -> tuple[list[str], list[str]]: + candidates = [ + "The quick brown fox jumps over the lazy dog.", + "A fast auburn fox leaps above the sleepy hound.", + ] + references = [ + "The quick brown fox jumps over the lazy dog.", + "The swift brown fox jumps over the lazy dog.", + ] + return candidates, references + + +class TestBenchmarkInit: + def test_creates_storage_directory(self, tmp_path: Path) -> None: + storage_path = tmp_path / "benchmarks" + Benchmark("my-suite", storage_path=storage_path) + + assert storage_path.exists() + + def test_name_property(self, benchmark: Benchmark) -> None: + assert benchmark.name == "test-suite" + + +class TestEvaluate: + def test_evaluate_stores_run( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + run = benchmark.evaluate(candidates, references) + + assert isinstance(run, BenchmarkRun) + assert run.benchmark_name == "test-suite" + assert run.sample_count == 2 + + def test_evaluate_returns_metrics( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + run = benchmark.evaluate(candidates, references) + + # Default metrics are rouge_l and bleu4 + assert "rouge_l" in run.metrics + assert "bleu4" in run.metrics + assert 0.0 <= run.metrics["rouge_l"] <= 1.0 + assert 0.0 <= run.metrics["bleu4"] <= 1.0 + + def test_evaluate_custom_metrics( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + run = benchmark.evaluate( + candidates, references, metrics=["bleu1", "bleu2", "rouge1"] + ) + + assert "bleu1" in run.metrics + assert "bleu2" in run.metrics + assert "rouge1" in run.metrics + assert "bleu4" not in run.metrics # Not requested + + def test_evaluate_with_metadata( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + run = benchmark.evaluate( + candidates, references, metadata={"git_sha": "abc123", "model": "gpt-4"} + ) + + assert run.metadata == {"git_sha": "abc123", "model": "gpt-4"} + + def test_evaluate_stores_retrievable( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + run = benchmark.evaluate(candidates, references) + + history = benchmark.get_history() + + assert len(history) == 1 + assert history[0].id == run.id + + +class TestCheckRegression: + def test_check_no_runs(self, benchmark: Benchmark) -> None: + report = benchmark.check_regression() + + assert not report.detected + assert report.baseline == {} + assert report.current == {} + + def test_check_single_run( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + benchmark.evaluate(candidates, references) + + report = benchmark.check_regression() + + # First run has no baseline to compare against + assert not report.detected + + def test_check_stable_metrics( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + # Run multiple times with same data + for _ in range(3): + benchmark.evaluate(candidates, references) + + report = benchmark.check_regression() + assert not report.detected + + def test_check_reports_regression(self, tmp_path: Path) -> None: + benchmark = Benchmark("regress-test", storage_path=tmp_path / "benchmarks") + + # First run with good metrics + good_candidates = ["The quick brown fox jumps."] + good_references = ["The quick brown fox jumps."] + benchmark.evaluate(good_candidates, good_references) + + # Second run with worse metrics (different text) + bad_candidates = ["Something completely different here."] + benchmark.evaluate(bad_candidates, good_references) + + report = benchmark.check_regression(tolerance=0.05) + + # Should detect regression since second run is very different + assert report.detected or any(d < -0.05 for d in report.deltas.values()) + + +class TestAssertNoRegression: + def test_passes_when_stable( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + for _ in range(3): + benchmark.evaluate(candidates, references) + + # Should not raise + benchmark.assert_no_regression() + + def test_raises_on_regression(self, tmp_path: Path) -> None: + benchmark = Benchmark("regress-test", storage_path=tmp_path / "benchmarks") + + # Establish baseline with perfect match + perfect = ["The quick brown fox."] + benchmark.evaluate(perfect, perfect) + + # Second run with terrible match + terrible = ["Completely unrelated text."] + benchmark.evaluate(terrible, perfect) + + with pytest.raises(RegressionDetectedError): + benchmark.assert_no_regression(tolerance=0.05) + + +class TestGetHistory: + def test_empty_history(self, benchmark: Benchmark) -> None: + history = benchmark.get_history() + assert history == [] + + def test_returns_runs( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + run1 = benchmark.evaluate(candidates, references) + run2 = benchmark.evaluate(candidates, references) + + history = benchmark.get_history() + + assert len(history) == 2 + assert history[0].id == run2.id # Most recent first + assert history[1].id == run1.id + + def test_respects_limit( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + for _ in range(5): + benchmark.evaluate(candidates, references) + + history = benchmark.get_history(limit=3) + assert len(history) == 3 + + def test_default_limit( + self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]] + ) -> None: + candidates, references = sample_data + + for _ in range(25): + benchmark.evaluate(candidates, references) + + history = benchmark.get_history() + assert len(history) == 20 diff --git a/tests/test_benchmark/test_storage.py b/tests/test_benchmark/test_storage.py new file mode 100644 index 0000000..df0f6b2 --- /dev/null +++ b/tests/test_benchmark/test_storage.py @@ -0,0 +1,268 @@ +"""Tests for benchmark SQLite storage.""" + +import sqlite3 +import threading +from datetime import UTC, datetime +from pathlib import Path + +import pytest + +from veritext.benchmark.models import BenchmarkRun +from veritext.benchmark.storage import BenchmarkStorage +from veritext.core.exceptions import StorageError + + +@pytest.fixture +def db_path(tmp_path: Path) -> Path: + return tmp_path / "benchmarks" / "test.db" + + +@pytest.fixture +def storage(db_path: Path) -> BenchmarkStorage: + return BenchmarkStorage(db_path) + + +@pytest.fixture +def sample_run() -> BenchmarkRun: + return BenchmarkRun( + id="run-001", + benchmark_name="test-suite", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0-dev", + metrics={"bleu4": 0.75, "rouge_l": 0.82}, + sample_count=100, + metadata={"git_sha": "abc123"}, + ) + + +class TestDatabaseCreation: + def test_creates_database_file(self, db_path: Path) -> None: + assert not db_path.exists() + BenchmarkStorage(db_path) + assert db_path.exists() + + def test_creates_parent_directories(self, tmp_path: Path) -> None: + nested_path = tmp_path / "deep" / "nested" / "path" / "test.db" + BenchmarkStorage(nested_path) + assert nested_path.exists() + + def test_creates_tables(self, db_path: Path) -> None: + BenchmarkStorage(db_path) + + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = {row[0] for row in cursor.fetchall()} + conn.close() + + assert "benchmark_runs" in tables + assert "benchmark_metrics" in tables + + def test_creates_index(self, db_path: Path) -> None: + BenchmarkStorage(db_path) + + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='index'") + indices = {row[0] for row in cursor.fetchall()} + conn.close() + + assert "idx_benchmark_name" in indices + + +class TestSaveRun: + def test_save_run( + self, storage: BenchmarkStorage, sample_run: BenchmarkRun + ) -> None: + storage.save_run(sample_run) + + runs = storage.get_runs("test-suite") + assert len(runs) == 1 + assert runs[0].id == "run-001" + + def test_save_preserves_all_fields( + self, storage: BenchmarkStorage, sample_run: BenchmarkRun + ) -> None: + storage.save_run(sample_run) + + runs = storage.get_runs("test-suite") + run = runs[0] + + assert run.id == sample_run.id + assert run.benchmark_name == sample_run.benchmark_name + assert run.timestamp == sample_run.timestamp + assert run.veritext_version == sample_run.veritext_version + assert run.metrics == sample_run.metrics + assert run.sample_count == sample_run.sample_count + assert run.metadata == sample_run.metadata + + def test_save_duplicate_id_raises( + self, storage: BenchmarkStorage, sample_run: BenchmarkRun + ) -> None: + storage.save_run(sample_run) + + with pytest.raises(StorageError, match="already exists"): + storage.save_run(sample_run) + + def test_save_run_empty_metadata(self, storage: BenchmarkStorage) -> None: + run = BenchmarkRun( + id="run-no-meta", + benchmark_name="test-suite", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0-dev", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + + storage.save_run(run) + retrieved = storage.get_latest_run("test-suite") + + assert retrieved is not None + assert retrieved.metadata == {} + + +class TestGetRuns: + def test_get_runs_empty_database(self, storage: BenchmarkStorage) -> None: + runs = storage.get_runs("nonexistent") + assert runs == [] + + def test_get_runs_filters_by_name(self, storage: BenchmarkStorage) -> None: + run1 = BenchmarkRun( + id="run-1", + benchmark_name="suite-a", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + run2 = BenchmarkRun( + id="run-2", + benchmark_name="suite-b", + timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.6}, + sample_count=10, + ) + + storage.save_run(run1) + storage.save_run(run2) + + runs_a = storage.get_runs("suite-a") + runs_b = storage.get_runs("suite-b") + + assert len(runs_a) == 1 + assert runs_a[0].id == "run-1" + assert len(runs_b) == 1 + assert runs_b[0].id == "run-2" + + def test_get_runs_ordered_by_timestamp(self, storage: BenchmarkStorage) -> None: + run_old = BenchmarkRun( + id="run-old", + benchmark_name="test", + timestamp=datetime(2025, 1, 10, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + run_new = BenchmarkRun( + id="run-new", + benchmark_name="test", + timestamp=datetime(2025, 1, 20, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.6}, + sample_count=10, + ) + + # Save in reverse order + storage.save_run(run_new) + storage.save_run(run_old) + + runs = storage.get_runs("test") + assert runs[0].id == "run-new" + assert runs[1].id == "run-old" + + def test_get_runs_with_limit(self, storage: BenchmarkStorage) -> None: + for i in range(5): + run = BenchmarkRun( + id=f"run-{i}", + benchmark_name="test", + timestamp=datetime(2025, 1, i + 1, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5 + i * 0.1}, + sample_count=10, + ) + storage.save_run(run) + + runs = storage.get_runs("test", limit=3) + assert len(runs) == 3 + + +class TestGetLatestRun: + def test_get_latest_run_empty(self, storage: BenchmarkStorage) -> None: + result = storage.get_latest_run("nonexistent") + assert result is None + + def test_get_latest_run(self, storage: BenchmarkStorage) -> None: + run_old = BenchmarkRun( + id="run-old", + benchmark_name="test", + timestamp=datetime(2025, 1, 10, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + run_new = BenchmarkRun( + id="run-new", + benchmark_name="test", + timestamp=datetime(2025, 1, 20, 12, 0, 0, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.6}, + sample_count=10, + ) + + storage.save_run(run_old) + storage.save_run(run_new) + + latest = storage.get_latest_run("test") + assert latest is not None + assert latest.id == "run-new" + + +class TestConcurrentAccess: + def test_concurrent_writes(self, db_path: Path) -> None: + errors: list[Exception] = [] + + def write_run(run_id: int) -> None: + try: + storage = BenchmarkStorage(db_path) + run = BenchmarkRun( + id=f"run-{run_id}", + benchmark_name="test", + timestamp=datetime(2025, 1, 15, 12, 0, run_id, tzinfo=UTC), + veritext_version="0.1.0", + metrics={"bleu4": 0.5}, + sample_count=10, + ) + storage.save_run(run) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_run, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"Concurrent writes failed: {errors}" + + storage = BenchmarkStorage(db_path) + runs = storage.get_runs("test") + assert len(runs) == 10 + + def test_wal_mode_enabled(self, db_path: Path) -> None: + BenchmarkStorage(db_path) + + conn = sqlite3.connect(str(db_path)) + cursor = conn.execute("PRAGMA journal_mode") + mode = cursor.fetchone()[0] + conn.close() + + assert mode.lower() == "wal"