diff --git a/tests/test_metrics/__init__.py b/tests/test_metrics/__init__.py new file mode 100644 index 0000000..97451ef --- /dev/null +++ b/tests/test_metrics/__init__.py @@ -0,0 +1 @@ +"""Tests for the metrics module.""" diff --git a/tests/test_metrics/test_bleu.py b/tests/test_metrics/test_bleu.py new file mode 100644 index 0000000..90c7c98 --- /dev/null +++ b/tests/test_metrics/test_bleu.py @@ -0,0 +1,240 @@ +"""Tests for the BLEU metric.""" + +import pytest + +from veritext.metrics import Bleu, BleuResult + + +class TestBleu: + """Tests for the Bleu metric class.""" + + @pytest.fixture + def bleu(self) -> Bleu: + """Provide a BLEU metric instance.""" + return Bleu() + + def test_name(self, bleu: Bleu) -> None: + """Test that name returns 'bleu'.""" + assert bleu.name == "bleu" + + def test_requires_reference(self, bleu: Bleu) -> None: + """Test that BLEU requires reference text.""" + assert bleu.requires_reference is True + + def test_identical_texts(self, bleu: Bleu) -> None: + """Test that identical texts produce perfect scores.""" + text = "The cat sat on the mat" + result = bleu.score(text, text) + + assert result.bleu1 == 1.0 + assert result.bleu2 == 1.0 + assert result.bleu3 == 1.0 + assert result.bleu4 == 1.0 + assert result.brevity_penalty == 1.0 + + def test_no_overlap(self, bleu: Bleu) -> None: + """Test that texts with no overlap produce zero scores.""" + candidate = "The cat sat" + reference = "A dog runs fast" + result = bleu.score(candidate, reference) + + assert result.bleu1 == 0.0 + assert result.bleu2 == 0.0 + assert result.bleu3 == 0.0 + assert result.bleu4 == 0.0 + + def test_partial_overlap(self, bleu: Bleu) -> None: + """Test partial overlap produces intermediate scores.""" + candidate = "The cat sat on the mat" + reference = "The cat is on the floor" + result = bleu.score(candidate, reference) + + # Should have some overlap but not perfect + assert 0.0 < result.bleu1 < 1.0 + assert 0.0 < result.bleu2 < 1.0 + + def test_brevity_penalty_applied(self, bleu: Bleu) -> None: + """Test that brevity penalty is applied for short candidates.""" + candidate = "cat" + reference = "The cat sat on the mat" + result = bleu.score(candidate, reference) + + # Candidate is much shorter, so brevity penalty should be < 1 + assert result.brevity_penalty < 1.0 + + def test_no_brevity_penalty_for_long_candidate(self, bleu: Bleu) -> None: + """Test that no brevity penalty for longer candidates.""" + candidate = "The cat sat on the mat today" + reference = "The cat sat" + result = bleu.score(candidate, reference) + + # Candidate is longer, so no brevity penalty + assert result.brevity_penalty == 1.0 + + def test_single_word_texts(self, bleu: Bleu) -> None: + """Test BLEU with single word texts.""" + result = bleu.score("cat", "cat") + + assert result.bleu1 == 1.0 + # Higher order n-grams can't be computed for single words + # but with geometric mean, they still produce a score + assert result.brevity_penalty == 1.0 + + def test_empty_candidate(self, bleu: Bleu) -> None: + """Test that empty candidate returns zero scores.""" + result = bleu.score("", "The cat sat") + + assert result.bleu1 == 0.0 + assert result.bleu2 == 0.0 + assert result.bleu3 == 0.0 + assert result.bleu4 == 0.0 + assert result.brevity_penalty == 0.0 + + def test_whitespace_only_candidate(self, bleu: Bleu) -> None: + """Test that whitespace-only candidate returns zero scores.""" + result = bleu.score(" \t\n ", "The cat sat") + + assert result.bleu1 == 0.0 + assert result.bleu4 == 0.0 + + def test_empty_reference_raises(self, bleu: Bleu) -> None: + """Test that empty reference raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + bleu.score("The cat sat", "") + + def test_none_reference_raises(self, bleu: Bleu) -> None: + """Test that None reference raises ValueError.""" + with pytest.raises(ValueError, match="requires reference"): + bleu.score("The cat sat", None) + + def test_multiple_references(self, bleu: Bleu) -> None: + """Test BLEU with multiple references uses max n-gram matches.""" + candidate = "The cat sat on the mat" + references = [ + "The cat is on the floor", + "The cat sat on the mat", # Exact match + ] + result = bleu.score(candidate, references) + + # Should get perfect score due to exact match reference + assert result.bleu4 == 1.0 + + def test_multiple_references_partial(self, bleu: Bleu) -> None: + """Test multiple references with partial matches.""" + candidate = "The quick brown fox" + references = [ + "The fast brown fox", + "A quick brown dog", + ] + result = bleu.score(candidate, references) + + # Should benefit from both references + assert result.bleu1 > 0.0 + + def test_result_score_property(self, bleu: Bleu) -> None: + """Test that result.score returns bleu4.""" + result = bleu.score("The cat sat", "The cat sat") + assert result.score == result.bleu4 + + def test_case_insensitivity(self, bleu: Bleu) -> None: + """Test that BLEU is case insensitive by default.""" + result = bleu.score("THE CAT SAT ON THE MAT", "the cat sat on the mat") + assert result.bleu4 == 1.0 + + def test_punctuation_ignored(self, bleu: Bleu) -> None: + """Test that punctuation is ignored by default.""" + result = bleu.score("The cat sat on the mat.", "The cat sat on the mat!") + assert result.bleu4 == 1.0 + + +class TestBleuBatch: + """Tests for BLEU batch scoring.""" + + @pytest.fixture + def bleu(self) -> Bleu: + """Provide a BLEU metric instance.""" + return Bleu() + + def test_batch_score_basic(self, bleu: Bleu) -> None: + """Test basic batch scoring.""" + candidates = ["The cat sat on the mat", "A quick brown dog runs fast"] + references = ["The cat sat on the mat", "A quick brown dog runs fast"] + result = bleu.batch_score(candidates, references) + + assert result.count == 2 + assert len(result.results) == 2 + assert all(r.bleu4 == 1.0 for r in result.results) + + def test_batch_score_statistics(self, bleu: Bleu) -> None: + """Test that batch scoring computes statistics.""" + candidates = ["The cat sat", "A different text entirely"] + references = ["The cat sat", "The cat sat"] + result = bleu.batch_score(candidates, references) + + # Check statistics are computed + assert "bleu1" in result.stats + assert "bleu4" in result.stats + assert "brevity_penalty" in result.stats + + # Mean should be between min and max + stats = result.stats["bleu4"] + assert stats.min <= stats.mean <= stats.max + + def test_batch_score_percentiles(self, bleu: Bleu) -> None: + """Test that batch scoring computes percentiles.""" + candidates = ["a", "b", "c", "d", "e"] + references = ["a", "b", "c", "d", "e"] + result = bleu.batch_score(candidates, references) + + stats = result.stats["bleu1"] + assert 25 in stats.percentiles + assert 50 in stats.percentiles + assert 75 in stats.percentiles + assert 95 in stats.percentiles + + def test_batch_score_none_references_raises(self, bleu: Bleu) -> None: + """Test that batch scoring raises for None references.""" + with pytest.raises(ValueError, match="requires reference"): + bleu.batch_score(["text"], None) + + def test_batch_score_length_mismatch_raises(self, bleu: Bleu) -> None: + """Test that batch scoring raises for mismatched lengths.""" + with pytest.raises(ValueError, match="must match"): + bleu.batch_score(["a", "b"], ["a"]) + + def test_batch_score_with_multiple_references(self, bleu: Bleu) -> None: + """Test batch scoring with multiple references per candidate.""" + candidates = [ + "The cat sat on the mat", + "A quick brown dog runs fast", + ] + references = [ + ["The cat sat on the mat", "A cat rests on floor"], + ["A quick brown dog runs fast", "Dogs run very quickly"], + ] + result = bleu.batch_score(candidates, references) + + assert result.count == 2 + assert result.results[0].bleu4 == 1.0 + assert result.results[1].bleu4 == 1.0 + + +class TestBleuResult: + """Tests for BleuResult type.""" + + def test_frozen(self) -> None: + """Test that BleuResult is frozen.""" + from pydantic import ValidationError + + result = BleuResult( + bleu1=0.5, bleu2=0.4, bleu3=0.3, bleu4=0.2, brevity_penalty=1.0 + ) + with pytest.raises(ValidationError): + result.bleu1 = 0.6 # type: ignore[misc] + + def test_score_property(self) -> None: + """Test that score property returns bleu4.""" + result = BleuResult( + bleu1=0.9, bleu2=0.8, bleu3=0.7, bleu4=0.6, brevity_penalty=1.0 + ) + assert result.score == 0.6 diff --git a/tests/test_metrics/test_lexical.py b/tests/test_metrics/test_lexical.py new file mode 100644 index 0000000..9dbcb0b --- /dev/null +++ b/tests/test_metrics/test_lexical.py @@ -0,0 +1,215 @@ +"""Tests for the lexical similarity metric.""" + +import pytest + +from veritext.metrics import Lexical, LexicalResult + + +class TestLexical: + """Tests for the Lexical metric class.""" + + @pytest.fixture + def lexical(self) -> Lexical: + """Provide a lexical metric instance.""" + return Lexical() + + def test_name(self, lexical: Lexical) -> None: + """Test that name returns 'lexical'.""" + assert lexical.name == "lexical" + + def test_requires_reference(self, lexical: Lexical) -> None: + """Test that lexical requires reference text.""" + assert lexical.requires_reference is True + + def test_identical_texts(self, lexical: Lexical) -> None: + """Test that identical texts produce perfect scores.""" + text = "The cat sat on the mat" + result = lexical.score(text, text) + + assert result.jaccard == 1.0 + assert result.token_overlap == 1.0 + + def test_no_overlap(self, lexical: Lexical) -> None: + """Test that texts with no overlap produce zero scores.""" + candidate = "apple banana cherry" + reference = "dog elephant fox" + result = lexical.score(candidate, reference) + + assert result.jaccard == 0.0 + assert result.token_overlap == 0.0 + + def test_partial_overlap_jaccard(self, lexical: Lexical) -> None: + """Test Jaccard with partial overlap.""" + candidate = "the cat sat" + reference = "the dog sat" + result = lexical.score(candidate, reference) + + # Intersection: {the, sat}, Union: {the, cat, sat, dog} + # Jaccard = 2/4 = 0.5 + assert result.jaccard == 0.5 + + def test_partial_overlap_token_overlap(self, lexical: Lexical) -> None: + """Test token overlap with partial overlap.""" + candidate = "the cat sat" + reference = "the dog sat" + result = lexical.score(candidate, reference) + + # Candidate tokens: {the, cat, sat} + # In reference: {the, sat} + # Overlap = 2/3 + assert abs(result.token_overlap - 2 / 3) < 1e-10 + + def test_candidate_subset_of_reference(self, lexical: Lexical) -> None: + """Test when candidate is a subset of reference.""" + candidate = "the cat" + reference = "the cat sat on the mat" + result = lexical.score(candidate, reference) + + # All candidate tokens are in reference + assert result.token_overlap == 1.0 + # But Jaccard is less than 1 due to extra tokens in reference + assert result.jaccard < 1.0 + + def test_reference_subset_of_candidate(self, lexical: Lexical) -> None: + """Test when reference is a subset of candidate.""" + candidate = "the cat sat on the mat" + reference = "the cat" + result = lexical.score(candidate, reference) + + # Jaccard is less than 1 + assert result.jaccard < 1.0 + # Token overlap is less than 1 + assert result.token_overlap < 1.0 + + def test_empty_candidate(self, lexical: Lexical) -> None: + """Test that empty candidate returns zero scores.""" + result = lexical.score("", "The cat sat") + + assert result.jaccard == 0.0 + assert result.token_overlap == 0.0 + + def test_whitespace_only_candidate(self, lexical: Lexical) -> None: + """Test that whitespace-only candidate returns zero scores.""" + result = lexical.score(" \t\n ", "The cat sat") + + assert result.jaccard == 0.0 + assert result.token_overlap == 0.0 + + def test_empty_reference_raises(self, lexical: Lexical) -> None: + """Test that empty reference raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + lexical.score("The cat sat", "") + + def test_none_reference_raises(self, lexical: Lexical) -> None: + """Test that None reference raises ValueError.""" + with pytest.raises(ValueError, match="requires reference"): + lexical.score("The cat sat", None) + + def test_multiple_references_uses_first(self, lexical: Lexical) -> None: + """Test that multiple references uses the first one.""" + candidate = "the cat sat" + references = ["the dog ran", "the cat sat"] # First differs + result = lexical.score(candidate, references) + + # Should use first reference, not second + assert result.jaccard < 1.0 + + def test_case_insensitivity(self, lexical: Lexical) -> None: + """Test that lexical is case insensitive by default.""" + result = lexical.score("THE CAT SAT", "the cat sat") + assert result.jaccard == 1.0 + assert result.token_overlap == 1.0 + + def test_punctuation_ignored(self, lexical: Lexical) -> None: + """Test that punctuation is ignored by default.""" + result = lexical.score("The cat sat.", "The cat sat!") + assert result.jaccard == 1.0 + assert result.token_overlap == 1.0 + + def test_repeated_tokens(self, lexical: Lexical) -> None: + """Test handling of repeated tokens.""" + candidate = "the the the" + reference = "the cat" + result = lexical.score(candidate, reference) + + # Sets: {the} and {the, cat} + # Jaccard = 1/2 = 0.5 + assert result.jaccard == 0.5 + # Token overlap: {the} / {the} = 1.0 + assert result.token_overlap == 1.0 + + +class TestLexicalBatch: + """Tests for lexical batch scoring.""" + + @pytest.fixture + def lexical(self) -> Lexical: + """Provide a lexical metric instance.""" + return Lexical() + + def test_batch_score_basic(self, lexical: Lexical) -> None: + """Test basic batch scoring.""" + candidates = ["The cat sat", "A dog runs"] + references = ["The cat sat", "A dog runs"] + result = lexical.batch_score(candidates, references) + + assert result.count == 2 + assert len(result.results) == 2 + assert all(r.jaccard == 1.0 for r in result.results) + + def test_batch_score_statistics(self, lexical: Lexical) -> None: + """Test that batch scoring computes statistics.""" + candidates = ["The cat sat", "Completely different words"] + references = ["The cat sat", "The cat sat"] + result = lexical.batch_score(candidates, references) + + # Check statistics are computed + assert "jaccard" in result.stats + assert "token_overlap" in result.stats + + # First result should be 1.0, second should be 0.0 + assert result.results[0].jaccard == 1.0 + assert result.results[1].jaccard == 0.0 + + # Mean should be 0.5 + assert result.stats["jaccard"].mean == 0.5 + + def test_batch_score_percentiles(self, lexical: Lexical) -> None: + """Test that batch scoring computes percentiles.""" + candidates = ["a", "b", "c", "d", "e"] + references = ["a", "b", "c", "d", "e"] + result = lexical.batch_score(candidates, references) + + stats = result.stats["jaccard"] + assert 25 in stats.percentiles + assert 50 in stats.percentiles + assert 75 in stats.percentiles + assert 95 in stats.percentiles + + def test_batch_score_none_references_raises(self, lexical: Lexical) -> None: + """Test that batch scoring raises for None references.""" + with pytest.raises(ValueError, match="requires reference"): + lexical.batch_score(["text"], None) + + def test_batch_score_length_mismatch_raises(self, lexical: Lexical) -> None: + """Test that batch scoring raises for mismatched lengths.""" + with pytest.raises(ValueError, match="must match"): + lexical.batch_score(["a", "b"], ["a"]) + + +class TestLexicalResult: + """Tests for LexicalResult type.""" + + def test_frozen(self) -> None: + """Test that LexicalResult is frozen.""" + from pydantic import ValidationError + + result = LexicalResult(jaccard=0.5, token_overlap=0.7) + with pytest.raises(ValidationError): + result.jaccard = 0.6 # type: ignore[misc] + + def test_values(self) -> None: + """Test that values are stored correctly.""" + result = LexicalResult(jaccard=0.5, token_overlap=0.7) + assert result.jaccard == 0.5 + assert result.token_overlap == 0.7