feat: semantic similarity metric
This commit is contained in:
16
src/veritext/semantic/__init__.py
Normal file
16
src/veritext/semantic/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Semantic similarity module: embedding-based text comparison.
|
||||||
|
|
||||||
|
This module provides semantic similarity using sentence-transformers.
|
||||||
|
It requires the `veritext[semantic]` extra to be installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from veritext.semantic import SemanticSimilarity
|
||||||
|
>>>
|
||||||
|
>>> metric = SemanticSimilarity()
|
||||||
|
>>> result = metric.score("The cat sat on the mat", "A feline rested on the rug")
|
||||||
|
>>> print(f"Similarity: {result.similarity:.2f}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from veritext.semantic.similarity import SemanticSimilarity
|
||||||
|
|
||||||
|
__all__ = ["SemanticSimilarity"]
|
||||||
171
src/veritext/semantic/similarity.py
Normal file
171
src/veritext/semantic/similarity.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Embedding-based semantic similarity using sentence-transformers."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from veritext.core.exceptions import DependencyError
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult
|
||||||
|
from veritext.metrics.results import SemanticResult
|
||||||
|
|
||||||
|
DEFAULT_CACHE_MAX_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticSimilarity:
|
||||||
|
"""
|
||||||
|
Embedding-based semantic similarity using sentence-transformers.
|
||||||
|
|
||||||
|
Computes cosine similarity between text embeddings to measure semantic
|
||||||
|
relatedness. This metric captures meaning beyond lexical overlap.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
cache_max_size: int = DEFAULT_CACHE_MAX_SIZE,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the semantic similarity metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
Defaults to "all-MiniLM-L6-v2" (22MB, good quality/size tradeoff).
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
Defaults to True.
|
||||||
|
cache_max_size: Maximum number of embeddings to cache. Oldest entries
|
||||||
|
are evicted when the limit is reached. Defaults to 1000.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyError: If sentence-transformers is not installed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
except ImportError as err:
|
||||||
|
raise DependencyError(
|
||||||
|
"Install veritext[semantic] for semantic similarity: "
|
||||||
|
"pip install veritext[semantic]"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
self._model_name = model
|
||||||
|
self._model: Any = SentenceTransformer(model)
|
||||||
|
self._cache: OrderedDict[str, Any] | None = (
|
||||||
|
OrderedDict() if cache_embeddings else None
|
||||||
|
)
|
||||||
|
self._cache_max_size = cache_max_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "semantic"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_reference(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_embedding(self, text: str) -> Any:
|
||||||
|
if self._cache is not None and text in self._cache:
|
||||||
|
self._cache.move_to_end(text)
|
||||||
|
return self._cache[text]
|
||||||
|
|
||||||
|
embedding = self._model.encode(text, convert_to_tensor=True)
|
||||||
|
|
||||||
|
if self._cache is not None:
|
||||||
|
while len(self._cache) >= self._cache_max_size:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
self._cache[text] = embedding
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def _cosine_similarity(self, embedding1: Any, embedding2: Any) -> float:
|
||||||
|
from sentence_transformers import util
|
||||||
|
|
||||||
|
similarity: float = util.cos_sim(embedding1, embedding2).item()
|
||||||
|
return max(0.0, min(1.0, similarity))
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self, candidate: str, reference: str | list[str] | None = None
|
||||||
|
) -> SemanticResult:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity between candidate and reference.
|
||||||
|
|
||||||
|
When multiple references are provided, returns the maximum similarity
|
||||||
|
across all references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate: The text to score.
|
||||||
|
reference: Reference text(s) for comparison.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SemanticResult with similarity score and model name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If reference is None or empty.
|
||||||
|
"""
|
||||||
|
if reference is None:
|
||||||
|
raise ValueError("Semantic similarity requires reference text")
|
||||||
|
|
||||||
|
references = [reference] if isinstance(reference, str) else reference
|
||||||
|
|
||||||
|
if not references:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
candidate_stripped = candidate.strip()
|
||||||
|
if not candidate_stripped:
|
||||||
|
return SemanticResult(similarity=0.0, model=self._model_name)
|
||||||
|
|
||||||
|
valid_references = [r for r in references if r.strip()]
|
||||||
|
if not valid_references:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
candidate_embedding = self._get_embedding(candidate_stripped)
|
||||||
|
|
||||||
|
max_similarity = 0.0
|
||||||
|
for ref in valid_references:
|
||||||
|
ref_embedding = self._get_embedding(ref.strip())
|
||||||
|
similarity = self._cosine_similarity(candidate_embedding, ref_embedding)
|
||||||
|
max_similarity = max(max_similarity, similarity)
|
||||||
|
|
||||||
|
return SemanticResult(similarity=max_similarity, model=self._model_name)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]] | None = None,
|
||||||
|
) -> BatchResult[SemanticResult]:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity for a batch of candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of texts to score.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchResult containing individual results and aggregate statistics.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If references is None or length mismatch.
|
||||||
|
"""
|
||||||
|
if references is None:
|
||||||
|
raise ValueError("Semantic similarity requires reference texts")
|
||||||
|
|
||||||
|
if len(candidates) != len(references):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of candidates ({len(candidates)}) must match "
|
||||||
|
f"number of references ({len(references)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[SemanticResult] = []
|
||||||
|
for i, cand in enumerate(candidates):
|
||||||
|
ref: str | list[str] = references[i]
|
||||||
|
results.append(self.score(cand, ref))
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"similarity": AggregateStats.from_values([r.similarity for r in results]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BatchResult(results=results, count=len(results), stats=stats)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
if self._cache is not None:
|
||||||
|
self._cache.clear()
|
||||||
Reference in New Issue
Block a user