feat(metrics): add metric protocol and batch types

Add Metric protocol, AggregateStats for statistical summaries, and
BatchResult for batch processing support.
This commit is contained in:
2026-02-03 16:45:38 +00:00
parent 14dcddcbba
commit e6167005e5
2 changed files with 180 additions and 0 deletions

View File

@@ -0,0 +1,139 @@
"""Base types and protocols for metrics."""
import math
from typing import Generic, Protocol, TypeVar
from pydantic import BaseModel, ConfigDict
T = TypeVar("T")
class MetricResult(Protocol):
"""Protocol for metric result types."""
class AggregateStats(BaseModel):
"""Aggregate statistics for a batch of metric scores."""
model_config = ConfigDict(frozen=True)
mean: float
"""Mean of the scores."""
std: float
"""Standard deviation of the scores."""
min: float
"""Minimum score."""
max: float
"""Maximum score."""
percentiles: dict[int, float]
"""Percentile values (typically 25, 50, 75, 95)."""
@classmethod
def from_values(cls, values: list[float]) -> "AggregateStats":
"""
Compute aggregate statistics from a list of values.
Args:
values: List of numeric values to aggregate.
Returns:
AggregateStats with computed statistics.
Raises:
ValueError: If values list is empty.
"""
if not values:
raise ValueError("Cannot compute statistics from empty list")
n = len(values)
mean = sum(values) / n
if n == 1:
std = 0.0
else:
variance = sum((v - mean) ** 2 for v in values) / (n - 1)
std = math.sqrt(variance)
sorted_values = sorted(values)
def percentile(p: int) -> float:
"""Compute percentile using linear interpolation."""
if n == 1:
return sorted_values[0]
k = (n - 1) * p / 100
f = math.floor(k)
c = math.ceil(k)
if f == c:
return sorted_values[int(k)]
return sorted_values[f] * (c - k) + sorted_values[c] * (k - f)
return cls(
mean=mean,
std=std,
min=sorted_values[0],
max=sorted_values[-1],
percentiles={p: percentile(p) for p in (25, 50, 75, 95)},
)
class BatchResult(BaseModel, Generic[T]):
"""Result of batch metric computation."""
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
results: list[T]
"""Individual results for each input."""
count: int
"""Number of results."""
stats: dict[str, AggregateStats]
"""Aggregate statistics keyed by score name."""
class Metric(Protocol[T]):
"""Protocol for metrics that compute scores from text."""
@property
def name(self) -> str:
"""Return the name of this metric."""
...
@property
def requires_reference(self) -> bool:
"""Return whether this metric requires reference text."""
...
def score(self, candidate: str, reference: str | list[str] | None = None) -> T:
"""
Compute the metric score for a candidate text.
Args:
candidate: The text to score.
reference: Reference text(s) for comparison, if required.
Returns:
The computed metric result.
"""
...
def batch_score(
self,
candidates: list[str],
references: list[str] | list[list[str]] | None = None,
) -> BatchResult[T]:
"""
Compute metric scores for a batch of candidates.
Args:
candidates: List of texts to score.
references: Reference text(s) for each candidate, if required.
Returns:
BatchResult containing individual results and aggregate statistics.
"""
...

View File

@@ -0,0 +1,41 @@
"""Result types for metrics."""
from pydantic import BaseModel, ConfigDict
class BleuResult(BaseModel):
"""Result of BLEU score computation."""
model_config = ConfigDict(frozen=True)
bleu1: float
"""Unigram BLEU score (precision)."""
bleu2: float
"""Bigram BLEU score (precision)."""
bleu3: float
"""Trigram BLEU score (precision)."""
bleu4: float
"""4-gram BLEU score (precision)."""
brevity_penalty: float
"""Brevity penalty applied to the score."""
@property
def score(self) -> float:
"""Return the composite BLEU-4 score with brevity penalty."""
return self.bleu4
class LexicalResult(BaseModel):
"""Result of lexical similarity computation."""
model_config = ConfigDict(frozen=True)
jaccard: float
"""Jaccard similarity: |intersection| / |union| of token sets."""
token_overlap: float
"""Proportion of candidate tokens found in reference."""