Add comprehensive tests for BLEU and lexical metrics including edge cases, batch scoring, and aggregate statistics.
201 lines
6.6 KiB
Python
201 lines
6.6 KiB
Python
"""Tests for the BLEU metric."""
|
|
|
|
import pytest
|
|
|
|
from veritext.metrics import Bleu, BleuResult
|
|
|
|
|
|
class TestBleu:
|
|
@pytest.fixture
|
|
def bleu(self) -> Bleu:
|
|
return Bleu()
|
|
|
|
def test_name(self, bleu: Bleu) -> None:
|
|
assert bleu.name == "bleu"
|
|
|
|
def test_requires_reference(self, bleu: Bleu) -> None:
|
|
assert bleu.requires_reference is True
|
|
|
|
def test_identical_texts(self, bleu: Bleu) -> None:
|
|
text = "The cat sat on the mat"
|
|
result = bleu.score(text, text)
|
|
|
|
assert result.bleu1 == 1.0
|
|
assert result.bleu2 == 1.0
|
|
assert result.bleu3 == 1.0
|
|
assert result.bleu4 == 1.0
|
|
assert result.brevity_penalty == 1.0
|
|
|
|
def test_no_overlap(self, bleu: Bleu) -> None:
|
|
candidate = "The cat sat"
|
|
reference = "A dog runs fast"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
assert result.bleu1 == 0.0
|
|
assert result.bleu2 == 0.0
|
|
assert result.bleu3 == 0.0
|
|
assert result.bleu4 == 0.0
|
|
|
|
def test_partial_overlap(self, bleu: Bleu) -> None:
|
|
candidate = "The cat sat on the mat"
|
|
reference = "The cat is on the floor"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
assert 0.0 < result.bleu1 < 1.0
|
|
assert 0.0 < result.bleu2 < 1.0
|
|
|
|
def test_brevity_penalty_applied(self, bleu: Bleu) -> None:
|
|
candidate = "cat"
|
|
reference = "The cat sat on the mat"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
assert result.brevity_penalty < 1.0
|
|
|
|
def test_no_brevity_penalty_long_cand(self, bleu: Bleu) -> None:
|
|
candidate = "The cat sat on the mat today"
|
|
reference = "The cat sat"
|
|
result = bleu.score(candidate, reference)
|
|
|
|
assert result.brevity_penalty == 1.0
|
|
|
|
def test_single_word_texts(self, bleu: Bleu) -> None:
|
|
result = bleu.score("cat", "cat")
|
|
|
|
assert result.bleu1 == 1.0
|
|
assert result.brevity_penalty == 1.0
|
|
|
|
def test_empty_candidate(self, bleu: Bleu) -> None:
|
|
result = bleu.score("", "The cat sat")
|
|
|
|
assert result.bleu1 == 0.0
|
|
assert result.bleu2 == 0.0
|
|
assert result.bleu3 == 0.0
|
|
assert result.bleu4 == 0.0
|
|
assert result.brevity_penalty == 0.0
|
|
|
|
def test_whitespace_only_candidate(self, bleu: Bleu) -> None:
|
|
result = bleu.score(" \t\n ", "The cat sat")
|
|
|
|
assert result.bleu1 == 0.0
|
|
assert result.bleu4 == 0.0
|
|
|
|
def test_empty_reference_raises(self, bleu: Bleu) -> None:
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
bleu.score("The cat sat", "")
|
|
|
|
def test_none_reference_raises(self, bleu: Bleu) -> None:
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
bleu.score("The cat sat", None)
|
|
|
|
def test_multiple_references(self, bleu: Bleu) -> None:
|
|
candidate = "The cat sat on the mat"
|
|
references = [
|
|
"The cat is on the floor",
|
|
"The cat sat on the mat", # Exact match
|
|
]
|
|
result = bleu.score(candidate, references)
|
|
|
|
assert result.bleu4 == 1.0
|
|
|
|
def test_multiple_references_partial(self, bleu: Bleu) -> None:
|
|
candidate = "The quick brown fox"
|
|
references = [
|
|
"The fast brown fox",
|
|
"A quick brown dog",
|
|
]
|
|
result = bleu.score(candidate, references)
|
|
|
|
assert result.bleu1 > 0.0
|
|
|
|
def test_result_score_property(self, bleu: Bleu) -> None:
|
|
result = bleu.score("The cat sat", "The cat sat")
|
|
assert result.score == result.bleu4
|
|
|
|
def test_case_insensitivity(self, bleu: Bleu) -> None:
|
|
result = bleu.score("THE CAT SAT ON THE MAT", "the cat sat on the mat")
|
|
assert result.bleu4 == 1.0
|
|
|
|
def test_punctuation_ignored(self, bleu: Bleu) -> None:
|
|
result = bleu.score("The cat sat on the mat.", "The cat sat on the mat!")
|
|
assert result.bleu4 == 1.0
|
|
|
|
|
|
class TestBleuBatch:
|
|
@pytest.fixture
|
|
def bleu(self) -> Bleu:
|
|
return Bleu()
|
|
|
|
def test_batch_score_basic(self, bleu: Bleu) -> None:
|
|
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
|
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert len(result.results) == 2
|
|
assert all(r.bleu4 == 1.0 for r in result.results)
|
|
|
|
def test_batch_score_statistics(self, bleu: Bleu) -> None:
|
|
candidates = ["The cat sat", "A different text entirely"]
|
|
references = ["The cat sat", "The cat sat"]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
# Check statistics are computed
|
|
assert "bleu1" in result.stats
|
|
assert "bleu4" in result.stats
|
|
assert "brevity_penalty" in result.stats
|
|
|
|
# Mean should be between min and max
|
|
stats = result.stats["bleu4"]
|
|
assert stats.min <= stats.mean <= stats.max
|
|
|
|
def test_batch_score_percentiles(self, bleu: Bleu) -> None:
|
|
candidates = ["a", "b", "c", "d", "e"]
|
|
references = ["a", "b", "c", "d", "e"]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
stats = result.stats["bleu1"]
|
|
assert 25 in stats.percentiles
|
|
assert 50 in stats.percentiles
|
|
assert 75 in stats.percentiles
|
|
assert 95 in stats.percentiles
|
|
|
|
def test_batch_score_none_references_raises(self, bleu: Bleu) -> None:
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
bleu.batch_score(["text"], None)
|
|
|
|
def test_batch_score_length_mismatch_raises(self, bleu: Bleu) -> None:
|
|
with pytest.raises(ValueError, match="must match"):
|
|
bleu.batch_score(["a", "b"], ["a"])
|
|
|
|
def test_batch_score_multi_refs(self, bleu: Bleu) -> None:
|
|
candidates = [
|
|
"The cat sat on the mat",
|
|
"A quick brown dog runs fast",
|
|
]
|
|
references = [
|
|
["The cat sat on the mat", "A cat rests on floor"],
|
|
["A quick brown dog runs fast", "Dogs run very quickly"],
|
|
]
|
|
result = bleu.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert result.results[0].bleu4 == 1.0
|
|
assert result.results[1].bleu4 == 1.0
|
|
|
|
|
|
class TestBleuResult:
|
|
def test_frozen(self) -> None:
|
|
from pydantic import ValidationError
|
|
|
|
result = BleuResult(
|
|
bleu1=0.5, bleu2=0.4, bleu3=0.3, bleu4=0.2, brevity_penalty=1.0
|
|
)
|
|
with pytest.raises(ValidationError):
|
|
result.bleu1 = 0.6 # type: ignore[misc]
|
|
|
|
def test_score_property(self) -> None:
|
|
result = BleuResult(
|
|
bleu1=0.9, bleu2=0.8, bleu3=0.7, bleu4=0.6, brevity_penalty=1.0
|
|
)
|
|
assert result.score == 0.6
|