Add comprehensive tests for BLEU and lexical metrics including edge cases, batch scoring, and aggregate statistics.
241 lines
8.7 KiB
Python
241 lines
8.7 KiB
Python
"""Tests for the BLEU metric."""
|
|
|
|
import pytest
|
|
|
|
from veritext.metrics import Bleu, BleuResult
|
|
|
|
|
|
class TestBleu:
|
|
"""Tests for the Bleu metric class."""
|
|
|
|
@pytest.fixture
|
|
def bleu(self) -> Bleu:
|
|
"""Provide a BLEU metric instance."""
|
|
return Bleu()
|
|
|
|
def test_name(self, bleu: Bleu) -> None:
|
|
"""Test that name returns 'bleu'."""
|
|
assert bleu.name == "bleu"
|
|
|
|
def test_requires_reference(self, bleu: Bleu) -> None:
|
|
"""Test that BLEU requires reference text."""
|
|
assert bleu.requires_reference is True
|
|
|
|
def test_identical_texts(self, bleu: Bleu) -> None:
|
|
"""Test that identical texts produce perfect scores."""
|
|
text = "The cat sat on the mat"
|
|
result = bleu.score(text, text)
|
|
|
|
assert result.bleu1 == 1.0
|
|
assert result.bleu2 == 1.0
|
|
assert result.bleu3 == 1.0
|
|
assert result.bleu4 == 1.0
|
|
assert result.brevity_penalty == 1.0
|
|
|
|
def test_no_overlap(self, bleu: Bleu) -> None:
|
|
"""Test that texts with no overlap produce zero scores."""
|
|
candidate = "The cat sat"
|
|
reference = "A dog runs fast"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
assert result.bleu1 == 0.0
|
|
assert result.bleu2 == 0.0
|
|
assert result.bleu3 == 0.0
|
|
assert result.bleu4 == 0.0
|
|
|
|
def test_partial_overlap(self, bleu: Bleu) -> None:
|
|
"""Test partial overlap produces intermediate scores."""
|
|
candidate = "The cat sat on the mat"
|
|
reference = "The cat is on the floor"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
# Should have some overlap but not perfect
|
|
assert 0.0 < result.bleu1 < 1.0
|
|
assert 0.0 < result.bleu2 < 1.0
|
|
|
|
def test_brevity_penalty_applied(self, bleu: Bleu) -> None:
|
|
"""Test that brevity penalty is applied for short candidates."""
|
|
candidate = "cat"
|
|
reference = "The cat sat on the mat"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
# Candidate is much shorter, so brevity penalty should be < 1
|
|
assert result.brevity_penalty < 1.0
|
|
|
|
def test_no_brevity_penalty_for_long_candidate(self, bleu: Bleu) -> None:
|
|
"""Test that no brevity penalty for longer candidates."""
|
|
candidate = "The cat sat on the mat today"
|
|
reference = "The cat sat"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
# Candidate is longer, so no brevity penalty
|
|
assert result.brevity_penalty == 1.0
|
|
|
|
def test_single_word_texts(self, bleu: Bleu) -> None:
|
|
"""Test BLEU with single word texts."""
|
|
result = bleu.score("cat", "cat")
|
|
|
|
assert result.bleu1 == 1.0
|
|
# Higher order n-grams can't be computed for single words
|
|
# but with geometric mean, they still produce a score
|
|
assert result.brevity_penalty == 1.0
|
|
|
|
def test_empty_candidate(self, bleu: Bleu) -> None:
|
|
"""Test that empty candidate returns zero scores."""
|
|
result = bleu.score("", "The cat sat")
|
|
|
|
assert result.bleu1 == 0.0
|
|
assert result.bleu2 == 0.0
|
|
assert result.bleu3 == 0.0
|
|
assert result.bleu4 == 0.0
|
|
assert result.brevity_penalty == 0.0
|
|
|
|
def test_whitespace_only_candidate(self, bleu: Bleu) -> None:
|
|
"""Test that whitespace-only candidate returns zero scores."""
|
|
result = bleu.score(" \t\n ", "The cat sat")
|
|
|
|
assert result.bleu1 == 0.0
|
|
assert result.bleu4 == 0.0
|
|
|
|
def test_empty_reference_raises(self, bleu: Bleu) -> None:
|
|
"""Test that empty reference raises ValueError."""
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
bleu.score("The cat sat", "")
|
|
|
|
def test_none_reference_raises(self, bleu: Bleu) -> None:
|
|
"""Test that None reference raises ValueError."""
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
bleu.score("The cat sat", None)
|
|
|
|
def test_multiple_references(self, bleu: Bleu) -> None:
|
|
"""Test BLEU with multiple references uses max n-gram matches."""
|
|
candidate = "The cat sat on the mat"
|
|
references = [
|
|
"The cat is on the floor",
|
|
"The cat sat on the mat", # Exact match
|
|
]
|
|
result = bleu.score(candidate, references)
|
|
|
|
# Should get perfect score due to exact match reference
|
|
assert result.bleu4 == 1.0
|
|
|
|
def test_multiple_references_partial(self, bleu: Bleu) -> None:
|
|
"""Test multiple references with partial matches."""
|
|
candidate = "The quick brown fox"
|
|
references = [
|
|
"The fast brown fox",
|
|
"A quick brown dog",
|
|
]
|
|
result = bleu.score(candidate, references)
|
|
|
|
# Should benefit from both references
|
|
assert result.bleu1 > 0.0
|
|
|
|
def test_result_score_property(self, bleu: Bleu) -> None:
|
|
"""Test that result.score returns bleu4."""
|
|
result = bleu.score("The cat sat", "The cat sat")
|
|
assert result.score == result.bleu4
|
|
|
|
def test_case_insensitivity(self, bleu: Bleu) -> None:
|
|
"""Test that BLEU is case insensitive by default."""
|
|
result = bleu.score("THE CAT SAT ON THE MAT", "the cat sat on the mat")
|
|
assert result.bleu4 == 1.0
|
|
|
|
def test_punctuation_ignored(self, bleu: Bleu) -> None:
|
|
"""Test that punctuation is ignored by default."""
|
|
result = bleu.score("The cat sat on the mat.", "The cat sat on the mat!")
|
|
assert result.bleu4 == 1.0
|
|
|
|
|
|
class TestBleuBatch:
|
|
"""Tests for BLEU batch scoring."""
|
|
|
|
@pytest.fixture
|
|
def bleu(self) -> Bleu:
|
|
"""Provide a BLEU metric instance."""
|
|
return Bleu()
|
|
|
|
def test_batch_score_basic(self, bleu: Bleu) -> None:
|
|
"""Test basic batch scoring."""
|
|
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
|
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert len(result.results) == 2
|
|
assert all(r.bleu4 == 1.0 for r in result.results)
|
|
|
|
def test_batch_score_statistics(self, bleu: Bleu) -> None:
|
|
"""Test that batch scoring computes statistics."""
|
|
candidates = ["The cat sat", "A different text entirely"]
|
|
references = ["The cat sat", "The cat sat"]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
# Check statistics are computed
|
|
assert "bleu1" in result.stats
|
|
assert "bleu4" in result.stats
|
|
assert "brevity_penalty" in result.stats
|
|
|
|
# Mean should be between min and max
|
|
stats = result.stats["bleu4"]
|
|
assert stats.min <= stats.mean <= stats.max
|
|
|
|
def test_batch_score_percentiles(self, bleu: Bleu) -> None:
|
|
"""Test that batch scoring computes percentiles."""
|
|
candidates = ["a", "b", "c", "d", "e"]
|
|
references = ["a", "b", "c", "d", "e"]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
stats = result.stats["bleu1"]
|
|
assert 25 in stats.percentiles
|
|
assert 50 in stats.percentiles
|
|
assert 75 in stats.percentiles
|
|
assert 95 in stats.percentiles
|
|
|
|
def test_batch_score_none_references_raises(self, bleu: Bleu) -> None:
|
|
"""Test that batch scoring raises for None references."""
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
bleu.batch_score(["text"], None)
|
|
|
|
def test_batch_score_length_mismatch_raises(self, bleu: Bleu) -> None:
|
|
"""Test that batch scoring raises for mismatched lengths."""
|
|
with pytest.raises(ValueError, match="must match"):
|
|
bleu.batch_score(["a", "b"], ["a"])
|
|
|
|
def test_batch_score_with_multiple_references(self, bleu: Bleu) -> None:
|
|
"""Test batch scoring with multiple references per candidate."""
|
|
candidates = [
|
|
"The cat sat on the mat",
|
|
"A quick brown dog runs fast",
|
|
]
|
|
references = [
|
|
["The cat sat on the mat", "A cat rests on floor"],
|
|
["A quick brown dog runs fast", "Dogs run very quickly"],
|
|
]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert result.results[0].bleu4 == 1.0
|
|
assert result.results[1].bleu4 == 1.0
|
|
|
|
|
|
class TestBleuResult:
|
|
"""Tests for BleuResult type."""
|
|
|
|
def test_frozen(self) -> None:
|
|
"""Test that BleuResult is frozen."""
|
|
from pydantic import ValidationError
|
|
|
|
result = BleuResult(
|
|
bleu1=0.5, bleu2=0.4, bleu3=0.3, bleu4=0.2, brevity_penalty=1.0
|
|
)
|
|
with pytest.raises(ValidationError):
|
|
result.bleu1 = 0.6 # type: ignore[misc]
|
|
|
|
def test_score_property(self) -> None:
|
|
"""Test that score property returns bleu4."""
|
|
result = BleuResult(
|
|
bleu1=0.9, bleu2=0.8, bleu3=0.7, bleu4=0.6, brevity_penalty=1.0
|
|
)
|
|
assert result.score == 0.6
|