Add comprehensive tests for BLEU and lexical metrics including edge cases, batch scoring, and aggregate statistics.
185 lines
6.4 KiB
Python
185 lines
6.4 KiB
Python
"""Tests for the lexical similarity metric."""
|
|
|
|
import pytest
|
|
|
|
from veritext.metrics import Lexical, LexicalResult
|
|
|
|
|
|
class TestLexical:
|
|
@pytest.fixture
|
|
def lexical(self) -> Lexical:
|
|
return Lexical()
|
|
|
|
def test_name(self, lexical: Lexical) -> None:
|
|
assert lexical.name == "lexical"
|
|
|
|
def test_requires_reference(self, lexical: Lexical) -> None:
|
|
assert lexical.requires_reference is True
|
|
|
|
def test_identical_texts(self, lexical: Lexical) -> None:
|
|
text = "The cat sat on the mat"
|
|
result = lexical.score(text, text)
|
|
|
|
assert result.jaccard == 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
def test_no_overlap(self, lexical: Lexical) -> None:
|
|
candidate = "apple banana cherry"
|
|
reference = "dog elephant fox"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
assert result.jaccard == 0.0
|
|
assert result.token_overlap == 0.0
|
|
|
|
def test_partial_overlap_jaccard(self, lexical: Lexical) -> None:
|
|
candidate = "the cat sat"
|
|
reference = "the dog sat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Intersection: {the, sat}, Union: {the, cat, sat, dog}
|
|
# Jaccard = 2/4 = 0.5
|
|
assert result.jaccard == 0.5
|
|
|
|
def test_partial_overlap_token_overlap(self, lexical: Lexical) -> None:
|
|
candidate = "the cat sat"
|
|
reference = "the dog sat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Candidate tokens: {the, cat, sat}
|
|
# In reference: {the, sat}
|
|
# Overlap = 2/3
|
|
assert abs(result.token_overlap - 2 / 3) < 1e-10
|
|
|
|
def test_candidate_subset_of_reference(self, lexical: Lexical) -> None:
|
|
candidate = "the cat"
|
|
reference = "the cat sat on the mat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# All candidate tokens are in reference
|
|
assert result.token_overlap == 1.0
|
|
# But Jaccard is less than 1 due to extra tokens in reference
|
|
assert result.jaccard < 1.0
|
|
|
|
def test_reference_subset_of_candidate(self, lexical: Lexical) -> None:
|
|
candidate = "the cat sat on the mat"
|
|
reference = "the cat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Jaccard is less than 1
|
|
assert result.jaccard < 1.0
|
|
# Token overlap is less than 1
|
|
assert result.token_overlap < 1.0
|
|
|
|
def test_empty_candidate(self, lexical: Lexical) -> None:
|
|
result = lexical.score("", "The cat sat")
|
|
|
|
assert result.jaccard == 0.0
|
|
assert result.token_overlap == 0.0
|
|
|
|
def test_whitespace_only_candidate(self, lexical: Lexical) -> None:
|
|
result = lexical.score(" \t\n ", "The cat sat")
|
|
|
|
assert result.jaccard == 0.0
|
|
assert result.token_overlap == 0.0
|
|
|
|
def test_empty_reference_raises(self, lexical: Lexical) -> None:
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
lexical.score("The cat sat", "")
|
|
|
|
def test_none_reference_raises(self, lexical: Lexical) -> None:
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
lexical.score("The cat sat", None)
|
|
|
|
def test_multiple_references_uses_first(self, lexical: Lexical) -> None:
|
|
candidate = "the cat sat"
|
|
references = ["the dog ran", "the cat sat"] # First differs
|
|
result = lexical.score(candidate, references)
|
|
|
|
# Should use first reference, not second
|
|
assert result.jaccard < 1.0
|
|
|
|
def test_case_insensitivity(self, lexical: Lexical) -> None:
|
|
result = lexical.score("THE CAT SAT", "the cat sat")
|
|
assert result.jaccard == 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
def test_punctuation_ignored(self, lexical: Lexical) -> None:
|
|
result = lexical.score("The cat sat.", "The cat sat!")
|
|
assert result.jaccard == 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
def test_repeated_tokens(self, lexical: Lexical) -> None:
|
|
candidate = "the the the"
|
|
reference = "the cat"
|
|
result = lexical.score(candidate, reference)
|
|
|
|
# Sets: {the} and {the, cat}
|
|
# Jaccard = 1/2 = 0.5
|
|
assert result.jaccard == 0.5
|
|
# Token overlap: {the} / {the} = 1.0
|
|
assert result.token_overlap == 1.0
|
|
|
|
|
|
class TestLexicalBatch:
|
|
@pytest.fixture
|
|
def lexical(self) -> Lexical:
|
|
return Lexical()
|
|
|
|
def test_batch_score_basic(self, lexical: Lexical) -> None:
|
|
candidates = ["The cat sat", "A dog runs"]
|
|
references = ["The cat sat", "A dog runs"]
|
|
result = lexical.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert len(result.results) == 2
|
|
assert all(r.jaccard == 1.0 for r in result.results)
|
|
|
|
def test_batch_score_statistics(self, lexical: Lexical) -> None:
|
|
candidates = ["The cat sat", "Completely different words"]
|
|
references = ["The cat sat", "The cat sat"]
|
|
result = lexical.batch_score(candidates, references)
|
|
|
|
# Check statistics are computed
|
|
assert "jaccard" in result.stats
|
|
assert "token_overlap" in result.stats
|
|
|
|
# First result should be 1.0, second should be 0.0
|
|
assert result.results[0].jaccard == 1.0
|
|
assert result.results[1].jaccard == 0.0
|
|
|
|
# Mean should be 0.5
|
|
assert result.stats["jaccard"].mean == 0.5
|
|
|
|
def test_batch_score_percentiles(self, lexical: Lexical) -> None:
|
|
candidates = ["a", "b", "c", "d", "e"]
|
|
references = ["a", "b", "c", "d", "e"]
|
|
result = lexical.batch_score(candidates, references)
|
|
|
|
stats = result.stats["jaccard"]
|
|
assert 25 in stats.percentiles
|
|
assert 50 in stats.percentiles
|
|
assert 75 in stats.percentiles
|
|
assert 95 in stats.percentiles
|
|
|
|
def test_batch_score_none_references_raises(self, lexical: Lexical) -> None:
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
lexical.batch_score(["text"], None)
|
|
|
|
def test_batch_score_length_mismatch_raises(self, lexical: Lexical) -> None:
|
|
with pytest.raises(ValueError, match="must match"):
|
|
lexical.batch_score(["a", "b"], ["a"])
|
|
|
|
|
|
class TestLexicalResult:
|
|
def test_frozen(self) -> None:
|
|
from pydantic import ValidationError
|
|
|
|
result = LexicalResult(jaccard=0.5, token_overlap=0.7)
|
|
with pytest.raises(ValidationError):
|
|
result.jaccard = 0.6 # type: ignore[misc]
|
|
|
|
def test_values(self) -> None:
|
|
result = LexicalResult(jaccard=0.5, token_overlap=0.7)
|
|
assert result.jaccard == 0.5
|
|
assert result.token_overlap == 0.7
|