Add comprehensive tests for BLEU and lexical metrics including edge cases, batch scoring, and aggregate statistics.
216 lines
8.0 KiB
Python
216 lines
8.0 KiB
Python
"""Tests for the lexical similarity metric."""
|
|
|
|
import pytest
|
|
|
|
from veritext.metrics import Lexical, LexicalResult
|
|
|
|
|
|
class TestLexical:
|
|
"""Tests for the Lexical metric class."""
|
|
|
|
@pytest.fixture
|
|
def lexical(self) -> Lexical:
|
|
"""Provide a lexical metric instance."""
|
|
return Lexical()
|
|
|
|
def test_name(self, lexical: Lexical) -> None:
|
|
"""Test that name returns 'lexical'."""
|
|
assert lexical.name == "lexical"
|
|
|
|
def test_requires_reference(self, lexical: Lexical) -> None:
|
|
"""Test that lexical requires reference text."""
|
|
assert lexical.requires_reference is True
|
|
|
|
def test_identical_texts(self, lexical: Lexical) -> None:
|
|
"""Test that identical texts produce perfect scores."""
|
|
text = "The cat sat on the mat"
|
|
result = lexical.score(text, text)
|
|
|
|
assert result.jaccard == 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
def test_no_overlap(self, lexical: Lexical) -> None:
|
|
"""Test that texts with no overlap produce zero scores."""
|
|
candidate = "apple banana cherry"
|
|
reference = "dog elephant fox"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
assert result.jaccard == 0.0
|
|
assert result.token_overlap == 0.0
|
|
|
|
def test_partial_overlap_jaccard(self, lexical: Lexical) -> None:
|
|
"""Test Jaccard with partial overlap."""
|
|
candidate = "the cat sat"
|
|
reference = "the dog sat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Intersection: {the, sat}, Union: {the, cat, sat, dog}
|
|
# Jaccard = 2/4 = 0.5
|
|
assert result.jaccard == 0.5
|
|
|
|
def test_partial_overlap_token_overlap(self, lexical: Lexical) -> None:
|
|
"""Test token overlap with partial overlap."""
|
|
candidate = "the cat sat"
|
|
reference = "the dog sat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Candidate tokens: {the, cat, sat}
|
|
# In reference: {the, sat}
|
|
# Overlap = 2/3
|
|
assert abs(result.token_overlap - 2 / 3) < 1e-10
|
|
|
|
def test_candidate_subset_of_reference(self, lexical: Lexical) -> None:
|
|
"""Test when candidate is a subset of reference."""
|
|
candidate = "the cat"
|
|
reference = "the cat sat on the mat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# All candidate tokens are in reference
|
|
assert result.token_overlap == 1.0
|
|
# But Jaccard is less than 1 due to extra tokens in reference
|
|
assert result.jaccard < 1.0
|
|
|
|
def test_reference_subset_of_candidate(self, lexical: Lexical) -> None:
|
|
"""Test when reference is a subset of candidate."""
|
|
candidate = "the cat sat on the mat"
|
|
reference = "the cat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Jaccard is less than 1
|
|
assert result.jaccard < 1.0
|
|
# Token overlap is less than 1
|
|
assert result.token_overlap < 1.0
|
|
|
|
def test_empty_candidate(self, lexical: Lexical) -> None:
|
|
"""Test that empty candidate returns zero scores."""
|
|
result = lexical.score("", "The cat sat")
|
|
|
|
assert result.jaccard == 0.0
|
|
assert result.token_overlap == 0.0
|
|
|
|
def test_whitespace_only_candidate(self, lexical: Lexical) -> None:
|
|
"""Test that whitespace-only candidate returns zero scores."""
|
|
result = lexical.score(" \t\n ", "The cat sat")
|
|
|
|
assert result.jaccard == 0.0
|
|
assert result.token_overlap == 0.0
|
|
|
|
def test_empty_reference_raises(self, lexical: Lexical) -> None:
|
|
"""Test that empty reference raises ValueError."""
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
lexical.score("The cat sat", "")
|
|
|
|
def test_none_reference_raises(self, lexical: Lexical) -> None:
|
|
"""Test that None reference raises ValueError."""
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
lexical.score("The cat sat", None)
|
|
|
|
def test_multiple_references_uses_first(self, lexical: Lexical) -> None:
|
|
"""Test that multiple references uses the first one."""
|
|
candidate = "the cat sat"
|
|
references = ["the dog ran", "the cat sat"] # First differs
|
|
result = lexical.score(candidate, references)
|
|
|
|
# Should use first reference, not second
|
|
assert result.jaccard < 1.0
|
|
|
|
def test_case_insensitivity(self, lexical: Lexical) -> None:
|
|
"""Test that lexical is case insensitive by default."""
|
|
result = lexical.score("THE CAT SAT", "the cat sat")
|
|
assert result.jaccard == 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
def test_punctuation_ignored(self, lexical: Lexical) -> None:
|
|
"""Test that punctuation is ignored by default."""
|
|
result = lexical.score("The cat sat.", "The cat sat!")
|
|
assert result.jaccard == 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
def test_repeated_tokens(self, lexical: Lexical) -> None:
|
|
"""Test handling of repeated tokens."""
|
|
candidate = "the the the"
|
|
reference = "the cat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Sets: {the} and {the, cat}
|
|
# Jaccard = 1/2 = 0.5
|
|
assert result.jaccard == 0.5
|
|
# Token overlap: {the} / {the} = 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
|
|
class TestLexicalBatch:
|
|
"""Tests for lexical batch scoring."""
|
|
|
|
@pytest.fixture
|
|
def lexical(self) -> Lexical:
|
|
"""Provide a lexical metric instance."""
|
|
return Lexical()
|
|
|
|
def test_batch_score_basic(self, lexical: Lexical) -> None:
|
|
"""Test basic batch scoring."""
|
|
candidates = ["The cat sat", "A dog runs"]
|
|
references = ["The cat sat", "A dog runs"]
|
|
result = lexical.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert len(result.results) == 2
|
|
assert all(r.jaccard == 1.0 for r in result.results)
|
|
|
|
def test_batch_score_statistics(self, lexical: Lexical) -> None:
|
|
"""Test that batch scoring computes statistics."""
|
|
candidates = ["The cat sat", "Completely different words"]
|
|
references = ["The cat sat", "The cat sat"]
|
|
result = lexical.batch_score(candidates, references)
|
|
|
|
# Check statistics are computed
|
|
assert "jaccard" in result.stats
|
|
assert "token_overlap" in result.stats
|
|
|
|
# First result should be 1.0, second should be 0.0
|
|
assert result.results[0].jaccard == 1.0
|
|
assert result.results[1].jaccard == 0.0
|
|
|
|
# Mean should be 0.5
|
|
assert result.stats["jaccard"].mean == 0.5
|
|
|
|
def test_batch_score_percentiles(self, lexical: Lexical) -> None:
|
|
"""Test that batch scoring computes percentiles."""
|
|
candidates = ["a", "b", "c", "d", "e"]
|
|
references = ["a", "b", "c", "d", "e"]
|
|
result = lexical.batch_score(candidates, references)
|
|
|
|
stats = result.stats["jaccard"]
|
|
assert 25 in stats.percentiles
|
|
assert 50 in stats.percentiles
|
|
assert 75 in stats.percentiles
|
|
assert 95 in stats.percentiles
|
|
|
|
def test_batch_score_none_references_raises(self, lexical: Lexical) -> None:
|
|
"""Test that batch scoring raises for None references."""
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
lexical.batch_score(["text"], None)
|
|
|
|
def test_batch_score_length_mismatch_raises(self, lexical: Lexical) -> None:
|
|
"""Test that batch scoring raises for mismatched lengths."""
|
|
with pytest.raises(ValueError, match="must match"):
|
|
lexical.batch_score(["a", "b"], ["a"])
|
|
|
|
|
|
class TestLexicalResult:
|
|
"""Tests for LexicalResult type."""
|
|
|
|
def test_frozen(self) -> None:
|
|
"""Test that LexicalResult is frozen."""
|
|
from pydantic import ValidationError
|
|
|
|
result = LexicalResult(jaccard=0.5, token_overlap=0.7)
|
|
with pytest.raises(ValidationError):
|
|
result.jaccard = 0.6 # type: ignore[misc]
|
|
|
|
def test_values(self) -> None:
|
|
"""Test that values are stored correctly."""
|
|
result = LexicalResult(jaccard=0.5, token_overlap=0.7)
|
|
assert result.jaccard == 0.5
|
|
assert result.token_overlap == 0.7
|