209 lines
7.8 KiB
Python
209 lines
7.8 KiB
Python
"""Tests for the semantic similarity metric."""
|
|
|
|
import pytest
|
|
|
|
# Skip all tests if sentence-transformers is not installed
|
|
pytest.importorskip("sentence_transformers")
|
|
|
|
from veritext.metrics.results import SemanticResult
|
|
from veritext.semantic import SemanticSimilarity
|
|
|
|
|
|
class TestSemanticSimilarity:
|
|
@pytest.fixture
|
|
def semantic(self) -> SemanticSimilarity:
|
|
return SemanticSimilarity()
|
|
|
|
def test_name(self, semantic: SemanticSimilarity) -> None:
|
|
assert semantic.name == "semantic"
|
|
|
|
def test_requires_reference(self, semantic: SemanticSimilarity) -> None:
|
|
assert semantic.requires_reference is True
|
|
|
|
def test_identical_texts(self, semantic: SemanticSimilarity) -> None:
|
|
text = "The cat sat on the mat"
|
|
result = semantic.score(text, text)
|
|
|
|
# Identical texts should have very high similarity (close to 1.0)
|
|
assert result.similarity >= 0.99
|
|
assert result.model == "all-MiniLM-L6-v2"
|
|
|
|
def test_semantically_similar_texts(self, semantic: SemanticSimilarity) -> None:
|
|
candidate = "The cat sat on the mat"
|
|
reference = "A feline rested on the rug"
|
|
result = semantic.score(candidate, reference)
|
|
|
|
# Similar meanings should have reasonable similarity
|
|
assert result.similarity > 0.3
|
|
|
|
def test_unrelated_texts(self, semantic: SemanticSimilarity) -> None:
|
|
candidate = "The quick brown fox"
|
|
reference = "Quantum physics describes particle behaviour"
|
|
result = semantic.score(candidate, reference)
|
|
|
|
# Unrelated texts should have low similarity
|
|
assert result.similarity < 0.5
|
|
|
|
def test_empty_candidate(self, semantic: SemanticSimilarity) -> None:
|
|
result = semantic.score("", "The cat sat on the mat")
|
|
assert result.similarity == 0.0
|
|
|
|
def test_whitespace_only_candidate(self, semantic: SemanticSimilarity) -> None:
|
|
result = semantic.score(" \t\n ", "The cat sat on the mat")
|
|
assert result.similarity == 0.0
|
|
|
|
def test_none_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
semantic.score("The cat sat", None)
|
|
|
|
def test_empty_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
semantic.score("The cat sat", "")
|
|
|
|
def test_whitespace_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
|
with pytest.raises(ValueError, match="cannot be empty"):
|
|
semantic.score("The cat sat", " \t\n ")
|
|
|
|
def test_multiple_references(self, semantic: SemanticSimilarity) -> None:
|
|
candidate = "The cat sat on the mat"
|
|
references = [
|
|
"A dog ran through the park",
|
|
"The cat sat on the mat", # Exact match
|
|
]
|
|
result = semantic.score(candidate, references)
|
|
|
|
# Should get high similarity due to exact match reference
|
|
assert result.similarity >= 0.99
|
|
|
|
def test_multiple_references_takes_max(self, semantic: SemanticSimilarity) -> None:
|
|
candidate = "The cat sat on the mat"
|
|
references = [
|
|
"Quantum physics is complex", # Low similarity
|
|
"A feline rested on the rug", # Higher similarity
|
|
]
|
|
result = semantic.score(candidate, references)
|
|
|
|
# Should use the higher similarity
|
|
assert result.similarity > 0.3
|
|
|
|
def test_result_score_property(self, semantic: SemanticSimilarity) -> None:
|
|
result = semantic.score("The cat sat", "The cat sat")
|
|
assert result.score == result.similarity
|
|
|
|
def test_caching_behaviour(self) -> None:
|
|
semantic = SemanticSimilarity(cache_embeddings=True)
|
|
|
|
# Score same texts multiple times
|
|
text = "The cat sat on the mat"
|
|
result1 = semantic.score(text, text)
|
|
result2 = semantic.score(text, text)
|
|
|
|
# Results should be identical
|
|
assert result1.similarity == result2.similarity
|
|
|
|
# Clear cache and check again
|
|
semantic.clear_cache()
|
|
result3 = semantic.score(text, text)
|
|
assert result3.similarity == result1.similarity
|
|
|
|
def test_caching_disabled(self) -> None:
|
|
semantic = SemanticSimilarity(cache_embeddings=False)
|
|
|
|
text = "The cat sat on the mat"
|
|
result1 = semantic.score(text, text)
|
|
result2 = semantic.score(text, text)
|
|
|
|
# Results should still be identical (just not cached)
|
|
assert result1.similarity == result2.similarity
|
|
|
|
# Clear cache should not raise even when disabled
|
|
semantic.clear_cache()
|
|
|
|
def test_custom_model(self) -> None:
|
|
# Use the same model but verify it's recorded correctly
|
|
semantic = SemanticSimilarity(model="all-MiniLM-L6-v2")
|
|
result = semantic.score("Test text", "Test text")
|
|
assert result.model == "all-MiniLM-L6-v2"
|
|
|
|
|
|
class TestSemanticSimilarityBatch:
|
|
@pytest.fixture
|
|
def semantic(self) -> SemanticSimilarity:
|
|
return SemanticSimilarity()
|
|
|
|
def test_batch_score_basic(self, semantic: SemanticSimilarity) -> None:
|
|
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
|
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
|
result = semantic.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
assert len(result.results) == 2
|
|
# Identical texts should have very high similarity
|
|
assert all(r.similarity >= 0.99 for r in result.results)
|
|
|
|
def test_batch_score_statistics(self, semantic: SemanticSimilarity) -> None:
|
|
candidates = ["The cat sat", "Quantum physics is complex"]
|
|
references = ["The cat sat", "The cat sat"]
|
|
result = semantic.batch_score(candidates, references)
|
|
|
|
# Check statistics are computed
|
|
assert "similarity" in result.stats
|
|
|
|
# Mean should be between min and max
|
|
stats = result.stats["similarity"]
|
|
assert stats.min <= stats.mean <= stats.max
|
|
|
|
def test_batch_score_percentiles(self, semantic: SemanticSimilarity) -> None:
|
|
candidates = ["a", "b", "c", "d", "e"]
|
|
references = ["a", "b", "c", "d", "e"]
|
|
result = semantic.batch_score(candidates, references)
|
|
|
|
stats = result.stats["similarity"]
|
|
assert 25 in stats.percentiles
|
|
assert 50 in stats.percentiles
|
|
assert 75 in stats.percentiles
|
|
assert 95 in stats.percentiles
|
|
|
|
def test_batch_score_none_references_raises(
|
|
self, semantic: SemanticSimilarity
|
|
) -> None:
|
|
with pytest.raises(ValueError, match="requires reference"):
|
|
semantic.batch_score(["text"], None)
|
|
|
|
def test_batch_score_length_mismatch_raises(
|
|
self, semantic: SemanticSimilarity
|
|
) -> None:
|
|
with pytest.raises(ValueError, match="must match"):
|
|
semantic.batch_score(["a", "b"], ["a"])
|
|
|
|
def test_batch_score_multi_refs(
|
|
self, semantic: SemanticSimilarity
|
|
) -> None:
|
|
candidates = [
|
|
"The cat sat on the mat",
|
|
"A quick brown dog runs fast",
|
|
]
|
|
references = [
|
|
["The cat sat on the mat", "A cat rests on floor"],
|
|
["A quick brown dog runs fast", "Dogs run very quickly"],
|
|
]
|
|
result = semantic.batch_score(candidates, references)
|
|
|
|
assert result.count == 2
|
|
# First pair has exact match
|
|
assert result.results[0].similarity >= 0.99
|
|
assert result.results[1].similarity >= 0.99
|
|
|
|
|
|
class TestSemanticResult:
|
|
def test_frozen(self) -> None:
|
|
from pydantic import ValidationError
|
|
|
|
result = SemanticResult(similarity=0.85, model="test-model")
|
|
with pytest.raises(ValidationError):
|
|
result.similarity = 0.9 # type: ignore[misc]
|
|
|
|
def test_score_property(self) -> None:
|
|
result = SemanticResult(similarity=0.75, model="test-model")
|
|
assert result.score == 0.75
|