test BLEU and lexical metrics
Add comprehensive tests for BLEU and lexical metrics including edge cases, batch scoring, and aggregate statistics.
This commit is contained in:
1
tests/test_metrics/__init__.py
Normal file
1
tests/test_metrics/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the metrics module."""
|
||||
200
tests/test_metrics/test_bleu.py
Normal file
200
tests/test_metrics/test_bleu.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Tests for the BLEU metric."""
|
||||
|
||||
import pytest
|
||||
|
||||
from veritext.metrics import Bleu, BleuResult
|
||||
|
||||
|
||||
class TestBleu:
|
||||
@pytest.fixture
|
||||
def bleu(self) -> Bleu:
|
||||
return Bleu()
|
||||
|
||||
def test_name(self, bleu: Bleu) -> None:
|
||||
assert bleu.name == "bleu"
|
||||
|
||||
def test_requires_reference(self, bleu: Bleu) -> None:
|
||||
assert bleu.requires_reference is True
|
||||
|
||||
def test_identical_texts(self, bleu: Bleu) -> None:
|
||||
text = "The cat sat on the mat"
|
||||
result = bleu.score(text, text)
|
||||
|
||||
assert result.bleu1 == 1.0
|
||||
assert result.bleu2 == 1.0
|
||||
assert result.bleu3 == 1.0
|
||||
assert result.bleu4 == 1.0
|
||||
assert result.brevity_penalty == 1.0
|
||||
|
||||
def test_no_overlap(self, bleu: Bleu) -> None:
|
||||
candidate = "The cat sat"
|
||||
reference = "A dog runs fast"
|
||||
result = bleu.score(candidate, reference)
|
||||
|
||||
assert result.bleu1 == 0.0
|
||||
assert result.bleu2 == 0.0
|
||||
assert result.bleu3 == 0.0
|
||||
assert result.bleu4 == 0.0
|
||||
|
||||
def test_partial_overlap(self, bleu: Bleu) -> None:
|
||||
candidate = "The cat sat on the mat"
|
||||
reference = "The cat is on the floor"
|
||||
result = bleu.score(candidate, reference)
|
||||
|
||||
assert 0.0 < result.bleu1 < 1.0
|
||||
assert 0.0 < result.bleu2 < 1.0
|
||||
|
||||
def test_brevity_penalty_applied(self, bleu: Bleu) -> None:
|
||||
candidate = "cat"
|
||||
reference = "The cat sat on the mat"
|
||||
result = bleu.score(candidate, reference)
|
||||
|
||||
assert result.brevity_penalty < 1.0
|
||||
|
||||
def test_no_brevity_penalty_long_cand(self, bleu: Bleu) -> None:
|
||||
candidate = "The cat sat on the mat today"
|
||||
reference = "The cat sat"
|
||||
result = bleu.score(candidate, reference)
|
||||
|
||||
assert result.brevity_penalty == 1.0
|
||||
|
||||
def test_single_word_texts(self, bleu: Bleu) -> None:
|
||||
result = bleu.score("cat", "cat")
|
||||
|
||||
assert result.bleu1 == 1.0
|
||||
assert result.brevity_penalty == 1.0
|
||||
|
||||
def test_empty_candidate(self, bleu: Bleu) -> None:
|
||||
result = bleu.score("", "The cat sat")
|
||||
|
||||
assert result.bleu1 == 0.0
|
||||
assert result.bleu2 == 0.0
|
||||
assert result.bleu3 == 0.0
|
||||
assert result.bleu4 == 0.0
|
||||
assert result.brevity_penalty == 0.0
|
||||
|
||||
def test_whitespace_only_candidate(self, bleu: Bleu) -> None:
|
||||
result = bleu.score(" \t\n ", "The cat sat")
|
||||
|
||||
assert result.bleu1 == 0.0
|
||||
assert result.bleu4 == 0.0
|
||||
|
||||
def test_empty_reference_raises(self, bleu: Bleu) -> None:
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
bleu.score("The cat sat", "")
|
||||
|
||||
def test_none_reference_raises(self, bleu: Bleu) -> None:
|
||||
with pytest.raises(ValueError, match="requires reference"):
|
||||
bleu.score("The cat sat", None)
|
||||
|
||||
def test_multiple_references(self, bleu: Bleu) -> None:
|
||||
candidate = "The cat sat on the mat"
|
||||
references = [
|
||||
"The cat is on the floor",
|
||||
"The cat sat on the mat", # Exact match
|
||||
]
|
||||
result = bleu.score(candidate, references)
|
||||
|
||||
assert result.bleu4 == 1.0
|
||||
|
||||
def test_multiple_references_partial(self, bleu: Bleu) -> None:
|
||||
candidate = "The quick brown fox"
|
||||
references = [
|
||||
"The fast brown fox",
|
||||
"A quick brown dog",
|
||||
]
|
||||
result = bleu.score(candidate, references)
|
||||
|
||||
assert result.bleu1 > 0.0
|
||||
|
||||
def test_result_score_property(self, bleu: Bleu) -> None:
|
||||
result = bleu.score("The cat sat", "The cat sat")
|
||||
assert result.score == result.bleu4
|
||||
|
||||
def test_case_insensitivity(self, bleu: Bleu) -> None:
|
||||
result = bleu.score("THE CAT SAT ON THE MAT", "the cat sat on the mat")
|
||||
assert result.bleu4 == 1.0
|
||||
|
||||
def test_punctuation_ignored(self, bleu: Bleu) -> None:
|
||||
result = bleu.score("The cat sat on the mat.", "The cat sat on the mat!")
|
||||
assert result.bleu4 == 1.0
|
||||
|
||||
|
||||
class TestBleuBatch:
|
||||
@pytest.fixture
|
||||
def bleu(self) -> Bleu:
|
||||
return Bleu()
|
||||
|
||||
def test_batch_score_basic(self, bleu: Bleu) -> None:
|
||||
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||
result = bleu.batch_score(candidates, references)
|
||||
|
||||
assert result.count == 2
|
||||
assert len(result.results) == 2
|
||||
assert all(r.bleu4 == 1.0 for r in result.results)
|
||||
|
||||
def test_batch_score_statistics(self, bleu: Bleu) -> None:
|
||||
candidates = ["The cat sat", "A different text entirely"]
|
||||
references = ["The cat sat", "The cat sat"]
|
||||
result = bleu.batch_score(candidates, references)
|
||||
|
||||
# Check statistics are computed
|
||||
assert "bleu1" in result.stats
|
||||
assert "bleu4" in result.stats
|
||||
assert "brevity_penalty" in result.stats
|
||||
|
||||
# Mean should be between min and max
|
||||
stats = result.stats["bleu4"]
|
||||
assert stats.min <= stats.mean <= stats.max
|
||||
|
||||
def test_batch_score_percentiles(self, bleu: Bleu) -> None:
|
||||
candidates = ["a", "b", "c", "d", "e"]
|
||||
references = ["a", "b", "c", "d", "e"]
|
||||
result = bleu.batch_score(candidates, references)
|
||||
|
||||
stats = result.stats["bleu1"]
|
||||
assert 25 in stats.percentiles
|
||||
assert 50 in stats.percentiles
|
||||
assert 75 in stats.percentiles
|
||||
assert 95 in stats.percentiles
|
||||
|
||||
def test_batch_score_none_references_raises(self, bleu: Bleu) -> None:
|
||||
with pytest.raises(ValueError, match="requires reference"):
|
||||
bleu.batch_score(["text"], None)
|
||||
|
||||
def test_batch_score_length_mismatch_raises(self, bleu: Bleu) -> None:
|
||||
with pytest.raises(ValueError, match="must match"):
|
||||
bleu.batch_score(["a", "b"], ["a"])
|
||||
|
||||
def test_batch_score_multi_refs(self, bleu: Bleu) -> None:
|
||||
candidates = [
|
||||
"The cat sat on the mat",
|
||||
"A quick brown dog runs fast",
|
||||
]
|
||||
references = [
|
||||
["The cat sat on the mat", "A cat rests on floor"],
|
||||
["A quick brown dog runs fast", "Dogs run very quickly"],
|
||||
]
|
||||
result = bleu.batch_score(candidates, references)
|
||||
|
||||
assert result.count == 2
|
||||
assert result.results[0].bleu4 == 1.0
|
||||
assert result.results[1].bleu4 == 1.0
|
||||
|
||||
|
||||
class TestBleuResult:
|
||||
def test_frozen(self) -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
result = BleuResult(
|
||||
bleu1=0.5, bleu2=0.4, bleu3=0.3, bleu4=0.2, brevity_penalty=1.0
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
result.bleu1 = 0.6 # type: ignore[misc]
|
||||
|
||||
def test_score_property(self) -> None:
|
||||
result = BleuResult(
|
||||
bleu1=0.9, bleu2=0.8, bleu3=0.7, bleu4=0.6, brevity_penalty=1.0
|
||||
)
|
||||
assert result.score == 0.6
|
||||
184
tests/test_metrics/test_lexical.py
Normal file
184
tests/test_metrics/test_lexical.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for the lexical similarity metric."""
|
||||
|
||||
import pytest
|
||||
|
||||
from veritext.metrics import Lexical, LexicalResult
|
||||
|
||||
|
||||
class TestLexical:
|
||||
@pytest.fixture
|
||||
def lexical(self) -> Lexical:
|
||||
return Lexical()
|
||||
|
||||
def test_name(self, lexical: Lexical) -> None:
|
||||
assert lexical.name == "lexical"
|
||||
|
||||
def test_requires_reference(self, lexical: Lexical) -> None:
|
||||
assert lexical.requires_reference is True
|
||||
|
||||
def test_identical_texts(self, lexical: Lexical) -> None:
|
||||
text = "The cat sat on the mat"
|
||||
result = lexical.score(text, text)
|
||||
|
||||
assert result.jaccard == 1.0
|
||||
assert result.token_overlap == 1.0
|
||||
|
||||
def test_no_overlap(self, lexical: Lexical) -> None:
|
||||
candidate = "apple banana cherry"
|
||||
reference = "dog elephant fox"
|
||||
result = lexical.score(candidate, reference)
|
||||
|
||||
assert result.jaccard == 0.0
|
||||
assert result.token_overlap == 0.0
|
||||
|
||||
def test_partial_overlap_jaccard(self, lexical: Lexical) -> None:
|
||||
candidate = "the cat sat"
|
||||
reference = "the dog sat"
|
||||
result = lexical.score(candidate, reference)
|
||||
|
||||
# Intersection: {the, sat}, Union: {the, cat, sat, dog}
|
||||
# Jaccard = 2/4 = 0.5
|
||||
assert result.jaccard == 0.5
|
||||
|
||||
def test_partial_overlap_token_overlap(self, lexical: Lexical) -> None:
|
||||
candidate = "the cat sat"
|
||||
reference = "the dog sat"
|
||||
result = lexical.score(candidate, reference)
|
||||
|
||||
# Candidate tokens: {the, cat, sat}
|
||||
# In reference: {the, sat}
|
||||
# Overlap = 2/3
|
||||
assert abs(result.token_overlap - 2 / 3) < 1e-10
|
||||
|
||||
def test_candidate_subset_of_reference(self, lexical: Lexical) -> None:
|
||||
candidate = "the cat"
|
||||
reference = "the cat sat on the mat"
|
||||
result = lexical.score(candidate, reference)
|
||||
|
||||
# All candidate tokens are in reference
|
||||
assert result.token_overlap == 1.0
|
||||
# But Jaccard is less than 1 due to extra tokens in reference
|
||||
assert result.jaccard < 1.0
|
||||
|
||||
def test_reference_subset_of_candidate(self, lexical: Lexical) -> None:
|
||||
candidate = "the cat sat on the mat"
|
||||
reference = "the cat"
|
||||
result = lexical.score(candidate, reference)
|
||||
|
||||
# Jaccard is less than 1
|
||||
assert result.jaccard < 1.0
|
||||
# Token overlap is less than 1
|
||||
assert result.token_overlap < 1.0
|
||||
|
||||
def test_empty_candidate(self, lexical: Lexical) -> None:
|
||||
result = lexical.score("", "The cat sat")
|
||||
|
||||
assert result.jaccard == 0.0
|
||||
assert result.token_overlap == 0.0
|
||||
|
||||
def test_whitespace_only_candidate(self, lexical: Lexical) -> None:
|
||||
result = lexical.score(" \t\n ", "The cat sat")
|
||||
|
||||
assert result.jaccard == 0.0
|
||||
assert result.token_overlap == 0.0
|
||||
|
||||
def test_empty_reference_raises(self, lexical: Lexical) -> None:
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
lexical.score("The cat sat", "")
|
||||
|
||||
def test_none_reference_raises(self, lexical: Lexical) -> None:
|
||||
with pytest.raises(ValueError, match="requires reference"):
|
||||
lexical.score("The cat sat", None)
|
||||
|
||||
def test_multiple_references_uses_first(self, lexical: Lexical) -> None:
|
||||
candidate = "the cat sat"
|
||||
references = ["the dog ran", "the cat sat"] # First differs
|
||||
result = lexical.score(candidate, references)
|
||||
|
||||
# Should use first reference, not second
|
||||
assert result.jaccard < 1.0
|
||||
|
||||
def test_case_insensitivity(self, lexical: Lexical) -> None:
|
||||
result = lexical.score("THE CAT SAT", "the cat sat")
|
||||
assert result.jaccard == 1.0
|
||||
assert result.token_overlap == 1.0
|
||||
|
||||
def test_punctuation_ignored(self, lexical: Lexical) -> None:
|
||||
result = lexical.score("The cat sat.", "The cat sat!")
|
||||
assert result.jaccard == 1.0
|
||||
assert result.token_overlap == 1.0
|
||||
|
||||
def test_repeated_tokens(self, lexical: Lexical) -> None:
|
||||
candidate = "the the the"
|
||||
reference = "the cat"
|
||||
result = lexical.score(candidate, reference)
|
||||
|
||||
# Sets: {the} and {the, cat}
|
||||
# Jaccard = 1/2 = 0.5
|
||||
assert result.jaccard == 0.5
|
||||
# Token overlap: {the} / {the} = 1.0
|
||||
assert result.token_overlap == 1.0
|
||||
|
||||
|
||||
class TestLexicalBatch:
|
||||
@pytest.fixture
|
||||
def lexical(self) -> Lexical:
|
||||
return Lexical()
|
||||
|
||||
def test_batch_score_basic(self, lexical: Lexical) -> None:
|
||||
candidates = ["The cat sat", "A dog runs"]
|
||||
references = ["The cat sat", "A dog runs"]
|
||||
result = lexical.batch_score(candidates, references)
|
||||
|
||||
assert result.count == 2
|
||||
assert len(result.results) == 2
|
||||
assert all(r.jaccard == 1.0 for r in result.results)
|
||||
|
||||
def test_batch_score_statistics(self, lexical: Lexical) -> None:
|
||||
candidates = ["The cat sat", "Completely different words"]
|
||||
references = ["The cat sat", "The cat sat"]
|
||||
result = lexical.batch_score(candidates, references)
|
||||
|
||||
# Check statistics are computed
|
||||
assert "jaccard" in result.stats
|
||||
assert "token_overlap" in result.stats
|
||||
|
||||
# First result should be 1.0, second should be 0.0
|
||||
assert result.results[0].jaccard == 1.0
|
||||
assert result.results[1].jaccard == 0.0
|
||||
|
||||
# Mean should be 0.5
|
||||
assert result.stats["jaccard"].mean == 0.5
|
||||
|
||||
def test_batch_score_percentiles(self, lexical: Lexical) -> None:
|
||||
candidates = ["a", "b", "c", "d", "e"]
|
||||
references = ["a", "b", "c", "d", "e"]
|
||||
result = lexical.batch_score(candidates, references)
|
||||
|
||||
stats = result.stats["jaccard"]
|
||||
assert 25 in stats.percentiles
|
||||
assert 50 in stats.percentiles
|
||||
assert 75 in stats.percentiles
|
||||
assert 95 in stats.percentiles
|
||||
|
||||
def test_batch_score_none_references_raises(self, lexical: Lexical) -> None:
|
||||
with pytest.raises(ValueError, match="requires reference"):
|
||||
lexical.batch_score(["text"], None)
|
||||
|
||||
def test_batch_score_length_mismatch_raises(self, lexical: Lexical) -> None:
|
||||
with pytest.raises(ValueError, match="must match"):
|
||||
lexical.batch_score(["a", "b"], ["a"])
|
||||
|
||||
|
||||
class TestLexicalResult:
|
||||
def test_frozen(self) -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
result = LexicalResult(jaccard=0.5, token_overlap=0.7)
|
||||
with pytest.raises(ValidationError):
|
||||
result.jaccard = 0.6 # type: ignore[misc]
|
||||
|
||||
def test_values(self) -> None:
|
||||
result = LexicalResult(jaccard=0.5, token_overlap=0.7)
|
||||
assert result.jaccard == 0.5
|
||||
assert result.token_overlap == 0.7
|
||||
Reference in New Issue
Block a user