test(semantic): add semantic similarity tests
This commit is contained in:
1
tests/test_semantic/__init__.py
Normal file
1
tests/test_semantic/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for semantic similarity module."""
|
||||||
240
tests/test_semantic/test_similarity.py
Normal file
240
tests/test_semantic/test_similarity.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Tests for the semantic similarity metric."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Skip all tests if sentence-transformers is not installed
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
|
||||||
|
from veritext.metrics.results import SemanticResult
|
||||||
|
from veritext.semantic import SemanticSimilarity
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSimilarity:
|
||||||
|
"""Tests for the SemanticSimilarity metric class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def semantic(self) -> SemanticSimilarity:
|
||||||
|
"""Provide a SemanticSimilarity metric instance."""
|
||||||
|
return SemanticSimilarity()
|
||||||
|
|
||||||
|
def test_name(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that name returns 'semantic'."""
|
||||||
|
assert semantic.name == "semantic"
|
||||||
|
|
||||||
|
def test_requires_reference(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that semantic similarity requires reference text."""
|
||||||
|
assert semantic.requires_reference is True
|
||||||
|
|
||||||
|
def test_identical_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that identical texts produce high similarity."""
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Identical texts should have very high similarity (close to 1.0)
|
||||||
|
assert result.similarity >= 0.99
|
||||||
|
assert result.model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_semantically_similar_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that semantically similar texts have high similarity."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
reference = "A feline rested on the rug"
|
||||||
|
result = semantic.score(candidate, reference)
|
||||||
|
|
||||||
|
# Similar meanings should have reasonable similarity
|
||||||
|
assert result.similarity > 0.3
|
||||||
|
|
||||||
|
def test_unrelated_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that unrelated texts have low similarity."""
|
||||||
|
candidate = "The quick brown fox"
|
||||||
|
reference = "Quantum physics describes particle behaviour"
|
||||||
|
result = semantic.score(candidate, reference)
|
||||||
|
|
||||||
|
# Unrelated texts should have low similarity
|
||||||
|
assert result.similarity < 0.5
|
||||||
|
|
||||||
|
def test_empty_candidate(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that empty candidate returns zero similarity."""
|
||||||
|
result = semantic.score("", "The cat sat on the mat")
|
||||||
|
assert result.similarity == 0.0
|
||||||
|
|
||||||
|
def test_whitespace_only_candidate(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that whitespace-only candidate returns zero similarity."""
|
||||||
|
result = semantic.score(" \t\n ", "The cat sat on the mat")
|
||||||
|
assert result.similarity == 0.0
|
||||||
|
|
||||||
|
def test_none_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that None reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
semantic.score("The cat sat", None)
|
||||||
|
|
||||||
|
def test_empty_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that empty reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
semantic.score("The cat sat", "")
|
||||||
|
|
||||||
|
def test_whitespace_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that whitespace-only reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
semantic.score("The cat sat", " \t\n ")
|
||||||
|
|
||||||
|
def test_multiple_references(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test semantic similarity with multiple references uses max."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"A dog ran through the park",
|
||||||
|
"The cat sat on the mat", # Exact match
|
||||||
|
]
|
||||||
|
result = semantic.score(candidate, references)
|
||||||
|
|
||||||
|
# Should get high similarity due to exact match reference
|
||||||
|
assert result.similarity >= 0.99
|
||||||
|
|
||||||
|
def test_multiple_references_takes_max(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that multiple references returns maximum similarity."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"Quantum physics is complex", # Low similarity
|
||||||
|
"A feline rested on the rug", # Higher similarity
|
||||||
|
]
|
||||||
|
result = semantic.score(candidate, references)
|
||||||
|
|
||||||
|
# Should use the higher similarity
|
||||||
|
assert result.similarity > 0.3
|
||||||
|
|
||||||
|
def test_result_score_property(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that result.score returns similarity."""
|
||||||
|
result = semantic.score("The cat sat", "The cat sat")
|
||||||
|
assert result.score == result.similarity
|
||||||
|
|
||||||
|
def test_caching_behaviour(self) -> None:
|
||||||
|
"""Test that caching works for repeated texts."""
|
||||||
|
semantic = SemanticSimilarity(cache_embeddings=True)
|
||||||
|
|
||||||
|
# Score same texts multiple times
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result1 = semantic.score(text, text)
|
||||||
|
result2 = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
assert result1.similarity == result2.similarity
|
||||||
|
|
||||||
|
# Clear cache and check again
|
||||||
|
semantic.clear_cache()
|
||||||
|
result3 = semantic.score(text, text)
|
||||||
|
assert result3.similarity == result1.similarity
|
||||||
|
|
||||||
|
def test_caching_disabled(self) -> None:
|
||||||
|
"""Test that caching can be disabled."""
|
||||||
|
semantic = SemanticSimilarity(cache_embeddings=False)
|
||||||
|
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result1 = semantic.score(text, text)
|
||||||
|
result2 = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Results should still be identical (just not cached)
|
||||||
|
assert result1.similarity == result2.similarity
|
||||||
|
|
||||||
|
# Clear cache should not raise even when disabled
|
||||||
|
semantic.clear_cache()
|
||||||
|
|
||||||
|
def test_custom_model(self) -> None:
|
||||||
|
"""Test that custom model name is recorded in result."""
|
||||||
|
# Use the same model but verify it's recorded correctly
|
||||||
|
semantic = SemanticSimilarity(model="all-MiniLM-L6-v2")
|
||||||
|
result = semantic.score("Test text", "Test text")
|
||||||
|
assert result.model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSimilarityBatch:
|
||||||
|
"""Tests for semantic similarity batch scoring."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def semantic(self) -> SemanticSimilarity:
|
||||||
|
"""Provide a SemanticSimilarity metric instance."""
|
||||||
|
return SemanticSimilarity()
|
||||||
|
|
||||||
|
def test_batch_score_basic(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test basic batch scoring."""
|
||||||
|
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||||
|
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
assert len(result.results) == 2
|
||||||
|
# Identical texts should have very high similarity
|
||||||
|
assert all(r.similarity >= 0.99 for r in result.results)
|
||||||
|
|
||||||
|
def test_batch_score_statistics(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that batch scoring computes statistics."""
|
||||||
|
candidates = ["The cat sat", "Quantum physics is complex"]
|
||||||
|
references = ["The cat sat", "The cat sat"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
# Check statistics are computed
|
||||||
|
assert "similarity" in result.stats
|
||||||
|
|
||||||
|
# Mean should be between min and max
|
||||||
|
stats = result.stats["similarity"]
|
||||||
|
assert stats.min <= stats.mean <= stats.max
|
||||||
|
|
||||||
|
def test_batch_score_percentiles(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that batch scoring computes percentiles."""
|
||||||
|
candidates = ["a", "b", "c", "d", "e"]
|
||||||
|
references = ["a", "b", "c", "d", "e"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
stats = result.stats["similarity"]
|
||||||
|
assert 25 in stats.percentiles
|
||||||
|
assert 50 in stats.percentiles
|
||||||
|
assert 75 in stats.percentiles
|
||||||
|
assert 95 in stats.percentiles
|
||||||
|
|
||||||
|
def test_batch_score_none_references_raises(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test that batch scoring raises for None references."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
semantic.batch_score(["text"], None)
|
||||||
|
|
||||||
|
def test_batch_score_length_mismatch_raises(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test that batch scoring raises for mismatched lengths."""
|
||||||
|
with pytest.raises(ValueError, match="must match"):
|
||||||
|
semantic.batch_score(["a", "b"], ["a"])
|
||||||
|
|
||||||
|
def test_batch_score_with_multiple_references(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test batch scoring with multiple references per candidate."""
|
||||||
|
candidates = [
|
||||||
|
"The cat sat on the mat",
|
||||||
|
"A quick brown dog runs fast",
|
||||||
|
]
|
||||||
|
references = [
|
||||||
|
["The cat sat on the mat", "A cat rests on floor"],
|
||||||
|
["A quick brown dog runs fast", "Dogs run very quickly"],
|
||||||
|
]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
# First pair has exact match
|
||||||
|
assert result.results[0].similarity >= 0.99
|
||||||
|
assert result.results[1].similarity >= 0.99
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticResult:
|
||||||
|
"""Tests for SemanticResult type."""
|
||||||
|
|
||||||
|
def test_frozen(self) -> None:
|
||||||
|
"""Test that SemanticResult is frozen."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
result = SemanticResult(similarity=0.85, model="test-model")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
result.similarity = 0.9 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_score_property(self) -> None:
|
||||||
|
"""Test that score property returns similarity."""
|
||||||
|
result = SemanticResult(similarity=0.75, model="test-model")
|
||||||
|
assert result.score == 0.75
|
||||||
@@ -207,3 +207,77 @@ class TestLexicalValidator:
|
|||||||
validator = lexical(min_jaccard=0.5, min_overlap=0.6)
|
validator = lexical(min_jaccard=0.5, min_overlap=0.6)
|
||||||
assert isinstance(validator, LexicalValidator)
|
assert isinstance(validator, LexicalValidator)
|
||||||
assert validator.name == "lexical"
|
assert validator.name == "lexical"
|
||||||
|
|
||||||
|
|
||||||
|
# SemanticValidator tests - conditionally run if sentence-transformers is installed
|
||||||
|
class TestSemanticValidator:
|
||||||
|
"""Tests for SemanticValidator."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _skip_if_no_transformers() -> None:
|
||||||
|
"""Skip test if sentence-transformers is not installed."""
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
|
||||||
|
def test_semantic_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when semantic similarity meets threshold."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "semantic"
|
||||||
|
assert result.actual >= 0.99 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_semantic_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when semantic similarity is below threshold."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.99)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check(
|
||||||
|
"quantum physics describes particle behaviour", context
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.name == "semantic"
|
||||||
|
assert result.actual < 0.99
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_semantic_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_semantic_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error without loading model."""
|
||||||
|
# This test doesn't need sentence-transformers since validation happens first
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
SemanticValidator(min_score=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
SemanticValidator(min_score=-0.1)
|
||||||
|
|
||||||
|
def test_semantic_factory_function(self) -> None:
|
||||||
|
"""Test the semantic() factory function."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators import semantic
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = semantic(min_score=0.6)
|
||||||
|
assert isinstance(validator, SemanticValidator)
|
||||||
|
assert validator.name == "semantic"
|
||||||
|
|||||||
Reference in New Issue
Block a user