298 lines
9.5 KiB
Python
298 lines
9.5 KiB
Python
"""Tests for benchmark SQLite storage."""
|
|
|
|
import sqlite3
|
|
import threading
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from veritext.benchmark.models import BenchmarkRun
|
|
from veritext.benchmark.storage import BenchmarkStorage
|
|
from veritext.core.exceptions import StorageError
|
|
|
|
|
|
@pytest.fixture
|
|
def db_path(tmp_path: Path) -> Path:
|
|
"""Return a temporary database path."""
|
|
return tmp_path / "benchmarks" / "test.db"
|
|
|
|
|
|
@pytest.fixture
|
|
def storage(db_path: Path) -> BenchmarkStorage:
|
|
"""Create a BenchmarkStorage instance."""
|
|
return BenchmarkStorage(db_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_run() -> BenchmarkRun:
|
|
"""Create a sample benchmark run."""
|
|
return BenchmarkRun(
|
|
id="run-001",
|
|
benchmark_name="test-suite",
|
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0-dev",
|
|
metrics={"bleu4": 0.75, "rouge_l": 0.82},
|
|
sample_count=100,
|
|
metadata={"git_sha": "abc123"},
|
|
)
|
|
|
|
|
|
class TestDatabaseCreation:
|
|
"""Tests for database initialisation."""
|
|
|
|
def test_creates_database_file(self, db_path: Path) -> None:
|
|
"""Storage creates the database file on init."""
|
|
assert not db_path.exists()
|
|
BenchmarkStorage(db_path)
|
|
assert db_path.exists()
|
|
|
|
def test_creates_parent_directories(self, tmp_path: Path) -> None:
|
|
"""Storage creates parent directories if needed."""
|
|
nested_path = tmp_path / "deep" / "nested" / "path" / "test.db"
|
|
BenchmarkStorage(nested_path)
|
|
assert nested_path.exists()
|
|
|
|
def test_creates_tables(self, db_path: Path) -> None:
|
|
"""Storage creates required tables."""
|
|
BenchmarkStorage(db_path)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = {row[0] for row in cursor.fetchall()}
|
|
conn.close()
|
|
|
|
assert "benchmark_runs" in tables
|
|
assert "benchmark_metrics" in tables
|
|
|
|
def test_creates_index(self, db_path: Path) -> None:
|
|
"""Storage creates index on benchmark_name and timestamp."""
|
|
BenchmarkStorage(db_path)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='index'")
|
|
indices = {row[0] for row in cursor.fetchall()}
|
|
conn.close()
|
|
|
|
assert "idx_benchmark_name" in indices
|
|
|
|
|
|
class TestSaveRun:
|
|
"""Tests for saving benchmark runs."""
|
|
|
|
def test_save_run(
|
|
self, storage: BenchmarkStorage, sample_run: BenchmarkRun
|
|
) -> None:
|
|
"""Storage can save a benchmark run."""
|
|
storage.save_run(sample_run)
|
|
|
|
runs = storage.get_runs("test-suite")
|
|
assert len(runs) == 1
|
|
assert runs[0].id == "run-001"
|
|
|
|
def test_save_preserves_all_fields(
|
|
self, storage: BenchmarkStorage, sample_run: BenchmarkRun
|
|
) -> None:
|
|
"""Saved run preserves all fields correctly."""
|
|
storage.save_run(sample_run)
|
|
|
|
runs = storage.get_runs("test-suite")
|
|
run = runs[0]
|
|
|
|
assert run.id == sample_run.id
|
|
assert run.benchmark_name == sample_run.benchmark_name
|
|
assert run.timestamp == sample_run.timestamp
|
|
assert run.veritext_version == sample_run.veritext_version
|
|
assert run.metrics == sample_run.metrics
|
|
assert run.sample_count == sample_run.sample_count
|
|
assert run.metadata == sample_run.metadata
|
|
|
|
def test_save_duplicate_id_raises(
|
|
self, storage: BenchmarkStorage, sample_run: BenchmarkRun
|
|
) -> None:
|
|
"""Saving a run with duplicate ID raises StorageError."""
|
|
storage.save_run(sample_run)
|
|
|
|
with pytest.raises(StorageError, match="already exists"):
|
|
storage.save_run(sample_run)
|
|
|
|
def test_save_run_empty_metadata(self, storage: BenchmarkStorage) -> None:
|
|
"""Run with empty metadata saves correctly."""
|
|
run = BenchmarkRun(
|
|
id="run-no-meta",
|
|
benchmark_name="test-suite",
|
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0-dev",
|
|
metrics={"bleu4": 0.5},
|
|
sample_count=10,
|
|
)
|
|
|
|
storage.save_run(run)
|
|
retrieved = storage.get_latest_run("test-suite")
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.metadata == {}
|
|
|
|
|
|
class TestGetRuns:
|
|
"""Tests for retrieving benchmark runs."""
|
|
|
|
def test_get_runs_empty_database(self, storage: BenchmarkStorage) -> None:
|
|
"""Returns empty list for empty database."""
|
|
runs = storage.get_runs("nonexistent")
|
|
assert runs == []
|
|
|
|
def test_get_runs_filters_by_name(self, storage: BenchmarkStorage) -> None:
|
|
"""Returns only runs matching the benchmark name."""
|
|
run1 = BenchmarkRun(
|
|
id="run-1",
|
|
benchmark_name="suite-a",
|
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.5},
|
|
sample_count=10,
|
|
)
|
|
run2 = BenchmarkRun(
|
|
id="run-2",
|
|
benchmark_name="suite-b",
|
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.6},
|
|
sample_count=10,
|
|
)
|
|
|
|
storage.save_run(run1)
|
|
storage.save_run(run2)
|
|
|
|
runs_a = storage.get_runs("suite-a")
|
|
runs_b = storage.get_runs("suite-b")
|
|
|
|
assert len(runs_a) == 1
|
|
assert runs_a[0].id == "run-1"
|
|
assert len(runs_b) == 1
|
|
assert runs_b[0].id == "run-2"
|
|
|
|
def test_get_runs_ordered_by_timestamp(self, storage: BenchmarkStorage) -> None:
|
|
"""Returns runs ordered by timestamp, most recent first."""
|
|
run_old = BenchmarkRun(
|
|
id="run-old",
|
|
benchmark_name="test",
|
|
timestamp=datetime(2025, 1, 10, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.5},
|
|
sample_count=10,
|
|
)
|
|
run_new = BenchmarkRun(
|
|
id="run-new",
|
|
benchmark_name="test",
|
|
timestamp=datetime(2025, 1, 20, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.6},
|
|
sample_count=10,
|
|
)
|
|
|
|
# Save in reverse order
|
|
storage.save_run(run_new)
|
|
storage.save_run(run_old)
|
|
|
|
runs = storage.get_runs("test")
|
|
assert runs[0].id == "run-new"
|
|
assert runs[1].id == "run-old"
|
|
|
|
def test_get_runs_with_limit(self, storage: BenchmarkStorage) -> None:
|
|
"""Respects limit parameter."""
|
|
for i in range(5):
|
|
run = BenchmarkRun(
|
|
id=f"run-{i}",
|
|
benchmark_name="test",
|
|
timestamp=datetime(2025, 1, i + 1, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.5 + i * 0.1},
|
|
sample_count=10,
|
|
)
|
|
storage.save_run(run)
|
|
|
|
runs = storage.get_runs("test", limit=3)
|
|
assert len(runs) == 3
|
|
|
|
|
|
class TestGetLatestRun:
|
|
"""Tests for getting the latest run."""
|
|
|
|
def test_get_latest_run_empty(self, storage: BenchmarkStorage) -> None:
|
|
"""Returns None for empty database."""
|
|
result = storage.get_latest_run("nonexistent")
|
|
assert result is None
|
|
|
|
def test_get_latest_run(self, storage: BenchmarkStorage) -> None:
|
|
"""Returns the most recent run."""
|
|
run_old = BenchmarkRun(
|
|
id="run-old",
|
|
benchmark_name="test",
|
|
timestamp=datetime(2025, 1, 10, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.5},
|
|
sample_count=10,
|
|
)
|
|
run_new = BenchmarkRun(
|
|
id="run-new",
|
|
benchmark_name="test",
|
|
timestamp=datetime(2025, 1, 20, 12, 0, 0, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.6},
|
|
sample_count=10,
|
|
)
|
|
|
|
storage.save_run(run_old)
|
|
storage.save_run(run_new)
|
|
|
|
latest = storage.get_latest_run("test")
|
|
assert latest is not None
|
|
assert latest.id == "run-new"
|
|
|
|
|
|
class TestConcurrentAccess:
|
|
"""Tests for concurrent database access."""
|
|
|
|
def test_concurrent_writes(self, db_path: Path) -> None:
|
|
"""Multiple threads can write concurrently with WAL mode."""
|
|
errors: list[Exception] = []
|
|
|
|
def write_run(run_id: int) -> None:
|
|
try:
|
|
storage = BenchmarkStorage(db_path)
|
|
run = BenchmarkRun(
|
|
id=f"run-{run_id}",
|
|
benchmark_name="test",
|
|
timestamp=datetime(2025, 1, 15, 12, 0, run_id, tzinfo=UTC),
|
|
veritext_version="0.1.0",
|
|
metrics={"bleu4": 0.5},
|
|
sample_count=10,
|
|
)
|
|
storage.save_run(run)
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads = [threading.Thread(target=write_run, args=(i,)) for i in range(10)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert not errors, f"Concurrent writes failed: {errors}"
|
|
|
|
storage = BenchmarkStorage(db_path)
|
|
runs = storage.get_runs("test")
|
|
assert len(runs) == 10
|
|
|
|
def test_wal_mode_enabled(self, db_path: Path) -> None:
|
|
"""Database uses WAL journal mode."""
|
|
BenchmarkStorage(db_path)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.execute("PRAGMA journal_mode")
|
|
mode = cursor.fetchone()[0]
|
|
conn.close()
|
|
|
|
assert mode.lower() == "wal"
|