feat(benchmark): add Benchmark runner class
Main Benchmark class for evaluating text quality and tracking regressions.
This commit is contained in:
186
src/veritext/benchmark/runner.py
Normal file
186
src/veritext/benchmark/runner.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Benchmark execution and tracking."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import veritext
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
from veritext.benchmark.regression import compute_baseline, detect_regression
|
||||||
|
from veritext.benchmark.storage import BenchmarkStorage
|
||||||
|
from veritext.core.exceptions import RegressionDetectedError
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
|
# Default metrics to use for evaluation
|
||||||
|
DEFAULT_METRICS = ["rouge_l", "bleu4"]
|
||||||
|
|
||||||
|
|
||||||
|
class Benchmark:
|
||||||
|
"""Track text quality over time."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
storage_path: str | Path = "benchmarks/",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise a benchmark tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name identifying this benchmark suite.
|
||||||
|
storage_path: Directory for storing benchmark data.
|
||||||
|
"""
|
||||||
|
self._name = name
|
||||||
|
self._storage_path = Path(storage_path)
|
||||||
|
self._storage = BenchmarkStorage(self._storage_path / f"{name}.db")
|
||||||
|
|
||||||
|
# Initialise metrics
|
||||||
|
self._bleu = Bleu()
|
||||||
|
self._rouge = Rouge()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the benchmark name."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def _compute_metrics(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]],
|
||||||
|
metric_names: list[str],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Compute requested metrics for the given samples."""
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
|
||||||
|
for metric_name in metric_names:
|
||||||
|
if metric_name in ("bleu1", "bleu2", "bleu3", "bleu4"):
|
||||||
|
batch_result = self._bleu.batch_score(candidates, references)
|
||||||
|
stats = batch_result.stats.get(metric_name)
|
||||||
|
if stats:
|
||||||
|
results[metric_name] = stats.mean
|
||||||
|
|
||||||
|
elif metric_name in (
|
||||||
|
"rouge1",
|
||||||
|
"rouge2",
|
||||||
|
"rouge_l",
|
||||||
|
"rouge1_fmeasure",
|
||||||
|
"rouge2_fmeasure",
|
||||||
|
"rouge_l_fmeasure",
|
||||||
|
):
|
||||||
|
rouge_result = self._rouge.batch_score(candidates, references)
|
||||||
|
# Map short names to stat names
|
||||||
|
stat_name = metric_name
|
||||||
|
if metric_name == "rouge1":
|
||||||
|
stat_name = "rouge1_fmeasure"
|
||||||
|
elif metric_name == "rouge2":
|
||||||
|
stat_name = "rouge2_fmeasure"
|
||||||
|
elif metric_name == "rouge_l":
|
||||||
|
stat_name = "rouge_l_fmeasure"
|
||||||
|
|
||||||
|
stats = rouge_result.stats.get(stat_name)
|
||||||
|
if stats:
|
||||||
|
results[metric_name] = stats.mean
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]],
|
||||||
|
metrics: list[str] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> BenchmarkRun:
|
||||||
|
"""
|
||||||
|
Evaluate candidates against references, store results, and return the run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of candidate texts to evaluate.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
metrics: List of metrics to compute. Defaults to ["rouge_l", "bleu4"].
|
||||||
|
metadata: Optional metadata (git_sha, model version, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The BenchmarkRun record that was created and stored.
|
||||||
|
"""
|
||||||
|
metric_names = metrics or DEFAULT_METRICS
|
||||||
|
metric_results = self._compute_metrics(candidates, references, metric_names)
|
||||||
|
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
benchmark_name=self._name,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
veritext_version=veritext.__version__,
|
||||||
|
metrics=metric_results,
|
||||||
|
sample_count=len(candidates),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._storage.save_run(run)
|
||||||
|
return run
|
||||||
|
|
||||||
|
def check_regression(
|
||||||
|
self,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
window: int = 10,
|
||||||
|
) -> RegressionReport:
|
||||||
|
"""
|
||||||
|
Compare latest run against historical baseline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tolerance: Maximum allowed metric drop before regression is flagged.
|
||||||
|
window: Number of historical runs to include in baseline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RegressionReport with comparison results.
|
||||||
|
"""
|
||||||
|
runs = self._storage.get_runs(self._name)
|
||||||
|
|
||||||
|
if not runs:
|
||||||
|
# No runs at all
|
||||||
|
return RegressionReport(
|
||||||
|
detected=False,
|
||||||
|
baseline={},
|
||||||
|
current={},
|
||||||
|
deltas={},
|
||||||
|
tolerance=tolerance,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_run = runs[0]
|
||||||
|
# Baseline excludes the current run
|
||||||
|
historical_runs = runs[1:]
|
||||||
|
baseline = compute_baseline(historical_runs, window=window)
|
||||||
|
|
||||||
|
return detect_regression(current_run.metrics, baseline, tolerance)
|
||||||
|
|
||||||
|
def assert_no_regression(
|
||||||
|
self,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
window: int = 10,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Raise RegressionDetectedError if quality dropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tolerance: Maximum allowed metric drop before regression is flagged.
|
||||||
|
window: Number of historical runs to include in baseline.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RegressionDetectedError: If a regression is detected.
|
||||||
|
"""
|
||||||
|
report = self.check_regression(tolerance=tolerance, window=window)
|
||||||
|
if report.detected:
|
||||||
|
raise RegressionDetectedError(report.summary)
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 20) -> list[BenchmarkRun]:
|
||||||
|
"""
|
||||||
|
Get recent benchmark runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of runs to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BenchmarkRun objects, most recent first.
|
||||||
|
"""
|
||||||
|
return self._storage.get_runs(self._name, limit=limit)
|
||||||
Reference in New Issue
Block a user