diff --git a/src/veritext/metrics/rouge.py b/src/veritext/metrics/rouge.py new file mode 100644 index 0000000..e212853 --- /dev/null +++ b/src/veritext/metrics/rouge.py @@ -0,0 +1,268 @@ +"""ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric implementation.""" + +from collections import Counter + +from veritext.core.tokenisation import WordTokeniser +from veritext.metrics.base import AggregateStats, BatchResult +from veritext.metrics.results import RougeResult, RougeScore + + +def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]: + if n > len(tokens): + return Counter() + return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)) + + +def _ngram_overlap( + candidate_ngrams: Counter[tuple[str, ...]], + reference_ngrams: Counter[tuple[str, ...]], +) -> int: + overlap = 0 + for ngram, count in candidate_ngrams.items(): + overlap += min(count, reference_ngrams.get(ngram, 0)) + return overlap + + +def _compute_rouge_score( + candidate_tokens: list[str], + reference_tokens: list[str], + n: int, +) -> RougeScore: + """ + Compute ROUGE-n score for given n-gram size. + + Args: + candidate_tokens: Tokenised candidate text. + reference_tokens: Tokenised reference text. + n: N-gram size. + + Returns: + RougeScore with precision, recall, and F-measure. + """ + candidate_ngrams = _get_ngrams(candidate_tokens, n) + reference_ngrams = _get_ngrams(reference_tokens, n) + + candidate_count = sum(candidate_ngrams.values()) + reference_count = sum(reference_ngrams.values()) + + if candidate_count == 0 and reference_count == 0: + return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0) + + overlap = _ngram_overlap(candidate_ngrams, reference_ngrams) + + precision = overlap / candidate_count if candidate_count > 0 else 0.0 + recall = overlap / reference_count if reference_count > 0 else 0.0 + + if precision + recall > 0: + fmeasure = 2 * precision * recall / (precision + recall) + else: + fmeasure = 0.0 + + return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _lcs_length(seq1: list[str], seq2: list[str]) -> int: + """ + Compute the length of the longest common subsequence. + + Uses dynamic programming with O(m*n) time and O(min(m,n)) space. + """ + if not seq1 or not seq2: + return 0 + + if len(seq1) < len(seq2): + seq1, seq2 = seq2, seq1 + + m, n = len(seq1), len(seq2) + + prev = [0] * (n + 1) + curr = [0] * (n + 1) + + for i in range(1, m + 1): + for j in range(1, n + 1): + if seq1[i - 1] == seq2[j - 1]: + curr[j] = prev[j - 1] + 1 + else: + curr[j] = max(prev[j], curr[j - 1]) + prev, curr = curr, prev + + return prev[n] + + +def _compute_rouge_l( + candidate_tokens: list[str], + reference_tokens: list[str], +) -> RougeScore: + """ + Compute ROUGE-L score using longest common subsequence. + + Args: + candidate_tokens: Tokenised candidate text. + reference_tokens: Tokenised reference text. + + Returns: + RougeScore with precision, recall, and F-measure. + """ + if not candidate_tokens or not reference_tokens: + return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0) + + lcs = _lcs_length(candidate_tokens, reference_tokens) + + precision = lcs / len(candidate_tokens) + recall = lcs / len(reference_tokens) + + if precision + recall > 0: + fmeasure = 2 * precision * recall / (precision + recall) + else: + fmeasure = 0.0 + + return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure) + + +def _max_rouge_scores(scores: list[RougeScore]) -> RougeScore: + return max(scores, key=lambda s: s.fmeasure) + + +class Rouge: + """ + ROUGE metric for measuring summary/generation quality. + + Computes ROUGE-1 (unigram), ROUGE-2 (bigram), and ROUGE-L (LCS) scores. + ROUGE is recall-oriented, measuring how much of the reference is captured. + """ + + def __init__(self, tokeniser: WordTokeniser | None = None) -> None: + """ + Initialise the ROUGE metric. + + Args: + tokeniser: Tokeniser to use. Defaults to WordTokeniser(). + """ + self._tokeniser = tokeniser or WordTokeniser() + + @property + def name(self) -> str: + return "rouge" + + @property + def requires_reference(self) -> bool: + return True + + def score( + self, candidate: str, reference: str | list[str] | None = None + ) -> RougeResult: + """ + Compute ROUGE scores for a candidate text. + + Args: + candidate: The text to score. + reference: Reference text(s) for comparison. If multiple references + are provided, returns the maximum score for each variant. + + Returns: + RougeResult with ROUGE-1, ROUGE-2, and ROUGE-L scores. + + Raises: + ValueError: If reference is None or empty. + """ + if reference is None: + raise ValueError("ROUGE requires reference text") + + references = [reference] if isinstance(reference, str) else reference + + candidate_tokens = self._tokeniser.tokenise(candidate) + reference_token_lists = [self._tokeniser.tokenise(r) for r in references] + + if all(not ref for ref in reference_token_lists): + raise ValueError("Reference text cannot be empty") + + if not candidate_tokens: + return RougeResult( + rouge1=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0), + rouge2=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0), + rouge_l=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0), + ) + + rouge1_scores = [] + rouge2_scores = [] + rouge_l_scores = [] + + for ref_tokens in reference_token_lists: + if not ref_tokens: + continue + rouge1_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 1)) + rouge2_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 2)) + rouge_l_scores.append(_compute_rouge_l(candidate_tokens, ref_tokens)) + + if not rouge1_scores: + raise ValueError("Reference text cannot be empty") + + return RougeResult( + rouge1=_max_rouge_scores(rouge1_scores), + rouge2=_max_rouge_scores(rouge2_scores), + rouge_l=_max_rouge_scores(rouge_l_scores), + ) + + def batch_score( + self, + candidates: list[str], + references: list[str] | list[list[str]] | None = None, + ) -> BatchResult[RougeResult]: + """ + Compute ROUGE scores for a batch of candidates. + + Args: + candidates: List of texts to score. + references: Reference text(s) for each candidate. + + Returns: + BatchResult containing individual results and aggregate statistics. + + Raises: + ValueError: If references is None or length mismatch. + """ + if references is None: + raise ValueError("ROUGE requires reference texts") + + if len(candidates) != len(references): + raise ValueError( + f"Number of candidates ({len(candidates)}) must match " + f"number of references ({len(references)})" + ) + + results: list[RougeResult] = [] + for i, cand in enumerate(candidates): + ref: str | list[str] = references[i] + results.append(self.score(cand, ref)) + + stats = { + "rouge1_precision": AggregateStats.from_values( + [r.rouge1.precision for r in results] + ), + "rouge1_recall": AggregateStats.from_values( + [r.rouge1.recall for r in results] + ), + "rouge1_fmeasure": AggregateStats.from_values( + [r.rouge1.fmeasure for r in results] + ), + "rouge2_precision": AggregateStats.from_values( + [r.rouge2.precision for r in results] + ), + "rouge2_recall": AggregateStats.from_values( + [r.rouge2.recall for r in results] + ), + "rouge2_fmeasure": AggregateStats.from_values( + [r.rouge2.fmeasure for r in results] + ), + "rouge_l_precision": AggregateStats.from_values( + [r.rouge_l.precision for r in results] + ), + "rouge_l_recall": AggregateStats.from_values( + [r.rouge_l.recall for r in results] + ), + "rouge_l_fmeasure": AggregateStats.from_values( + [r.rouge_l.fmeasure for r in results] + ), + } + + return BatchResult(results=results, count=len(results), stats=stats)