From 9a4ac359a3f2e3da6cfe23eda9868755c5686fc5 Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Tue, 3 Feb 2026 17:30:56 +0000 Subject: [PATCH] feat(semantic): add SemanticSimilarity metric --- src/veritext/semantic/__init__.py | 16 +++ src/veritext/semantic/similarity.py | 188 ++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 src/veritext/semantic/__init__.py create mode 100644 src/veritext/semantic/similarity.py diff --git a/src/veritext/semantic/__init__.py b/src/veritext/semantic/__init__.py new file mode 100644 index 0000000..c540def --- /dev/null +++ b/src/veritext/semantic/__init__.py @@ -0,0 +1,16 @@ +"""Semantic similarity module: embedding-based text comparison. + +This module provides semantic similarity using sentence-transformers. +It requires the `veritext[semantic]` extra to be installed. + +Example: + >>> from veritext.semantic import SemanticSimilarity + >>> + >>> metric = SemanticSimilarity() + >>> result = metric.score("The cat sat on the mat", "A feline rested on the rug") + >>> print(f"Similarity: {result.similarity:.2f}") +""" + +from veritext.semantic.similarity import SemanticSimilarity + +__all__ = ["SemanticSimilarity"] diff --git a/src/veritext/semantic/similarity.py b/src/veritext/semantic/similarity.py new file mode 100644 index 0000000..5b1bc01 --- /dev/null +++ b/src/veritext/semantic/similarity.py @@ -0,0 +1,188 @@ +"""Embedding-based semantic similarity using sentence-transformers.""" + +from typing import Any + +from veritext.core.exceptions import DependencyError +from veritext.metrics.base import AggregateStats, BatchResult +from veritext.metrics.results import SemanticResult + + +class SemanticSimilarity: + """ + Embedding-based semantic similarity using sentence-transformers. + + Computes cosine similarity between text embeddings to measure semantic + relatedness. This metric captures meaning beyond lexical overlap. + + Requires the `veritext[semantic]` extra to be installed. + """ + + def __init__( + self, + model: str = "all-MiniLM-L6-v2", + cache_embeddings: bool = True, + ) -> None: + """ + Initialise the semantic similarity metric. + + Args: + model: Name of the sentence-transformers model to use. + Defaults to "all-MiniLM-L6-v2" (22MB, good quality/size tradeoff). + cache_embeddings: Whether to cache embeddings for repeated texts. + Defaults to True. + + Raises: + DependencyError: If sentence-transformers is not installed. + """ + try: + from sentence_transformers import SentenceTransformer + except ImportError as err: + raise DependencyError( + "Install veritext[semantic] for semantic similarity: " + "pip install veritext[semantic]" + ) from err + + self._model_name = model + self._model: Any = SentenceTransformer(model) + self._cache: dict[str, Any] | None = {} if cache_embeddings else None + + @property + def name(self) -> str: + """Return the name of this metric.""" + return "semantic" + + @property + def requires_reference(self) -> bool: + """Return whether this metric requires reference text.""" + return True + + def _get_embedding(self, text: str) -> Any: + """ + Get embedding for text, using cache if available. + + Args: + text: The text to embed. + + Returns: + The embedding tensor. + """ + if self._cache is not None and text in self._cache: + return self._cache[text] + + embedding = self._model.encode(text, convert_to_tensor=True) + + if self._cache is not None: + self._cache[text] = embedding + + return embedding + + def _cosine_similarity(self, embedding1: Any, embedding2: Any) -> float: + """ + Compute cosine similarity between two embeddings. + + Args: + embedding1: First embedding tensor. + embedding2: Second embedding tensor. + + Returns: + Cosine similarity score (0.0 to 1.0). + """ + from sentence_transformers import util + + similarity: float = util.cos_sim(embedding1, embedding2).item() + # Clamp to [0, 1] as negative similarities are possible but not meaningful + return max(0.0, min(1.0, similarity)) + + def score( + self, candidate: str, reference: str | list[str] | None = None + ) -> SemanticResult: + """ + Compute semantic similarity between candidate and reference. + + When multiple references are provided, returns the maximum similarity + across all references. + + Args: + candidate: The text to score. + reference: Reference text(s) for comparison. + + Returns: + SemanticResult with similarity score and model name. + + Raises: + ValueError: If reference is None or empty. + """ + if reference is None: + raise ValueError("Semantic similarity requires reference text") + + # Normalise reference to list + references = [reference] if isinstance(reference, str) else reference + + if not references: + raise ValueError("Reference text cannot be empty") + + # Handle empty candidate + candidate_stripped = candidate.strip() + if not candidate_stripped: + return SemanticResult(similarity=0.0, model=self._model_name) + + # Handle empty references + valid_references = [r for r in references if r.strip()] + if not valid_references: + raise ValueError("Reference text cannot be empty") + + # Get candidate embedding + candidate_embedding = self._get_embedding(candidate_stripped) + + # Compute similarity against each reference, take maximum + max_similarity = 0.0 + for ref in valid_references: + ref_embedding = self._get_embedding(ref.strip()) + similarity = self._cosine_similarity(candidate_embedding, ref_embedding) + max_similarity = max(max_similarity, similarity) + + return SemanticResult(similarity=max_similarity, model=self._model_name) + + def batch_score( + self, + candidates: list[str], + references: list[str] | list[list[str]] | None = None, + ) -> BatchResult[SemanticResult]: + """ + Compute semantic similarity for a batch of candidates. + + Args: + candidates: List of texts to score. + references: Reference text(s) for each candidate. + + Returns: + BatchResult containing individual results and aggregate statistics. + + Raises: + ValueError: If references is None or length mismatch. + """ + if references is None: + raise ValueError("Semantic similarity requires reference texts") + + if len(candidates) != len(references): + raise ValueError( + f"Number of candidates ({len(candidates)}) must match " + f"number of references ({len(references)})" + ) + + results: list[SemanticResult] = [] + for i, cand in enumerate(candidates): + ref: str | list[str] = references[i] + results.append(self.score(cand, ref)) + + # Compute aggregate statistics + stats = { + "similarity": AggregateStats.from_values([r.similarity for r in results]), + } + + return BatchResult(results=results, count=len(results), stats=stats) + + def clear_cache(self) -> None: + """Clear the embedding cache.""" + if self._cache is not None: + self._cache.clear()