From 8b3536873e37647ea409854f57f1840acbda888d Mon Sep 17 00:00:00 2001 From: Kai Chappell Date: Tue, 3 Feb 2026 17:31:01 +0000 Subject: [PATCH] feat(validators): add SemanticValidator --- src/veritext/validators/__init__.py | 31 ++++++++++- src/veritext/validators/metric.py | 82 +++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/src/veritext/validators/__init__.py b/src/veritext/validators/__init__.py index 1179e32..cadc55d 100644 --- a/src/veritext/validators/__init__.py +++ b/src/veritext/validators/__init__.py @@ -27,7 +27,12 @@ from veritext.validators.constraint import ( LengthValidator, ReadabilityValidator, ) -from veritext.validators.metric import BleuValidator, LexicalValidator, RougeValidator +from veritext.validators.metric import ( + BleuValidator, + LexicalValidator, + RougeValidator, + SemanticValidator, +) # Factory functions for clean API @@ -187,6 +192,28 @@ def any_of(checks: list[Check]) -> AnyOf: return AnyOf(checks=checks) +def semantic( + min_score: float, + model: str = "all-MiniLM-L6-v2", + cache_embeddings: bool = True, +) -> SemanticValidator: + """Create a semantic similarity validator. + + Requires the `veritext[semantic]` extra to be installed. + + Args: + min_score: Minimum semantic similarity required (0.0 to 1.0). + model: Name of the sentence-transformers model to use. + cache_embeddings: Whether to cache embeddings for repeated texts. + + Returns: + SemanticValidator instance. + """ + return SemanticValidator( + min_score=min_score, model=model, cache_embeddings=cache_embeddings + ) + + __all__ = [ "AllOf", "AnyOf", @@ -198,6 +225,7 @@ __all__ = [ "LexicalValidator", "ReadabilityValidator", "RougeValidator", + "SemanticValidator", "all_of", "any_of", "bleu", @@ -207,4 +235,5 @@ __all__ = [ "lexical", "readability", "rouge", + "semantic", ] diff --git a/src/veritext/validators/metric.py b/src/veritext/validators/metric.py index e87841b..3ac9174 100644 --- a/src/veritext/validators/metric.py +++ b/src/veritext/validators/metric.py @@ -286,3 +286,85 @@ class LexicalValidator: threshold=threshold, message=message, ) + + +class SemanticValidator: + """Validates that semantic similarity meets minimum threshold. + + Requires the `veritext[semantic]` extra to be installed. + """ + + def __init__( + self, + min_score: float, + model: str = "all-MiniLM-L6-v2", + cache_embeddings: bool = True, + ) -> None: + """ + Initialise the semantic validator. + + Args: + min_score: Minimum semantic similarity required (0.0 to 1.0). + model: Name of the sentence-transformers model to use. + cache_embeddings: Whether to cache embeddings for repeated texts. + + Raises: + InvalidThresholdError: If min_score is not in range [0.0, 1.0]. + DependencyError: If sentence-transformers is not installed. + """ + if not 0.0 <= min_score <= 1.0: + raise InvalidThresholdError( + f"min_score must be between 0.0 and 1.0, got {min_score}" + ) + + self._min_score = min_score + # Lazy import to avoid loading PyTorch unless needed + from veritext.semantic.similarity import SemanticSimilarity + + self._metric: SemanticSimilarity = SemanticSimilarity( + model=model, cache_embeddings=cache_embeddings + ) + + @property + def name(self) -> str: + """Return the name of this check.""" + return "semantic" + + def check(self, text: str, context: ValidationContext) -> CheckResult: + """ + Run the semantic similarity check. + + Args: + text: The text to validate. + context: Validation context containing reference text. + + Returns: + CheckResult with pass/fail status. + + Raises: + ValidationError: If reference text is missing from context. + """ + if context.reference is None: + raise ValidationError(f"{self.name} requires reference text in context") + + result = self._metric.score(text, context.reference) + passed = result.similarity >= self._min_score + + if passed: + message = ( + f"Semantic similarity {result.similarity:.2f} " + f"meets minimum {self._min_score:.2f}" + ) + else: + message = ( + f"Semantic similarity {result.similarity:.2f} " + f"below minimum {self._min_score:.2f}" + ) + + return CheckResult( + name=self.name, + passed=passed, + actual=result.similarity, + threshold=self._min_score, + message=message, + )