diff --git a/src/veritext/metrics/__init__.py b/src/veritext/metrics/__init__.py new file mode 100644 index 0000000..9fd2fc9 --- /dev/null +++ b/src/veritext/metrics/__init__.py @@ -0,0 +1,16 @@ +"""Metrics module: BLEU, lexical similarity, and batch processing.""" + +from veritext.metrics.base import AggregateStats, BatchResult, Metric +from veritext.metrics.bleu import Bleu +from veritext.metrics.lexical import Lexical +from veritext.metrics.results import BleuResult, LexicalResult + +__all__ = [ + "AggregateStats", + "BatchResult", + "Bleu", + "BleuResult", + "Lexical", + "LexicalResult", + "Metric", +] diff --git a/src/veritext/metrics/bleu.py b/src/veritext/metrics/bleu.py new file mode 100644 index 0000000..050f72c --- /dev/null +++ b/src/veritext/metrics/bleu.py @@ -0,0 +1,210 @@ +"""BLEU (Bilingual Evaluation Understudy) metric implementation.""" + +import math +from collections import Counter + +from veritext.core.tokenisation import WordTokeniser +from veritext.metrics.base import AggregateStats, BatchResult +from veritext.metrics.results import BleuResult + + +def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]: + """Extract n-grams from a list of tokens.""" + if n > len(tokens): + return Counter() + return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)) + + +def _modified_precision( + candidate_tokens: list[str], + reference_token_lists: list[list[str]], + n: int, +) -> tuple[int, int]: + """ + Compute modified n-gram precision (clipped counts). + + Returns: + Tuple of (clipped_count, total_count). + """ + candidate_ngrams = _get_ngrams(candidate_tokens, n) + if not candidate_ngrams: + return 0, 0 + + # Get max count for each n-gram across all references + max_ref_counts: Counter[tuple[str, ...]] = Counter() + for ref_tokens in reference_token_lists: + ref_ngrams = _get_ngrams(ref_tokens, n) + for ngram, count in ref_ngrams.items(): + max_ref_counts[ngram] = max(max_ref_counts[ngram], count) + + # Clip candidate counts to max reference counts + clipped_count = 0 + for ngram, count in candidate_ngrams.items(): + clipped_count += min(count, max_ref_counts[ngram]) + + return clipped_count, sum(candidate_ngrams.values()) + + +def _brevity_penalty(candidate_len: int, reference_lens: list[int]) -> float: + """ + Compute brevity penalty. + + Uses the closest reference length to the candidate length. + """ + if candidate_len == 0: + return 0.0 + + # Find closest reference length + closest_ref_len = min(reference_lens, key=lambda r: (abs(r - candidate_len), r)) + + if candidate_len >= closest_ref_len: + return 1.0 + + return math.exp(1 - closest_ref_len / candidate_len) + + +class Bleu: + """ + BLEU metric for measuring translation/generation quality. + + Computes BLEU-1 through BLEU-4 scores using modified n-gram precision + with brevity penalty. + """ + + def __init__(self, tokeniser: WordTokeniser | None = None) -> None: + """ + Initialise the BLEU metric. + + Args: + tokeniser: Tokeniser to use. Defaults to WordTokeniser(). + """ + self._tokeniser = tokeniser or WordTokeniser() + + @property + def name(self) -> str: + """Return the name of this metric.""" + return "bleu" + + @property + def requires_reference(self) -> bool: + """Return whether this metric requires reference text.""" + return True + + def score( + self, candidate: str, reference: str | list[str] | None = None + ) -> BleuResult: + """ + Compute BLEU scores for a candidate text. + + Args: + candidate: The text to score. + reference: Reference text(s) for comparison. + + Returns: + BleuResult with BLEU-1 through BLEU-4 and brevity penalty. + + Raises: + ValueError: If reference is None. + """ + if reference is None: + raise ValueError("BLEU requires reference text") + + # Normalise reference to list + references = [reference] if isinstance(reference, str) else reference + + # Tokenise + candidate_tokens = self._tokeniser.tokenise(candidate) + reference_token_lists = [self._tokeniser.tokenise(r) for r in references] + + # Handle empty candidate + if not candidate_tokens: + return BleuResult( + bleu1=0.0, bleu2=0.0, bleu3=0.0, bleu4=0.0, brevity_penalty=0.0 + ) + + # Handle empty references + if all(not ref for ref in reference_token_lists): + raise ValueError("Reference text cannot be empty") + + # Compute modified precisions for n=1,2,3,4 + precisions = [] + for n in range(1, 5): + clipped, total = _modified_precision( + candidate_tokens, reference_token_lists, n + ) + if total == 0: + precisions.append(0.0) + else: + precisions.append(clipped / total) + + # Compute BLEU scores using geometric mean + bp = _brevity_penalty( + len(candidate_tokens), + [len(ref) for ref in reference_token_lists], + ) + + def geometric_mean(p_list: list[float]) -> float: + """Compute geometric mean with smoothing for zeros.""" + if any(p == 0.0 for p in p_list): + return 0.0 + log_sum = sum(math.log(p) for p in p_list) + return math.exp(log_sum / len(p_list)) + + bleu_scores = [] + for n in range(1, 5): + # BLEU-n uses precisions 1 through n + bleu_n = bp * geometric_mean(precisions[:n]) + bleu_scores.append(bleu_n) + + return BleuResult( + bleu1=bleu_scores[0], + bleu2=bleu_scores[1], + bleu3=bleu_scores[2], + bleu4=bleu_scores[3], + brevity_penalty=bp, + ) + + def batch_score( + self, + candidates: list[str], + references: list[str] | list[list[str]] | None = None, + ) -> BatchResult[BleuResult]: + """ + Compute BLEU scores for a batch of candidates. + + Args: + candidates: List of texts to score. + references: Reference text(s) for each candidate. + + Returns: + BatchResult containing individual results and aggregate statistics. + + Raises: + ValueError: If references is None or length mismatch. + """ + if references is None: + raise ValueError("BLEU requires reference texts") + + if len(candidates) != len(references): + raise ValueError( + f"Number of candidates ({len(candidates)}) must match " + f"number of references ({len(references)})" + ) + + results: list[BleuResult] = [] + for i, cand in enumerate(candidates): + ref: str | list[str] = references[i] + results.append(self.score(cand, ref)) + + # Compute aggregate statistics for each score type + stats = { + "bleu1": AggregateStats.from_values([r.bleu1 for r in results]), + "bleu2": AggregateStats.from_values([r.bleu2 for r in results]), + "bleu3": AggregateStats.from_values([r.bleu3 for r in results]), + "bleu4": AggregateStats.from_values([r.bleu4 for r in results]), + "brevity_penalty": AggregateStats.from_values( + [r.brevity_penalty for r in results] + ), + } + + return BatchResult(results=results, count=len(results), stats=stats)