From 73a65656a1d3ca50a80888d09372d95193efe79a Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Sat, 19 Apr 2025 12:08:56 +0000 Subject: [PATCH] benchmark runner Main Benchmark class for evaluating text quality and tracking regressions. --- src/veritext/benchmark/runner.py | 179 +++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 src/veritext/benchmark/runner.py diff --git a/src/veritext/benchmark/runner.py b/src/veritext/benchmark/runner.py new file mode 100644 index 0000000..c98820b --- /dev/null +++ b/src/veritext/benchmark/runner.py @@ -0,0 +1,179 @@ +"""Benchmark execution and tracking.""" + +import uuid +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import veritext +from veritext.benchmark.models import BenchmarkRun, RegressionReport +from veritext.benchmark.regression import compute_baseline, detect_regression +from veritext.benchmark.storage import BenchmarkStorage +from veritext.core.exceptions import RegressionDetectedError +from veritext.metrics.bleu import Bleu +from veritext.metrics.rouge import Rouge + +DEFAULT_METRICS = ["rouge_l", "bleu4"] + + +class Benchmark: + """Track text quality over time.""" + + def __init__( + self, + name: str, + storage_path: str | Path = "benchmarks/", + ) -> None: + """ + Initialise a benchmark tracker. + + Args: + name: Name identifying this benchmark suite. + storage_path: Directory for storing benchmark data. + """ + self._name = name + self._storage_path = Path(storage_path) + self._storage = BenchmarkStorage(self._storage_path / f"{name}.db") + + self._bleu = Bleu() + self._rouge = Rouge() + + @property + def name(self) -> str: + return self._name + + def _compute_metrics( + self, + candidates: list[str], + references: list[str] | list[list[str]], + metric_names: list[str], + ) -> dict[str, float]: + results: dict[str, float] = {} + + for metric_name in metric_names: + if metric_name in ("bleu1", "bleu2", "bleu3", "bleu4"): + batch_result = self._bleu.batch_score(candidates, references) + stats = batch_result.stats.get(metric_name) + if stats: + results[metric_name] = stats.mean + + elif metric_name in ( + "rouge1", + "rouge2", + "rouge_l", + "rouge1_fmeasure", + "rouge2_fmeasure", + "rouge_l_fmeasure", + ): + rouge_result = self._rouge.batch_score(candidates, references) + stat_name = metric_name + if metric_name == "rouge1": + stat_name = "rouge1_fmeasure" + elif metric_name == "rouge2": + stat_name = "rouge2_fmeasure" + elif metric_name == "rouge_l": + stat_name = "rouge_l_fmeasure" + + stats = rouge_result.stats.get(stat_name) + if stats: + results[metric_name] = stats.mean + + return results + + def evaluate( + self, + candidates: list[str], + references: list[str] | list[list[str]], + metrics: list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> BenchmarkRun: + """ + Evaluate candidates against references, store results, and return the run. + + Args: + candidates: List of candidate texts to evaluate. + references: Reference text(s) for each candidate. + metrics: List of metrics to compute. Defaults to ["rouge_l", "bleu4"]. + metadata: Optional metadata (git_sha, model version, etc.). + + Returns: + The BenchmarkRun record that was created and stored. + """ + metric_names = metrics or DEFAULT_METRICS + metric_results = self._compute_metrics(candidates, references, metric_names) + + run = BenchmarkRun( + id=str(uuid.uuid4()), + benchmark_name=self._name, + timestamp=datetime.now(UTC), + veritext_version=veritext.__version__, + metrics=metric_results, + sample_count=len(candidates), + metadata=metadata or {}, + ) + + self._storage.save_run(run) + return run + + def check_regression( + self, + tolerance: float = 0.05, + window: int = 10, + ) -> RegressionReport: + """ + Compare latest run against historical baseline. + + Args: + tolerance: Maximum allowed metric drop before regression is flagged. + window: Number of historical runs to include in baseline. + + Returns: + RegressionReport with comparison results. + """ + runs = self._storage.get_runs(self._name) + + if not runs: + return RegressionReport( + detected=False, + baseline={}, + current={}, + deltas={}, + tolerance=tolerance, + ) + + current_run = runs[0] + historical_runs = runs[1:] + baseline = compute_baseline(historical_runs, window=window) + + return detect_regression(current_run.metrics, baseline, tolerance) + + def assert_no_regression( + self, + tolerance: float = 0.05, + window: int = 10, + ) -> None: + """ + Raise RegressionDetectedError if quality dropped. + + Args: + tolerance: Maximum allowed metric drop before regression is flagged. + window: Number of historical runs to include in baseline. + + Raises: + RegressionDetectedError: If a regression is detected. + """ + report = self.check_regression(tolerance=tolerance, window=window) + if report.detected: + raise RegressionDetectedError(report.summary) + + def get_history(self, limit: int = 20) -> list[BenchmarkRun]: + """ + Get recent benchmark runs. + + Args: + limit: Maximum number of runs to return. + + Returns: + List of BenchmarkRun objects, most recent first. + """ + return self._storage.get_runs(self._name, limit=limit)