feat(validators): add SemanticValidator
This commit is contained in:
@@ -27,7 +27,12 @@ from veritext.validators.constraint import (
|
||||
LengthValidator,
|
||||
ReadabilityValidator,
|
||||
)
|
||||
from veritext.validators.metric import BleuValidator, LexicalValidator, RougeValidator
|
||||
from veritext.validators.metric import (
|
||||
BleuValidator,
|
||||
LexicalValidator,
|
||||
RougeValidator,
|
||||
SemanticValidator,
|
||||
)
|
||||
|
||||
|
||||
# Factory functions for clean API
|
||||
@@ -187,6 +192,28 @@ def any_of(checks: list[Check]) -> AnyOf:
|
||||
return AnyOf(checks=checks)
|
||||
|
||||
|
||||
def semantic(
|
||||
min_score: float,
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
cache_embeddings: bool = True,
|
||||
) -> SemanticValidator:
|
||||
"""Create a semantic similarity validator.
|
||||
|
||||
Requires the `veritext[semantic]` extra to be installed.
|
||||
|
||||
Args:
|
||||
min_score: Minimum semantic similarity required (0.0 to 1.0).
|
||||
model: Name of the sentence-transformers model to use.
|
||||
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||
|
||||
Returns:
|
||||
SemanticValidator instance.
|
||||
"""
|
||||
return SemanticValidator(
|
||||
min_score=min_score, model=model, cache_embeddings=cache_embeddings
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AllOf",
|
||||
"AnyOf",
|
||||
@@ -198,6 +225,7 @@ __all__ = [
|
||||
"LexicalValidator",
|
||||
"ReadabilityValidator",
|
||||
"RougeValidator",
|
||||
"SemanticValidator",
|
||||
"all_of",
|
||||
"any_of",
|
||||
"bleu",
|
||||
@@ -207,4 +235,5 @@ __all__ = [
|
||||
"lexical",
|
||||
"readability",
|
||||
"rouge",
|
||||
"semantic",
|
||||
]
|
||||
|
||||
@@ -286,3 +286,85 @@ class LexicalValidator:
|
||||
threshold=threshold,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
class SemanticValidator:
|
||||
"""Validates that semantic similarity meets minimum threshold.
|
||||
|
||||
Requires the `veritext[semantic]` extra to be installed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_score: float,
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
cache_embeddings: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the semantic validator.
|
||||
|
||||
Args:
|
||||
min_score: Minimum semantic similarity required (0.0 to 1.0).
|
||||
model: Name of the sentence-transformers model to use.
|
||||
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||
|
||||
Raises:
|
||||
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||
DependencyError: If sentence-transformers is not installed.
|
||||
"""
|
||||
if not 0.0 <= min_score <= 1.0:
|
||||
raise InvalidThresholdError(
|
||||
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||
)
|
||||
|
||||
self._min_score = min_score
|
||||
# Lazy import to avoid loading PyTorch unless needed
|
||||
from veritext.semantic.similarity import SemanticSimilarity
|
||||
|
||||
self._metric: SemanticSimilarity = SemanticSimilarity(
|
||||
model=model, cache_embeddings=cache_embeddings
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of this check."""
|
||||
return "semantic"
|
||||
|
||||
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||
"""
|
||||
Run the semantic similarity check.
|
||||
|
||||
Args:
|
||||
text: The text to validate.
|
||||
context: Validation context containing reference text.
|
||||
|
||||
Returns:
|
||||
CheckResult with pass/fail status.
|
||||
|
||||
Raises:
|
||||
ValidationError: If reference text is missing from context.
|
||||
"""
|
||||
if context.reference is None:
|
||||
raise ValidationError(f"{self.name} requires reference text in context")
|
||||
|
||||
result = self._metric.score(text, context.reference)
|
||||
passed = result.similarity >= self._min_score
|
||||
|
||||
if passed:
|
||||
message = (
|
||||
f"Semantic similarity {result.similarity:.2f} "
|
||||
f"meets minimum {self._min_score:.2f}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Semantic similarity {result.similarity:.2f} "
|
||||
f"below minimum {self._min_score:.2f}"
|
||||
)
|
||||
|
||||
return CheckResult(
|
||||
name=self.name,
|
||||
passed=passed,
|
||||
actual=result.similarity,
|
||||
threshold=self._min_score,
|
||||
message=message,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user