feat: BLEU scorer

Implement BLEU-1 through BLEU-4 with modified n-gram precision,
brevity penalty, and support for multiple references.
This commit is contained in:
2025-03-15 11:08:36 +00:00
parent 7832fa3d59
commit 82b6ffea79
2 changed files with 213 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
"""Metrics module: BLEU, lexical similarity, and batch processing."""
from veritext.metrics.base import AggregateStats, BatchResult, Metric
from veritext.metrics.bleu import Bleu
from veritext.metrics.lexical import Lexical
from veritext.metrics.results import BleuResult, LexicalResult
__all__ = [
"AggregateStats",
"BatchResult",
"Bleu",
"BleuResult",
"Lexical",
"LexicalResult",
"Metric",
]

View File

@@ -0,0 +1,197 @@
"""BLEU (Bilingual Evaluation Understudy) metric implementation."""
import math
from collections import Counter
from veritext.core.tokenisation import WordTokeniser
from veritext.metrics.base import AggregateStats, BatchResult
from veritext.metrics.results import BleuResult
MAX_NGRAM_ORDER = 4
def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]:
if n > len(tokens):
return Counter()
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
def _modified_precision(
candidate_tokens: list[str],
reference_token_lists: list[list[str]],
n: int,
) -> tuple[int, int]:
"""
Compute modified n-gram precision (clipped counts).
Returns:
Tuple of (clipped_count, total_count).
"""
candidate_ngrams = _get_ngrams(candidate_tokens, n)
if not candidate_ngrams:
return 0, 0
max_ref_counts: Counter[tuple[str, ...]] = Counter()
for ref_tokens in reference_token_lists:
ref_ngrams = _get_ngrams(ref_tokens, n)
for ngram, count in ref_ngrams.items():
max_ref_counts[ngram] = max(max_ref_counts[ngram], count)
clipped_count = 0
for ngram, count in candidate_ngrams.items():
clipped_count += min(count, max_ref_counts[ngram])
return clipped_count, sum(candidate_ngrams.values())
def _brevity_penalty(candidate_len: int, reference_lens: list[int]) -> float:
"""
Compute brevity penalty.
Uses the closest reference length to the candidate length.
"""
if candidate_len == 0:
return 0.0
closest_ref_len = min(reference_lens, key=lambda r: (abs(r - candidate_len), r))
if candidate_len >= closest_ref_len:
return 1.0
return math.exp(1 - closest_ref_len / candidate_len)
class Bleu:
"""
BLEU metric for measuring translation/generation quality.
Computes BLEU-1 through BLEU-4 scores using modified n-gram precision
with brevity penalty.
"""
def __init__(self, tokeniser: WordTokeniser | None = None) -> None:
"""
Initialise the BLEU metric.
Args:
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
"""
self._tokeniser = tokeniser or WordTokeniser()
@property
def name(self) -> str:
return "bleu"
@property
def requires_reference(self) -> bool:
return True
def score(
self, candidate: str, reference: str | list[str] | None = None
) -> BleuResult:
"""
Compute BLEU scores for a candidate text.
Args:
candidate: The text to score.
reference: Reference text(s) for comparison.
Returns:
BleuResult with BLEU-1 through BLEU-4 and brevity penalty.
Raises:
ValueError: If reference is None.
"""
if reference is None:
raise ValueError("BLEU requires reference text")
references = [reference] if isinstance(reference, str) else reference
candidate_tokens = self._tokeniser.tokenise(candidate)
reference_token_lists = [self._tokeniser.tokenise(r) for r in references]
if not candidate_tokens:
return BleuResult(
bleu1=0.0, bleu2=0.0, bleu3=0.0, bleu4=0.0, brevity_penalty=0.0
)
if all(not ref for ref in reference_token_lists):
raise ValueError("Reference text cannot be empty")
precisions = []
for n in range(1, MAX_NGRAM_ORDER + 1):
clipped, total = _modified_precision(
candidate_tokens, reference_token_lists, n
)
if total == 0:
precisions.append(0.0)
else:
precisions.append(clipped / total)
bp = _brevity_penalty(
len(candidate_tokens),
[len(ref) for ref in reference_token_lists],
)
def geometric_mean(p_list: list[float]) -> float:
if any(p == 0.0 for p in p_list):
return 0.0
log_sum = sum(math.log(p) for p in p_list)
return math.exp(log_sum / len(p_list))
bleu_scores = []
for n in range(1, MAX_NGRAM_ORDER + 1):
bleu_n = bp * geometric_mean(precisions[:n])
bleu_scores.append(bleu_n)
return BleuResult(
bleu1=bleu_scores[0],
bleu2=bleu_scores[1],
bleu3=bleu_scores[2],
bleu4=bleu_scores[3],
brevity_penalty=bp,
)
def batch_score(
self,
candidates: list[str],
references: list[str] | list[list[str]] | None = None,
) -> BatchResult[BleuResult]:
"""
Compute BLEU scores for a batch of candidates.
Args:
candidates: List of texts to score.
references: Reference text(s) for each candidate.
Returns:
BatchResult containing individual results and aggregate statistics.
Raises:
ValueError: If references is None or length mismatch.
"""
if references is None:
raise ValueError("BLEU requires reference texts")
if len(candidates) != len(references):
raise ValueError(
f"Number of candidates ({len(candidates)}) must match "
f"number of references ({len(references)})"
)
results: list[BleuResult] = []
for i, cand in enumerate(candidates):
ref: str | list[str] = references[i]
results.append(self.score(cand, ref))
stats = {
"bleu1": AggregateStats.from_values([r.bleu1 for r in results]),
"bleu2": AggregateStats.from_values([r.bleu2 for r in results]),
"bleu3": AggregateStats.from_values([r.bleu3 for r in results]),
"bleu4": AggregateStats.from_values([r.bleu4 for r in results]),
"brevity_penalty": AggregateStats.from_values(
[r.brevity_penalty for r in results]
),
}
return BatchResult(results=results, count=len(results), stats=stats)