feat(metrics): add ROUGE implementation

This commit is contained in:
2026-02-03 17:03:19 +00:00
parent 914c738013
commit 2a7476046d

View File

@@ -0,0 +1,281 @@
"""ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric implementation."""
from collections import Counter
from veritext.core.tokenisation import WordTokeniser
from veritext.metrics.base import AggregateStats, BatchResult
from veritext.metrics.results import RougeResult, RougeScore
def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]:
"""Extract n-grams from a list of tokens."""
if n > len(tokens):
return Counter()
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
def _ngram_overlap(
candidate_ngrams: Counter[tuple[str, ...]],
reference_ngrams: Counter[tuple[str, ...]],
) -> int:
"""Compute the overlap count between candidate and reference n-grams."""
overlap = 0
for ngram, count in candidate_ngrams.items():
overlap += min(count, reference_ngrams.get(ngram, 0))
return overlap
def _compute_rouge_score(
candidate_tokens: list[str],
reference_tokens: list[str],
n: int,
) -> RougeScore:
"""
Compute ROUGE-n score for given n-gram size.
Args:
candidate_tokens: Tokenised candidate text.
reference_tokens: Tokenised reference text.
n: N-gram size.
Returns:
RougeScore with precision, recall, and F-measure.
"""
candidate_ngrams = _get_ngrams(candidate_tokens, n)
reference_ngrams = _get_ngrams(reference_tokens, n)
candidate_count = sum(candidate_ngrams.values())
reference_count = sum(reference_ngrams.values())
if candidate_count == 0 and reference_count == 0:
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
overlap = _ngram_overlap(candidate_ngrams, reference_ngrams)
precision = overlap / candidate_count if candidate_count > 0 else 0.0
recall = overlap / reference_count if reference_count > 0 else 0.0
if precision + recall > 0:
fmeasure = 2 * precision * recall / (precision + recall)
else:
fmeasure = 0.0
return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure)
def _lcs_length(seq1: list[str], seq2: list[str]) -> int:
"""
Compute the length of the longest common subsequence.
Uses dynamic programming with O(m*n) time and O(min(m,n)) space.
"""
if not seq1 or not seq2:
return 0
# Optimise by using shorter sequence for columns
if len(seq1) < len(seq2):
seq1, seq2 = seq2, seq1
m, n = len(seq1), len(seq2)
# Only need two rows at a time
prev = [0] * (n + 1)
curr = [0] * (n + 1)
for i in range(1, m + 1):
for j in range(1, n + 1):
if seq1[i - 1] == seq2[j - 1]:
curr[j] = prev[j - 1] + 1
else:
curr[j] = max(prev[j], curr[j - 1])
prev, curr = curr, prev
return prev[n]
def _compute_rouge_l(
candidate_tokens: list[str],
reference_tokens: list[str],
) -> RougeScore:
"""
Compute ROUGE-L score using longest common subsequence.
Args:
candidate_tokens: Tokenised candidate text.
reference_tokens: Tokenised reference text.
Returns:
RougeScore with precision, recall, and F-measure.
"""
if not candidate_tokens and not reference_tokens:
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
if not candidate_tokens or not reference_tokens:
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
lcs = _lcs_length(candidate_tokens, reference_tokens)
precision = lcs / len(candidate_tokens)
recall = lcs / len(reference_tokens)
if precision + recall > 0:
fmeasure = 2 * precision * recall / (precision + recall)
else:
fmeasure = 0.0
return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure)
def _max_rouge_scores(scores: list[RougeScore]) -> RougeScore:
"""Select the RougeScore with the highest F-measure from a list."""
return max(scores, key=lambda s: s.fmeasure)
class Rouge:
"""
ROUGE metric for measuring summary/generation quality.
Computes ROUGE-1 (unigram), ROUGE-2 (bigram), and ROUGE-L (LCS) scores.
ROUGE is recall-oriented, measuring how much of the reference is captured.
"""
def __init__(self, tokeniser: WordTokeniser | None = None) -> None:
"""
Initialise the ROUGE metric.
Args:
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
"""
self._tokeniser = tokeniser or WordTokeniser()
@property
def name(self) -> str:
"""Return the name of this metric."""
return "rouge"
@property
def requires_reference(self) -> bool:
"""Return whether this metric requires reference text."""
return True
def score(
self, candidate: str, reference: str | list[str] | None = None
) -> RougeResult:
"""
Compute ROUGE scores for a candidate text.
Args:
candidate: The text to score.
reference: Reference text(s) for comparison. If multiple references
are provided, returns the maximum score for each variant.
Returns:
RougeResult with ROUGE-1, ROUGE-2, and ROUGE-L scores.
Raises:
ValueError: If reference is None or empty.
"""
if reference is None:
raise ValueError("ROUGE requires reference text")
# Normalise reference to list
references = [reference] if isinstance(reference, str) else reference
# Tokenise
candidate_tokens = self._tokeniser.tokenise(candidate)
reference_token_lists = [self._tokeniser.tokenise(r) for r in references]
# Handle empty references
if all(not ref for ref in reference_token_lists):
raise ValueError("Reference text cannot be empty")
# Handle empty candidate
if not candidate_tokens:
return RougeResult(
rouge1=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
rouge2=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
rouge_l=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
)
# Compute scores for each reference and take max
rouge1_scores = []
rouge2_scores = []
rouge_l_scores = []
for ref_tokens in reference_token_lists:
if not ref_tokens:
continue
rouge1_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 1))
rouge2_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 2))
rouge_l_scores.append(_compute_rouge_l(candidate_tokens, ref_tokens))
return RougeResult(
rouge1=_max_rouge_scores(rouge1_scores),
rouge2=_max_rouge_scores(rouge2_scores),
rouge_l=_max_rouge_scores(rouge_l_scores),
)
def batch_score(
self,
candidates: list[str],
references: list[str] | list[list[str]] | None = None,
) -> BatchResult[RougeResult]:
"""
Compute ROUGE scores for a batch of candidates.
Args:
candidates: List of texts to score.
references: Reference text(s) for each candidate.
Returns:
BatchResult containing individual results and aggregate statistics.
Raises:
ValueError: If references is None or length mismatch.
"""
if references is None:
raise ValueError("ROUGE requires reference texts")
if len(candidates) != len(references):
raise ValueError(
f"Number of candidates ({len(candidates)}) must match "
f"number of references ({len(references)})"
)
results: list[RougeResult] = []
for i, cand in enumerate(candidates):
ref: str | list[str] = references[i]
results.append(self.score(cand, ref))
# Compute aggregate statistics for each score type
stats = {
"rouge1_precision": AggregateStats.from_values(
[r.rouge1.precision for r in results]
),
"rouge1_recall": AggregateStats.from_values(
[r.rouge1.recall for r in results]
),
"rouge1_fmeasure": AggregateStats.from_values(
[r.rouge1.fmeasure for r in results]
),
"rouge2_precision": AggregateStats.from_values(
[r.rouge2.precision for r in results]
),
"rouge2_recall": AggregateStats.from_values(
[r.rouge2.recall for r in results]
),
"rouge2_fmeasure": AggregateStats.from_values(
[r.rouge2.fmeasure for r in results]
),
"rouge_l_precision": AggregateStats.from_values(
[r.rouge_l.precision for r in results]
),
"rouge_l_recall": AggregateStats.from_values(
[r.rouge_l.recall for r in results]
),
"rouge_l_fmeasure": AggregateStats.from_values(
[r.rouge_l.fmeasure for r in results]
),
}
return BatchResult(results=results, count=len(results), stats=stats)