feat: ROUGE-L scorer
This commit is contained in:
268
src/veritext/metrics/rouge.py
Normal file
268
src/veritext/metrics/rouge.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric implementation."""
|
||||
|
||||
from collections import Counter
|
||||
|
||||
from veritext.core.tokenisation import WordTokeniser
|
||||
from veritext.metrics.base import AggregateStats, BatchResult
|
||||
from veritext.metrics.results import RougeResult, RougeScore
|
||||
|
||||
|
||||
def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]:
|
||||
if n > len(tokens):
|
||||
return Counter()
|
||||
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
|
||||
|
||||
|
||||
def _ngram_overlap(
|
||||
candidate_ngrams: Counter[tuple[str, ...]],
|
||||
reference_ngrams: Counter[tuple[str, ...]],
|
||||
) -> int:
|
||||
overlap = 0
|
||||
for ngram, count in candidate_ngrams.items():
|
||||
overlap += min(count, reference_ngrams.get(ngram, 0))
|
||||
return overlap
|
||||
|
||||
|
||||
def _compute_rouge_score(
|
||||
candidate_tokens: list[str],
|
||||
reference_tokens: list[str],
|
||||
n: int,
|
||||
) -> RougeScore:
|
||||
"""
|
||||
Compute ROUGE-n score for given n-gram size.
|
||||
|
||||
Args:
|
||||
candidate_tokens: Tokenised candidate text.
|
||||
reference_tokens: Tokenised reference text.
|
||||
n: N-gram size.
|
||||
|
||||
Returns:
|
||||
RougeScore with precision, recall, and F-measure.
|
||||
"""
|
||||
candidate_ngrams = _get_ngrams(candidate_tokens, n)
|
||||
reference_ngrams = _get_ngrams(reference_tokens, n)
|
||||
|
||||
candidate_count = sum(candidate_ngrams.values())
|
||||
reference_count = sum(reference_ngrams.values())
|
||||
|
||||
if candidate_count == 0 and reference_count == 0:
|
||||
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
|
||||
|
||||
overlap = _ngram_overlap(candidate_ngrams, reference_ngrams)
|
||||
|
||||
precision = overlap / candidate_count if candidate_count > 0 else 0.0
|
||||
recall = overlap / reference_count if reference_count > 0 else 0.0
|
||||
|
||||
if precision + recall > 0:
|
||||
fmeasure = 2 * precision * recall / (precision + recall)
|
||||
else:
|
||||
fmeasure = 0.0
|
||||
|
||||
return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||
|
||||
|
||||
def _lcs_length(seq1: list[str], seq2: list[str]) -> int:
|
||||
"""
|
||||
Compute the length of the longest common subsequence.
|
||||
|
||||
Uses dynamic programming with O(m*n) time and O(min(m,n)) space.
|
||||
"""
|
||||
if not seq1 or not seq2:
|
||||
return 0
|
||||
|
||||
if len(seq1) < len(seq2):
|
||||
seq1, seq2 = seq2, seq1
|
||||
|
||||
m, n = len(seq1), len(seq2)
|
||||
|
||||
prev = [0] * (n + 1)
|
||||
curr = [0] * (n + 1)
|
||||
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if seq1[i - 1] == seq2[j - 1]:
|
||||
curr[j] = prev[j - 1] + 1
|
||||
else:
|
||||
curr[j] = max(prev[j], curr[j - 1])
|
||||
prev, curr = curr, prev
|
||||
|
||||
return prev[n]
|
||||
|
||||
|
||||
def _compute_rouge_l(
|
||||
candidate_tokens: list[str],
|
||||
reference_tokens: list[str],
|
||||
) -> RougeScore:
|
||||
"""
|
||||
Compute ROUGE-L score using longest common subsequence.
|
||||
|
||||
Args:
|
||||
candidate_tokens: Tokenised candidate text.
|
||||
reference_tokens: Tokenised reference text.
|
||||
|
||||
Returns:
|
||||
RougeScore with precision, recall, and F-measure.
|
||||
"""
|
||||
if not candidate_tokens or not reference_tokens:
|
||||
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
|
||||
|
||||
lcs = _lcs_length(candidate_tokens, reference_tokens)
|
||||
|
||||
precision = lcs / len(candidate_tokens)
|
||||
recall = lcs / len(reference_tokens)
|
||||
|
||||
if precision + recall > 0:
|
||||
fmeasure = 2 * precision * recall / (precision + recall)
|
||||
else:
|
||||
fmeasure = 0.0
|
||||
|
||||
return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||
|
||||
|
||||
def _max_rouge_scores(scores: list[RougeScore]) -> RougeScore:
|
||||
return max(scores, key=lambda s: s.fmeasure)
|
||||
|
||||
|
||||
class Rouge:
|
||||
"""
|
||||
ROUGE metric for measuring summary/generation quality.
|
||||
|
||||
Computes ROUGE-1 (unigram), ROUGE-2 (bigram), and ROUGE-L (LCS) scores.
|
||||
ROUGE is recall-oriented, measuring how much of the reference is captured.
|
||||
"""
|
||||
|
||||
def __init__(self, tokeniser: WordTokeniser | None = None) -> None:
|
||||
"""
|
||||
Initialise the ROUGE metric.
|
||||
|
||||
Args:
|
||||
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||
"""
|
||||
self._tokeniser = tokeniser or WordTokeniser()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "rouge"
|
||||
|
||||
@property
|
||||
def requires_reference(self) -> bool:
|
||||
return True
|
||||
|
||||
def score(
|
||||
self, candidate: str, reference: str | list[str] | None = None
|
||||
) -> RougeResult:
|
||||
"""
|
||||
Compute ROUGE scores for a candidate text.
|
||||
|
||||
Args:
|
||||
candidate: The text to score.
|
||||
reference: Reference text(s) for comparison. If multiple references
|
||||
are provided, returns the maximum score for each variant.
|
||||
|
||||
Returns:
|
||||
RougeResult with ROUGE-1, ROUGE-2, and ROUGE-L scores.
|
||||
|
||||
Raises:
|
||||
ValueError: If reference is None or empty.
|
||||
"""
|
||||
if reference is None:
|
||||
raise ValueError("ROUGE requires reference text")
|
||||
|
||||
references = [reference] if isinstance(reference, str) else reference
|
||||
|
||||
candidate_tokens = self._tokeniser.tokenise(candidate)
|
||||
reference_token_lists = [self._tokeniser.tokenise(r) for r in references]
|
||||
|
||||
if all(not ref for ref in reference_token_lists):
|
||||
raise ValueError("Reference text cannot be empty")
|
||||
|
||||
if not candidate_tokens:
|
||||
return RougeResult(
|
||||
rouge1=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
rouge2=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
rouge_l=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||
)
|
||||
|
||||
rouge1_scores = []
|
||||
rouge2_scores = []
|
||||
rouge_l_scores = []
|
||||
|
||||
for ref_tokens in reference_token_lists:
|
||||
if not ref_tokens:
|
||||
continue
|
||||
rouge1_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 1))
|
||||
rouge2_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 2))
|
||||
rouge_l_scores.append(_compute_rouge_l(candidate_tokens, ref_tokens))
|
||||
|
||||
if not rouge1_scores:
|
||||
raise ValueError("Reference text cannot be empty")
|
||||
|
||||
return RougeResult(
|
||||
rouge1=_max_rouge_scores(rouge1_scores),
|
||||
rouge2=_max_rouge_scores(rouge2_scores),
|
||||
rouge_l=_max_rouge_scores(rouge_l_scores),
|
||||
)
|
||||
|
||||
def batch_score(
|
||||
self,
|
||||
candidates: list[str],
|
||||
references: list[str] | list[list[str]] | None = None,
|
||||
) -> BatchResult[RougeResult]:
|
||||
"""
|
||||
Compute ROUGE scores for a batch of candidates.
|
||||
|
||||
Args:
|
||||
candidates: List of texts to score.
|
||||
references: Reference text(s) for each candidate.
|
||||
|
||||
Returns:
|
||||
BatchResult containing individual results and aggregate statistics.
|
||||
|
||||
Raises:
|
||||
ValueError: If references is None or length mismatch.
|
||||
"""
|
||||
if references is None:
|
||||
raise ValueError("ROUGE requires reference texts")
|
||||
|
||||
if len(candidates) != len(references):
|
||||
raise ValueError(
|
||||
f"Number of candidates ({len(candidates)}) must match "
|
||||
f"number of references ({len(references)})"
|
||||
)
|
||||
|
||||
results: list[RougeResult] = []
|
||||
for i, cand in enumerate(candidates):
|
||||
ref: str | list[str] = references[i]
|
||||
results.append(self.score(cand, ref))
|
||||
|
||||
stats = {
|
||||
"rouge1_precision": AggregateStats.from_values(
|
||||
[r.rouge1.precision for r in results]
|
||||
),
|
||||
"rouge1_recall": AggregateStats.from_values(
|
||||
[r.rouge1.recall for r in results]
|
||||
),
|
||||
"rouge1_fmeasure": AggregateStats.from_values(
|
||||
[r.rouge1.fmeasure for r in results]
|
||||
),
|
||||
"rouge2_precision": AggregateStats.from_values(
|
||||
[r.rouge2.precision for r in results]
|
||||
),
|
||||
"rouge2_recall": AggregateStats.from_values(
|
||||
[r.rouge2.recall for r in results]
|
||||
),
|
||||
"rouge2_fmeasure": AggregateStats.from_values(
|
||||
[r.rouge2.fmeasure for r in results]
|
||||
),
|
||||
"rouge_l_precision": AggregateStats.from_values(
|
||||
[r.rouge_l.precision for r in results]
|
||||
),
|
||||
"rouge_l_recall": AggregateStats.from_values(
|
||||
[r.rouge_l.recall for r in results]
|
||||
),
|
||||
"rouge_l_fmeasure": AggregateStats.from_values(
|
||||
[r.rouge_l.fmeasure for r in results]
|
||||
),
|
||||
}
|
||||
|
||||
return BatchResult(results=results, count=len(results), stats=stats)
|
||||
Reference in New Issue
Block a user