From b6c4bad96a58624740815934769638249f69535f Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Sat, 5 Apr 2025 10:03:52 +0000 Subject: [PATCH] feat: semantic similarity metric --- src/veritext/semantic/__init__.py | 16 +++ src/veritext/semantic/similarity.py | 171 ++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 src/veritext/semantic/__init__.py create mode 100644 src/veritext/semantic/similarity.py diff --git a/src/veritext/semantic/__init__.py b/src/veritext/semantic/__init__.py new file mode 100644 index 0000000..c540def --- /dev/null +++ b/src/veritext/semantic/__init__.py @@ -0,0 +1,16 @@ +"""Semantic similarity module: embedding-based text comparison. + +This module provides semantic similarity using sentence-transformers. +It requires the `veritext[semantic]` extra to be installed. + +Example: + >>> from veritext.semantic import SemanticSimilarity + >>> + >>> metric = SemanticSimilarity() + >>> result = metric.score("The cat sat on the mat", "A feline rested on the rug") + >>> print(f"Similarity: {result.similarity:.2f}") +""" + +from veritext.semantic.similarity import SemanticSimilarity + +__all__ = ["SemanticSimilarity"] diff --git a/src/veritext/semantic/similarity.py b/src/veritext/semantic/similarity.py new file mode 100644 index 0000000..bf89dd1 --- /dev/null +++ b/src/veritext/semantic/similarity.py @@ -0,0 +1,171 @@ +"""Embedding-based semantic similarity using sentence-transformers.""" + +from collections import OrderedDict +from typing import Any + +from veritext.core.exceptions import DependencyError +from veritext.metrics.base import AggregateStats, BatchResult +from veritext.metrics.results import SemanticResult + +DEFAULT_CACHE_MAX_SIZE = 1000 + + +class SemanticSimilarity: + """ + Embedding-based semantic similarity using sentence-transformers. + + Computes cosine similarity between text embeddings to measure semantic + relatedness. This metric captures meaning beyond lexical overlap. + + Requires the `veritext[semantic]` extra to be installed. + """ + + def __init__( + self, + model: str = "all-MiniLM-L6-v2", + cache_embeddings: bool = True, + cache_max_size: int = DEFAULT_CACHE_MAX_SIZE, + ) -> None: + """ + Initialise the semantic similarity metric. + + Args: + model: Name of the sentence-transformers model to use. + Defaults to "all-MiniLM-L6-v2" (22MB, good quality/size tradeoff). + cache_embeddings: Whether to cache embeddings for repeated texts. + Defaults to True. + cache_max_size: Maximum number of embeddings to cache. Oldest entries + are evicted when the limit is reached. Defaults to 1000. + + Raises: + DependencyError: If sentence-transformers is not installed. + """ + try: + from sentence_transformers import SentenceTransformer + except ImportError as err: + raise DependencyError( + "Install veritext[semantic] for semantic similarity: " + "pip install veritext[semantic]" + ) from err + + self._model_name = model + self._model: Any = SentenceTransformer(model) + self._cache: OrderedDict[str, Any] | None = ( + OrderedDict() if cache_embeddings else None + ) + self._cache_max_size = cache_max_size + + @property + def name(self) -> str: + return "semantic" + + @property + def requires_reference(self) -> bool: + return True + + def _get_embedding(self, text: str) -> Any: + if self._cache is not None and text in self._cache: + self._cache.move_to_end(text) + return self._cache[text] + + embedding = self._model.encode(text, convert_to_tensor=True) + + if self._cache is not None: + while len(self._cache) >= self._cache_max_size: + self._cache.popitem(last=False) + self._cache[text] = embedding + + return embedding + + def _cosine_similarity(self, embedding1: Any, embedding2: Any) -> float: + from sentence_transformers import util + + similarity: float = util.cos_sim(embedding1, embedding2).item() + return max(0.0, min(1.0, similarity)) + + def score( + self, candidate: str, reference: str | list[str] | None = None + ) -> SemanticResult: + """ + Compute semantic similarity between candidate and reference. + + When multiple references are provided, returns the maximum similarity + across all references. + + Args: + candidate: The text to score. + reference: Reference text(s) for comparison. + + Returns: + SemanticResult with similarity score and model name. + + Raises: + ValueError: If reference is None or empty. + """ + if reference is None: + raise ValueError("Semantic similarity requires reference text") + + references = [reference] if isinstance(reference, str) else reference + + if not references: + raise ValueError("Reference text cannot be empty") + + candidate_stripped = candidate.strip() + if not candidate_stripped: + return SemanticResult(similarity=0.0, model=self._model_name) + + valid_references = [r for r in references if r.strip()] + if not valid_references: + raise ValueError("Reference text cannot be empty") + + candidate_embedding = self._get_embedding(candidate_stripped) + + max_similarity = 0.0 + for ref in valid_references: + ref_embedding = self._get_embedding(ref.strip()) + similarity = self._cosine_similarity(candidate_embedding, ref_embedding) + max_similarity = max(max_similarity, similarity) + + return SemanticResult(similarity=max_similarity, model=self._model_name) + + def batch_score( + self, + candidates: list[str], + references: list[str] | list[list[str]] | None = None, + ) -> BatchResult[SemanticResult]: + """ + Compute semantic similarity for a batch of candidates. + + Args: + candidates: List of texts to score. + references: Reference text(s) for each candidate. + + Returns: + BatchResult containing individual results and aggregate statistics. + + Raises: + ValueError: If references is None or length mismatch. + """ + if references is None: + raise ValueError("Semantic similarity requires reference texts") + + if len(candidates) != len(references): + raise ValueError( + f"Number of candidates ({len(candidates)}) must match " + f"number of references ({len(references)})" + ) + + results: list[SemanticResult] = [] + for i, cand in enumerate(candidates): + ref: str | list[str] = references[i] + results.append(self.score(cand, ref)) + + stats = { + "similarity": AggregateStats.from_values([r.similarity for r in results]), + } + + return BatchResult(results=results, count=len(results), stats=stats) + + def clear_cache(self) -> None: + if self._cache is not None: + self._cache.clear()