benchmark runner
Main Benchmark class for evaluating text quality and tracking regressions.
This commit is contained in:
179
src/veritext/benchmark/runner.py
Normal file
179
src/veritext/benchmark/runner.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Benchmark execution and tracking."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import veritext
|
||||
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||
from veritext.benchmark.regression import compute_baseline, detect_regression
|
||||
from veritext.benchmark.storage import BenchmarkStorage
|
||||
from veritext.core.exceptions import RegressionDetectedError
|
||||
from veritext.metrics.bleu import Bleu
|
||||
from veritext.metrics.rouge import Rouge
|
||||
|
||||
DEFAULT_METRICS = ["rouge_l", "bleu4"]
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Track text quality over time."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
storage_path: str | Path = "benchmarks/",
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a benchmark tracker.
|
||||
|
||||
Args:
|
||||
name: Name identifying this benchmark suite.
|
||||
storage_path: Directory for storing benchmark data.
|
||||
"""
|
||||
self._name = name
|
||||
self._storage_path = Path(storage_path)
|
||||
self._storage = BenchmarkStorage(self._storage_path / f"{name}.db")
|
||||
|
||||
self._bleu = Bleu()
|
||||
self._rouge = Rouge()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def _compute_metrics(
|
||||
self,
|
||||
candidates: list[str],
|
||||
references: list[str] | list[list[str]],
|
||||
metric_names: list[str],
|
||||
) -> dict[str, float]:
|
||||
results: dict[str, float] = {}
|
||||
|
||||
for metric_name in metric_names:
|
||||
if metric_name in ("bleu1", "bleu2", "bleu3", "bleu4"):
|
||||
batch_result = self._bleu.batch_score(candidates, references)
|
||||
stats = batch_result.stats.get(metric_name)
|
||||
if stats:
|
||||
results[metric_name] = stats.mean
|
||||
|
||||
elif metric_name in (
|
||||
"rouge1",
|
||||
"rouge2",
|
||||
"rouge_l",
|
||||
"rouge1_fmeasure",
|
||||
"rouge2_fmeasure",
|
||||
"rouge_l_fmeasure",
|
||||
):
|
||||
rouge_result = self._rouge.batch_score(candidates, references)
|
||||
stat_name = metric_name
|
||||
if metric_name == "rouge1":
|
||||
stat_name = "rouge1_fmeasure"
|
||||
elif metric_name == "rouge2":
|
||||
stat_name = "rouge2_fmeasure"
|
||||
elif metric_name == "rouge_l":
|
||||
stat_name = "rouge_l_fmeasure"
|
||||
|
||||
stats = rouge_result.stats.get(stat_name)
|
||||
if stats:
|
||||
results[metric_name] = stats.mean
|
||||
|
||||
return results
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
candidates: list[str],
|
||||
references: list[str] | list[list[str]],
|
||||
metrics: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> BenchmarkRun:
|
||||
"""
|
||||
Evaluate candidates against references, store results, and return the run.
|
||||
|
||||
Args:
|
||||
candidates: List of candidate texts to evaluate.
|
||||
references: Reference text(s) for each candidate.
|
||||
metrics: List of metrics to compute. Defaults to ["rouge_l", "bleu4"].
|
||||
metadata: Optional metadata (git_sha, model version, etc.).
|
||||
|
||||
Returns:
|
||||
The BenchmarkRun record that was created and stored.
|
||||
"""
|
||||
metric_names = metrics or DEFAULT_METRICS
|
||||
metric_results = self._compute_metrics(candidates, references, metric_names)
|
||||
|
||||
run = BenchmarkRun(
|
||||
id=str(uuid.uuid4()),
|
||||
benchmark_name=self._name,
|
||||
timestamp=datetime.now(UTC),
|
||||
veritext_version=veritext.__version__,
|
||||
metrics=metric_results,
|
||||
sample_count=len(candidates),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._storage.save_run(run)
|
||||
return run
|
||||
|
||||
def check_regression(
|
||||
self,
|
||||
tolerance: float = 0.05,
|
||||
window: int = 10,
|
||||
) -> RegressionReport:
|
||||
"""
|
||||
Compare latest run against historical baseline.
|
||||
|
||||
Args:
|
||||
tolerance: Maximum allowed metric drop before regression is flagged.
|
||||
window: Number of historical runs to include in baseline.
|
||||
|
||||
Returns:
|
||||
RegressionReport with comparison results.
|
||||
"""
|
||||
runs = self._storage.get_runs(self._name)
|
||||
|
||||
if not runs:
|
||||
return RegressionReport(
|
||||
detected=False,
|
||||
baseline={},
|
||||
current={},
|
||||
deltas={},
|
||||
tolerance=tolerance,
|
||||
)
|
||||
|
||||
current_run = runs[0]
|
||||
historical_runs = runs[1:]
|
||||
baseline = compute_baseline(historical_runs, window=window)
|
||||
|
||||
return detect_regression(current_run.metrics, baseline, tolerance)
|
||||
|
||||
def assert_no_regression(
|
||||
self,
|
||||
tolerance: float = 0.05,
|
||||
window: int = 10,
|
||||
) -> None:
|
||||
"""
|
||||
Raise RegressionDetectedError if quality dropped.
|
||||
|
||||
Args:
|
||||
tolerance: Maximum allowed metric drop before regression is flagged.
|
||||
window: Number of historical runs to include in baseline.
|
||||
|
||||
Raises:
|
||||
RegressionDetectedError: If a regression is detected.
|
||||
"""
|
||||
report = self.check_regression(tolerance=tolerance, window=window)
|
||||
if report.detected:
|
||||
raise RegressionDetectedError(report.summary)
|
||||
|
||||
def get_history(self, limit: int = 20) -> list[BenchmarkRun]:
|
||||
"""
|
||||
Get recent benchmark runs.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of runs to return.
|
||||
|
||||
Returns:
|
||||
List of BenchmarkRun objects, most recent first.
|
||||
"""
|
||||
return self._storage.get_runs(self._name, limit=limit)
|
||||
Reference in New Issue
Block a user