feat: BLEU scorer
Implement BLEU-1 through BLEU-4 with modified n-gram precision, brevity penalty, and support for multiple references.
This commit is contained in:
16
src/veritext/metrics/__init__.py
Normal file
16
src/veritext/metrics/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Metrics module: BLEU, lexical similarity, and batch processing."""
|
||||||
|
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult, Metric
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.lexical import Lexical
|
||||||
|
from veritext.metrics.results import BleuResult, LexicalResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AggregateStats",
|
||||||
|
"BatchResult",
|
||||||
|
"Bleu",
|
||||||
|
"BleuResult",
|
||||||
|
"Lexical",
|
||||||
|
"LexicalResult",
|
||||||
|
"Metric",
|
||||||
|
]
|
||||||
197
src/veritext/metrics/bleu.py
Normal file
197
src/veritext/metrics/bleu.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""BLEU (Bilingual Evaluation Understudy) metric implementation."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult
|
||||||
|
from veritext.metrics.results import BleuResult
|
||||||
|
|
||||||
|
MAX_NGRAM_ORDER = 4
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]:
|
||||||
|
if n > len(tokens):
|
||||||
|
return Counter()
|
||||||
|
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _modified_precision(
|
||||||
|
candidate_tokens: list[str],
|
||||||
|
reference_token_lists: list[list[str]],
|
||||||
|
n: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Compute modified n-gram precision (clipped counts).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (clipped_count, total_count).
|
||||||
|
"""
|
||||||
|
candidate_ngrams = _get_ngrams(candidate_tokens, n)
|
||||||
|
if not candidate_ngrams:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
max_ref_counts: Counter[tuple[str, ...]] = Counter()
|
||||||
|
for ref_tokens in reference_token_lists:
|
||||||
|
ref_ngrams = _get_ngrams(ref_tokens, n)
|
||||||
|
for ngram, count in ref_ngrams.items():
|
||||||
|
max_ref_counts[ngram] = max(max_ref_counts[ngram], count)
|
||||||
|
|
||||||
|
clipped_count = 0
|
||||||
|
for ngram, count in candidate_ngrams.items():
|
||||||
|
clipped_count += min(count, max_ref_counts[ngram])
|
||||||
|
|
||||||
|
return clipped_count, sum(candidate_ngrams.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _brevity_penalty(candidate_len: int, reference_lens: list[int]) -> float:
|
||||||
|
"""
|
||||||
|
Compute brevity penalty.
|
||||||
|
|
||||||
|
Uses the closest reference length to the candidate length.
|
||||||
|
"""
|
||||||
|
if candidate_len == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
closest_ref_len = min(reference_lens, key=lambda r: (abs(r - candidate_len), r))
|
||||||
|
|
||||||
|
if candidate_len >= closest_ref_len:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return math.exp(1 - closest_ref_len / candidate_len)
|
||||||
|
|
||||||
|
|
||||||
|
class Bleu:
|
||||||
|
"""
|
||||||
|
BLEU metric for measuring translation/generation quality.
|
||||||
|
|
||||||
|
Computes BLEU-1 through BLEU-4 scores using modified n-gram precision
|
||||||
|
with brevity penalty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokeniser: WordTokeniser | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the BLEU metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
"""
|
||||||
|
self._tokeniser = tokeniser or WordTokeniser()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "bleu"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_reference(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self, candidate: str, reference: str | list[str] | None = None
|
||||||
|
) -> BleuResult:
|
||||||
|
"""
|
||||||
|
Compute BLEU scores for a candidate text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate: The text to score.
|
||||||
|
reference: Reference text(s) for comparison.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BleuResult with BLEU-1 through BLEU-4 and brevity penalty.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If reference is None.
|
||||||
|
"""
|
||||||
|
if reference is None:
|
||||||
|
raise ValueError("BLEU requires reference text")
|
||||||
|
|
||||||
|
references = [reference] if isinstance(reference, str) else reference
|
||||||
|
|
||||||
|
candidate_tokens = self._tokeniser.tokenise(candidate)
|
||||||
|
reference_token_lists = [self._tokeniser.tokenise(r) for r in references]
|
||||||
|
|
||||||
|
if not candidate_tokens:
|
||||||
|
return BleuResult(
|
||||||
|
bleu1=0.0, bleu2=0.0, bleu3=0.0, bleu4=0.0, brevity_penalty=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if all(not ref for ref in reference_token_lists):
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
precisions = []
|
||||||
|
for n in range(1, MAX_NGRAM_ORDER + 1):
|
||||||
|
clipped, total = _modified_precision(
|
||||||
|
candidate_tokens, reference_token_lists, n
|
||||||
|
)
|
||||||
|
if total == 0:
|
||||||
|
precisions.append(0.0)
|
||||||
|
else:
|
||||||
|
precisions.append(clipped / total)
|
||||||
|
|
||||||
|
bp = _brevity_penalty(
|
||||||
|
len(candidate_tokens),
|
||||||
|
[len(ref) for ref in reference_token_lists],
|
||||||
|
)
|
||||||
|
|
||||||
|
def geometric_mean(p_list: list[float]) -> float:
|
||||||
|
if any(p == 0.0 for p in p_list):
|
||||||
|
return 0.0
|
||||||
|
log_sum = sum(math.log(p) for p in p_list)
|
||||||
|
return math.exp(log_sum / len(p_list))
|
||||||
|
|
||||||
|
bleu_scores = []
|
||||||
|
for n in range(1, MAX_NGRAM_ORDER + 1):
|
||||||
|
bleu_n = bp * geometric_mean(precisions[:n])
|
||||||
|
bleu_scores.append(bleu_n)
|
||||||
|
|
||||||
|
return BleuResult(
|
||||||
|
bleu1=bleu_scores[0],
|
||||||
|
bleu2=bleu_scores[1],
|
||||||
|
bleu3=bleu_scores[2],
|
||||||
|
bleu4=bleu_scores[3],
|
||||||
|
brevity_penalty=bp,
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]] | None = None,
|
||||||
|
) -> BatchResult[BleuResult]:
|
||||||
|
"""
|
||||||
|
Compute BLEU scores for a batch of candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of texts to score.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchResult containing individual results and aggregate statistics.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If references is None or length mismatch.
|
||||||
|
"""
|
||||||
|
if references is None:
|
||||||
|
raise ValueError("BLEU requires reference texts")
|
||||||
|
|
||||||
|
if len(candidates) != len(references):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of candidates ({len(candidates)}) must match "
|
||||||
|
f"number of references ({len(references)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[BleuResult] = []
|
||||||
|
for i, cand in enumerate(candidates):
|
||||||
|
ref: str | list[str] = references[i]
|
||||||
|
results.append(self.score(cand, ref))
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"bleu1": AggregateStats.from_values([r.bleu1 for r in results]),
|
||||||
|
"bleu2": AggregateStats.from_values([r.bleu2 for r in results]),
|
||||||
|
"bleu3": AggregateStats.from_values([r.bleu3 for r in results]),
|
||||||
|
"bleu4": AggregateStats.from_values([r.bleu4 for r in results]),
|
||||||
|
"brevity_penalty": AggregateStats.from_values(
|
||||||
|
[r.brevity_penalty for r in results]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BatchResult(results=results, count=len(results), stats=stats)
|
||||||
Reference in New Issue
Block a user