Compare commits
12 Commits
feat/metri
...
feat/valid
| Author | SHA1 | Date | |
|---|---|---|---|
|
107fc4e275
|
|||
|
571b770281
|
|||
|
8b3536873e
|
|||
|
9a4ac359a3
|
|||
|
de5ad93524
|
|||
|
cab8099d06
|
|||
|
e2be3daffd
|
|||
|
9239300fd9
|
|||
|
b9f805b2f4
|
|||
|
75cd7b68de
|
|||
|
b2b5eb1518
|
|||
|
9e7b0131b3
|
10
changelog.md
10
changelog.md
@@ -21,3 +21,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- ROUGE metric (ROUGE-1, ROUGE-2, ROUGE-L with precision/recall/F-measure)
|
- ROUGE metric (ROUGE-1, ROUGE-2, ROUGE-L with precision/recall/F-measure)
|
||||||
- Flesch-Kincaid readability metrics (grade level and reading ease)
|
- Flesch-Kincaid readability metrics (grade level and reading ease)
|
||||||
- Batch scoring with aggregate statistics for all metrics
|
- Batch scoring with aggregate statistics for all metrics
|
||||||
|
- Validators module with `Check` protocol for validation checks
|
||||||
|
- Metric-based validators: `BleuValidator`, `RougeValidator`, `LexicalValidator`
|
||||||
|
- Constraint validators: `LengthValidator`, `ReadabilityValidator`, `ContainsValidator`, `ExcludesValidator`
|
||||||
|
- Composite validators: `AllOf` (all checks must pass), `AnyOf` (any check must pass)
|
||||||
|
- Factory functions for clean validator API (`bleu()`, `rouge()`, `lexical()`, `length()`, `readability()`, `contains()`, `excludes()`, `all_of()`, `any_of()`)
|
||||||
|
- Semantic similarity module with embedding-based text comparison (requires `veritext[semantic]` extra)
|
||||||
|
- `SemanticSimilarity` metric using sentence-transformers for semantic relatedness
|
||||||
|
- `SemanticValidator` for threshold-based semantic similarity validation
|
||||||
|
- `semantic()` factory function for creating semantic validators
|
||||||
|
- Embedding caching for performance optimisation in repeated comparisons
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from veritext.metrics.results import (
|
|||||||
ReadabilityResult,
|
ReadabilityResult,
|
||||||
RougeResult,
|
RougeResult,
|
||||||
RougeScore,
|
RougeScore,
|
||||||
|
SemanticResult,
|
||||||
)
|
)
|
||||||
from veritext.metrics.rouge import Rouge
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
@@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"Rouge",
|
"Rouge",
|
||||||
"RougeResult",
|
"RougeResult",
|
||||||
"RougeScore",
|
"RougeScore",
|
||||||
|
"SemanticResult",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -91,3 +91,20 @@ class ReadabilityResult(BaseModel):
|
|||||||
def score(self) -> float:
|
def score(self) -> float:
|
||||||
"""Return Flesch reading ease as the primary score."""
|
"""Return Flesch reading ease as the primary score."""
|
||||||
return self.flesch_reading_ease
|
return self.flesch_reading_ease
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticResult(BaseModel):
|
||||||
|
"""Result of semantic similarity computation."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
similarity: float
|
||||||
|
"""Cosine similarity score (0.0 to 1.0)."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""Name of the embedding model used."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
"""Return the primary score for this result."""
|
||||||
|
return self.similarity
|
||||||
|
|||||||
16
src/veritext/semantic/__init__.py
Normal file
16
src/veritext/semantic/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Semantic similarity module: embedding-based text comparison.
|
||||||
|
|
||||||
|
This module provides semantic similarity using sentence-transformers.
|
||||||
|
It requires the `veritext[semantic]` extra to be installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from veritext.semantic import SemanticSimilarity
|
||||||
|
>>>
|
||||||
|
>>> metric = SemanticSimilarity()
|
||||||
|
>>> result = metric.score("The cat sat on the mat", "A feline rested on the rug")
|
||||||
|
>>> print(f"Similarity: {result.similarity:.2f}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from veritext.semantic.similarity import SemanticSimilarity
|
||||||
|
|
||||||
|
__all__ = ["SemanticSimilarity"]
|
||||||
188
src/veritext/semantic/similarity.py
Normal file
188
src/veritext/semantic/similarity.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""Embedding-based semantic similarity using sentence-transformers."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from veritext.core.exceptions import DependencyError
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult
|
||||||
|
from veritext.metrics.results import SemanticResult
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticSimilarity:
|
||||||
|
"""
|
||||||
|
Embedding-based semantic similarity using sentence-transformers.
|
||||||
|
|
||||||
|
Computes cosine similarity between text embeddings to measure semantic
|
||||||
|
relatedness. This metric captures meaning beyond lexical overlap.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the semantic similarity metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
Defaults to "all-MiniLM-L6-v2" (22MB, good quality/size tradeoff).
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyError: If sentence-transformers is not installed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
except ImportError as err:
|
||||||
|
raise DependencyError(
|
||||||
|
"Install veritext[semantic] for semantic similarity: "
|
||||||
|
"pip install veritext[semantic]"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
self._model_name = model
|
||||||
|
self._model: Any = SentenceTransformer(model)
|
||||||
|
self._cache: dict[str, Any] | None = {} if cache_embeddings else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this metric."""
|
||||||
|
return "semantic"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_reference(self) -> bool:
|
||||||
|
"""Return whether this metric requires reference text."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_embedding(self, text: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get embedding for text, using cache if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The embedding tensor.
|
||||||
|
"""
|
||||||
|
if self._cache is not None and text in self._cache:
|
||||||
|
return self._cache[text]
|
||||||
|
|
||||||
|
embedding = self._model.encode(text, convert_to_tensor=True)
|
||||||
|
|
||||||
|
if self._cache is not None:
|
||||||
|
self._cache[text] = embedding
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def _cosine_similarity(self, embedding1: Any, embedding2: Any) -> float:
|
||||||
|
"""
|
||||||
|
Compute cosine similarity between two embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding1: First embedding tensor.
|
||||||
|
embedding2: Second embedding tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity score (0.0 to 1.0).
|
||||||
|
"""
|
||||||
|
from sentence_transformers import util
|
||||||
|
|
||||||
|
similarity: float = util.cos_sim(embedding1, embedding2).item()
|
||||||
|
# Clamp to [0, 1] as negative similarities are possible but not meaningful
|
||||||
|
return max(0.0, min(1.0, similarity))
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self, candidate: str, reference: str | list[str] | None = None
|
||||||
|
) -> SemanticResult:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity between candidate and reference.
|
||||||
|
|
||||||
|
When multiple references are provided, returns the maximum similarity
|
||||||
|
across all references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate: The text to score.
|
||||||
|
reference: Reference text(s) for comparison.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SemanticResult with similarity score and model name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If reference is None or empty.
|
||||||
|
"""
|
||||||
|
if reference is None:
|
||||||
|
raise ValueError("Semantic similarity requires reference text")
|
||||||
|
|
||||||
|
# Normalise reference to list
|
||||||
|
references = [reference] if isinstance(reference, str) else reference
|
||||||
|
|
||||||
|
if not references:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
# Handle empty candidate
|
||||||
|
candidate_stripped = candidate.strip()
|
||||||
|
if not candidate_stripped:
|
||||||
|
return SemanticResult(similarity=0.0, model=self._model_name)
|
||||||
|
|
||||||
|
# Handle empty references
|
||||||
|
valid_references = [r for r in references if r.strip()]
|
||||||
|
if not valid_references:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
# Get candidate embedding
|
||||||
|
candidate_embedding = self._get_embedding(candidate_stripped)
|
||||||
|
|
||||||
|
# Compute similarity against each reference, take maximum
|
||||||
|
max_similarity = 0.0
|
||||||
|
for ref in valid_references:
|
||||||
|
ref_embedding = self._get_embedding(ref.strip())
|
||||||
|
similarity = self._cosine_similarity(candidate_embedding, ref_embedding)
|
||||||
|
max_similarity = max(max_similarity, similarity)
|
||||||
|
|
||||||
|
return SemanticResult(similarity=max_similarity, model=self._model_name)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]] | None = None,
|
||||||
|
) -> BatchResult[SemanticResult]:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity for a batch of candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of texts to score.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchResult containing individual results and aggregate statistics.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If references is None or length mismatch.
|
||||||
|
"""
|
||||||
|
if references is None:
|
||||||
|
raise ValueError("Semantic similarity requires reference texts")
|
||||||
|
|
||||||
|
if len(candidates) != len(references):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of candidates ({len(candidates)}) must match "
|
||||||
|
f"number of references ({len(references)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[SemanticResult] = []
|
||||||
|
for i, cand in enumerate(candidates):
|
||||||
|
ref: str | list[str] = references[i]
|
||||||
|
results.append(self.score(cand, ref))
|
||||||
|
|
||||||
|
# Compute aggregate statistics
|
||||||
|
stats = {
|
||||||
|
"similarity": AggregateStats.from_values([r.similarity for r in results]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BatchResult(results=results, count=len(results), stats=stats)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the embedding cache."""
|
||||||
|
if self._cache is not None:
|
||||||
|
self._cache.clear()
|
||||||
239
src/veritext/validators/__init__.py
Normal file
239
src/veritext/validators/__init__.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""Validators module: composable validation checks for text quality.
|
||||||
|
|
||||||
|
This module provides validators that apply thresholds to metrics and return
|
||||||
|
pass/fail decisions with diagnostics.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from veritext.validators import bleu, length, all_of
|
||||||
|
>>> from veritext.core.types import ValidationContext
|
||||||
|
>>>
|
||||||
|
>>> validator = all_of([
|
||||||
|
... bleu(min_score=0.5),
|
||||||
|
... length(min_words=10),
|
||||||
|
... ])
|
||||||
|
>>> context = ValidationContext(reference="The quick brown fox.")
|
||||||
|
>>> result = validator.check("The quick brown fox jumps.", context)
|
||||||
|
>>> print(result.passed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.validators.base import Check
|
||||||
|
from veritext.validators.composite import AllOf, AnyOf
|
||||||
|
from veritext.validators.constraint import (
|
||||||
|
ContainsValidator,
|
||||||
|
ExcludesValidator,
|
||||||
|
LengthValidator,
|
||||||
|
ReadabilityValidator,
|
||||||
|
)
|
||||||
|
from veritext.validators.metric import (
|
||||||
|
BleuValidator,
|
||||||
|
LexicalValidator,
|
||||||
|
RougeValidator,
|
||||||
|
SemanticValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Factory functions for clean API
|
||||||
|
def bleu(
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal[1, 2, 3, 4] = 4,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> BleuValidator:
|
||||||
|
"""Create a BLEU validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum BLEU score required (0.0 to 1.0).
|
||||||
|
variant: BLEU variant to use (1, 2, 3, or 4). Defaults to 4.
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BleuValidator instance.
|
||||||
|
"""
|
||||||
|
return BleuValidator(min_score=min_score, variant=variant, tokeniser=tokeniser)
|
||||||
|
|
||||||
|
|
||||||
|
def rouge(
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal["1", "2", "l"] = "l",
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> RougeValidator:
|
||||||
|
"""Create a ROUGE validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum ROUGE F-measure required (0.0 to 1.0).
|
||||||
|
variant: ROUGE variant ("1", "2", or "l"). Defaults to "l".
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RougeValidator instance.
|
||||||
|
"""
|
||||||
|
return RougeValidator(min_score=min_score, variant=variant, tokeniser=tokeniser)
|
||||||
|
|
||||||
|
|
||||||
|
def lexical(
|
||||||
|
min_jaccard: float | None = None,
|
||||||
|
min_overlap: float | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> LexicalValidator:
|
||||||
|
"""Create a lexical similarity validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_jaccard: Minimum Jaccard similarity required (0.0 to 1.0).
|
||||||
|
min_overlap: Minimum token overlap required (0.0 to 1.0).
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LexicalValidator instance.
|
||||||
|
"""
|
||||||
|
return LexicalValidator(
|
||||||
|
min_jaccard=min_jaccard, min_overlap=min_overlap, tokeniser=tokeniser
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def length(
|
||||||
|
min_chars: int | None = None,
|
||||||
|
max_chars: int | None = None,
|
||||||
|
min_words: int | None = None,
|
||||||
|
max_words: int | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> LengthValidator:
|
||||||
|
"""Create a length validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_chars: Minimum character count (inclusive).
|
||||||
|
max_chars: Maximum character count (inclusive).
|
||||||
|
min_words: Minimum word count (inclusive).
|
||||||
|
max_words: Maximum word count (inclusive).
|
||||||
|
tokeniser: Tokeniser to use for word counting. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LengthValidator instance.
|
||||||
|
"""
|
||||||
|
return LengthValidator(
|
||||||
|
min_chars=min_chars,
|
||||||
|
max_chars=max_chars,
|
||||||
|
min_words=min_words,
|
||||||
|
max_words=max_words,
|
||||||
|
tokeniser=tokeniser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def readability(
|
||||||
|
max_grade: float | None = None,
|
||||||
|
min_ease: float | None = None,
|
||||||
|
) -> ReadabilityValidator:
|
||||||
|
"""Create a readability validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_grade: Maximum Flesch-Kincaid grade level allowed.
|
||||||
|
min_ease: Minimum Flesch Reading Ease score required.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadabilityValidator instance.
|
||||||
|
"""
|
||||||
|
return ReadabilityValidator(max_grade=max_grade, min_ease=min_ease)
|
||||||
|
|
||||||
|
|
||||||
|
def contains(
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> ContainsValidator:
|
||||||
|
"""Create a contains validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContainsValidator instance.
|
||||||
|
"""
|
||||||
|
return ContainsValidator(patterns=patterns, case_sensitive=case_sensitive)
|
||||||
|
|
||||||
|
|
||||||
|
def excludes(
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> ExcludesValidator:
|
||||||
|
"""Create an excludes validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must not be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExcludesValidator instance.
|
||||||
|
"""
|
||||||
|
return ExcludesValidator(patterns=patterns, case_sensitive=case_sensitive)
|
||||||
|
|
||||||
|
|
||||||
|
def all_of(checks: list[Check]) -> AllOf:
|
||||||
|
"""Create an AllOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks that must all pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AllOf instance.
|
||||||
|
"""
|
||||||
|
return AllOf(checks=checks)
|
||||||
|
|
||||||
|
|
||||||
|
def any_of(checks: list[Check]) -> AnyOf:
|
||||||
|
"""Create an AnyOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks where at least one must pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnyOf instance.
|
||||||
|
"""
|
||||||
|
return AnyOf(checks=checks)
|
||||||
|
|
||||||
|
|
||||||
|
def semantic(
|
||||||
|
min_score: float,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
) -> SemanticValidator:
|
||||||
|
"""Create a semantic similarity validator.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum semantic similarity required (0.0 to 1.0).
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SemanticValidator instance.
|
||||||
|
"""
|
||||||
|
return SemanticValidator(
|
||||||
|
min_score=min_score, model=model, cache_embeddings=cache_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AllOf",
|
||||||
|
"AnyOf",
|
||||||
|
"BleuValidator",
|
||||||
|
"Check",
|
||||||
|
"ContainsValidator",
|
||||||
|
"ExcludesValidator",
|
||||||
|
"LengthValidator",
|
||||||
|
"LexicalValidator",
|
||||||
|
"ReadabilityValidator",
|
||||||
|
"RougeValidator",
|
||||||
|
"SemanticValidator",
|
||||||
|
"all_of",
|
||||||
|
"any_of",
|
||||||
|
"bleu",
|
||||||
|
"contains",
|
||||||
|
"excludes",
|
||||||
|
"length",
|
||||||
|
"lexical",
|
||||||
|
"readability",
|
||||||
|
"rouge",
|
||||||
|
"semantic",
|
||||||
|
]
|
||||||
31
src/veritext/validators/base.py
Normal file
31
src/veritext/validators/base.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Base types and protocols for validation checks."""
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Check(Protocol):
|
||||||
|
"""Protocol for validation checks.
|
||||||
|
|
||||||
|
A Check computes a score or property of text and compares it against
|
||||||
|
a threshold to produce a pass/fail result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""Run the check and return a result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status and diagnostics.
|
||||||
|
"""
|
||||||
|
...
|
||||||
90
src/veritext/validators/composite.py
Normal file
90
src/veritext/validators/composite.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Composite validators for combining multiple checks."""
|
||||||
|
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext, ValidationResult
|
||||||
|
from veritext.validators.base import Check
|
||||||
|
|
||||||
|
|
||||||
|
class AllOf:
|
||||||
|
"""Passes only if all checks pass."""
|
||||||
|
|
||||||
|
def __init__(self, checks: list[Check]) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the AllOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks that must all pass.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If checks list is empty.
|
||||||
|
"""
|
||||||
|
if not checks:
|
||||||
|
raise ValueError("checks list cannot be empty")
|
||||||
|
|
||||||
|
self._checks = checks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this composite check."""
|
||||||
|
return "all_of"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Run all checks and return aggregate result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult that passes only if all checks pass.
|
||||||
|
"""
|
||||||
|
results: list[CheckResult] = []
|
||||||
|
for check in self._checks:
|
||||||
|
results.append(check.check(text, context))
|
||||||
|
|
||||||
|
all_passed = all(r.passed for r in results)
|
||||||
|
|
||||||
|
return ValidationResult(passed=all_passed, checks=results)
|
||||||
|
|
||||||
|
|
||||||
|
class AnyOf:
|
||||||
|
"""Passes if any check passes."""
|
||||||
|
|
||||||
|
def __init__(self, checks: list[Check]) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the AnyOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks where at least one must pass.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If checks list is empty.
|
||||||
|
"""
|
||||||
|
if not checks:
|
||||||
|
raise ValueError("checks list cannot be empty")
|
||||||
|
|
||||||
|
self._checks = checks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this composite check."""
|
||||||
|
return "any_of"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Run all checks and return aggregate result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult that passes if any check passes.
|
||||||
|
"""
|
||||||
|
results: list[CheckResult] = []
|
||||||
|
for check in self._checks:
|
||||||
|
results.append(check.check(text, context))
|
||||||
|
|
||||||
|
any_passed = any(r.passed for r in results)
|
||||||
|
|
||||||
|
return ValidationResult(passed=any_passed, checks=results)
|
||||||
337
src/veritext/validators/constraint.py
Normal file
337
src/veritext/validators/constraint.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""Constraint validators that do not require reference text."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext
|
||||||
|
from veritext.metrics.readability import Readability
|
||||||
|
|
||||||
|
|
||||||
|
class LengthValidator:
|
||||||
|
"""Validates text length constraints."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_chars: int | None = None,
|
||||||
|
max_chars: int | None = None,
|
||||||
|
min_words: int | None = None,
|
||||||
|
max_words: int | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the length validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_chars: Minimum character count (inclusive).
|
||||||
|
max_chars: Maximum character count (inclusive).
|
||||||
|
min_words: Minimum word count (inclusive).
|
||||||
|
max_words: Maximum word count (inclusive).
|
||||||
|
tokeniser: Tokeniser to use for word counting. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If no constraints provided or invalid values.
|
||||||
|
"""
|
||||||
|
if all(v is None for v in (min_chars, max_chars, min_words, max_words)):
|
||||||
|
raise InvalidThresholdError("At least one length constraint must be set")
|
||||||
|
|
||||||
|
if min_chars is not None and min_chars < 0:
|
||||||
|
raise InvalidThresholdError(f"min_chars must be >= 0, got {min_chars}")
|
||||||
|
if max_chars is not None and max_chars < 0:
|
||||||
|
raise InvalidThresholdError(f"max_chars must be >= 0, got {max_chars}")
|
||||||
|
if min_words is not None and min_words < 0:
|
||||||
|
raise InvalidThresholdError(f"min_words must be >= 0, got {min_words}")
|
||||||
|
if max_words is not None and max_words < 0:
|
||||||
|
raise InvalidThresholdError(f"max_words must be >= 0, got {max_words}")
|
||||||
|
|
||||||
|
if min_chars is not None and max_chars is not None and min_chars > max_chars:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_chars ({min_chars}) cannot exceed max_chars ({max_chars})"
|
||||||
|
)
|
||||||
|
if min_words is not None and max_words is not None and min_words > max_words:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_words ({min_words}) cannot exceed max_words ({max_words})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_chars = min_chars
|
||||||
|
self._max_chars = max_chars
|
||||||
|
self._min_words = min_words
|
||||||
|
self._max_words = max_words
|
||||||
|
self._tokeniser = tokeniser or WordTokeniser()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "length"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the length check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for length checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
char_count = len(text)
|
||||||
|
words = self._tokeniser.tokenise(text)
|
||||||
|
word_count = len(words)
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
if self._min_chars is not None and char_count < self._min_chars:
|
||||||
|
failures.append(f"{char_count} chars < min {self._min_chars}")
|
||||||
|
if self._max_chars is not None and char_count > self._max_chars:
|
||||||
|
failures.append(f"{char_count} chars > max {self._max_chars}")
|
||||||
|
if self._min_words is not None and word_count < self._min_words:
|
||||||
|
failures.append(f"{word_count} words < min {self._min_words}")
|
||||||
|
if self._max_words is not None and word_count > self._max_words:
|
||||||
|
failures.append(f"{word_count} words > max {self._max_words}")
|
||||||
|
|
||||||
|
passed = len(failures) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = f"Length check passed: {char_count} chars, {word_count} words"
|
||||||
|
else:
|
||||||
|
message = "Length check failed: " + "; ".join(failures)
|
||||||
|
|
||||||
|
actual = {"chars": char_count, "words": word_count}
|
||||||
|
threshold = {}
|
||||||
|
if self._min_chars is not None:
|
||||||
|
threshold["min_chars"] = self._min_chars
|
||||||
|
if self._max_chars is not None:
|
||||||
|
threshold["max_chars"] = self._max_chars
|
||||||
|
if self._min_words is not None:
|
||||||
|
threshold["min_words"] = self._min_words
|
||||||
|
if self._max_words is not None:
|
||||||
|
threshold["max_words"] = self._max_words
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual,
|
||||||
|
threshold=threshold,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadabilityValidator:
|
||||||
|
"""Validates Flesch-Kincaid readability."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_grade: float | None = None,
|
||||||
|
min_ease: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the readability validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_grade: Maximum Flesch-Kincaid grade level allowed.
|
||||||
|
min_ease: Minimum Flesch Reading Ease score required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If no constraints provided.
|
||||||
|
"""
|
||||||
|
if max_grade is None and min_ease is None:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
"At least one of max_grade or min_ease must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._max_grade = max_grade
|
||||||
|
self._min_ease = min_ease
|
||||||
|
self._metric = Readability()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "readability"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the readability check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for readability checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
result = self._metric.score(text)
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
if (
|
||||||
|
self._max_grade is not None
|
||||||
|
and result.flesch_kincaid_grade > self._max_grade
|
||||||
|
):
|
||||||
|
failures.append(
|
||||||
|
f"grade level {result.flesch_kincaid_grade:.1f} "
|
||||||
|
f"> max {self._max_grade:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._min_ease is not None and result.flesch_reading_ease < self._min_ease:
|
||||||
|
failures.append(
|
||||||
|
f"reading ease {result.flesch_reading_ease:.1f} "
|
||||||
|
f"< min {self._min_ease:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
passed = len(failures) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
parts = []
|
||||||
|
if self._max_grade is not None:
|
||||||
|
parts.append(
|
||||||
|
f"grade {result.flesch_kincaid_grade:.1f} <= {self._max_grade:.1f}"
|
||||||
|
)
|
||||||
|
if self._min_ease is not None:
|
||||||
|
parts.append(
|
||||||
|
f"ease {result.flesch_reading_ease:.1f} >= {self._min_ease:.1f}"
|
||||||
|
)
|
||||||
|
message = "Readability: " + ", ".join(parts)
|
||||||
|
else:
|
||||||
|
message = "Readability: " + "; ".join(failures)
|
||||||
|
|
||||||
|
actual = {
|
||||||
|
"grade": result.flesch_kincaid_grade,
|
||||||
|
"ease": result.flesch_reading_ease,
|
||||||
|
}
|
||||||
|
threshold = {}
|
||||||
|
if self._max_grade is not None:
|
||||||
|
threshold["max_grade"] = self._max_grade
|
||||||
|
if self._min_ease is not None:
|
||||||
|
threshold["min_ease"] = self._min_ease
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual,
|
||||||
|
threshold=threshold,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContainsValidator:
|
||||||
|
"""Validates text contains required patterns."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the contains validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If patterns list is empty.
|
||||||
|
"""
|
||||||
|
if not patterns:
|
||||||
|
raise InvalidThresholdError("patterns list cannot be empty")
|
||||||
|
|
||||||
|
self._patterns = patterns
|
||||||
|
self._case_sensitive = case_sensitive
|
||||||
|
self._flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "contains"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the contains check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for contains checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
missing = []
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if not re.search(pattern, text, self._flags):
|
||||||
|
missing.append(pattern)
|
||||||
|
|
||||||
|
passed = len(missing) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = f"Text contains all {len(self._patterns)} required pattern(s)"
|
||||||
|
else:
|
||||||
|
message = f"Text missing {len(missing)} pattern(s): {missing}"
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual={"found": len(self._patterns) - len(missing), "missing": missing},
|
||||||
|
threshold={"patterns": self._patterns},
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExcludesValidator:
|
||||||
|
"""Validates text excludes forbidden patterns."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the excludes validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must not be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If patterns list is empty.
|
||||||
|
"""
|
||||||
|
if not patterns:
|
||||||
|
raise InvalidThresholdError("patterns list cannot be empty")
|
||||||
|
|
||||||
|
self._patterns = patterns
|
||||||
|
self._case_sensitive = case_sensitive
|
||||||
|
self._flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "excludes"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the excludes check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for excludes checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
found = []
|
||||||
|
for pattern in self._patterns:
|
||||||
|
if re.search(pattern, text, self._flags):
|
||||||
|
found.append(pattern)
|
||||||
|
|
||||||
|
passed = len(found) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = f"Text excludes all {len(self._patterns)} forbidden pattern(s)"
|
||||||
|
else:
|
||||||
|
message = f"Text contains {len(found)} forbidden pattern(s): {found}"
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual={"excluded": len(self._patterns) - len(found), "found": found},
|
||||||
|
threshold={"patterns": self._patterns},
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
370
src/veritext/validators/metric.py
Normal file
370
src/veritext/validators/metric.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""Metric-based validators that require reference text."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError, ValidationError
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.lexical import Lexical
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
|
|
||||||
|
class BleuValidator:
|
||||||
|
"""Validates that BLEU score meets minimum threshold."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal[1, 2, 3, 4] = 4,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the BLEU validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum BLEU score required (0.0 to 1.0).
|
||||||
|
variant: BLEU variant to use (1, 2, 3, or 4). Defaults to 4.
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||||
|
"""
|
||||||
|
if not 0.0 <= min_score <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||||
|
)
|
||||||
|
if variant not in (1, 2, 3, 4):
|
||||||
|
raise InvalidThresholdError(f"variant must be 1, 2, 3, or 4, got {variant}")
|
||||||
|
|
||||||
|
self._min_score = min_score
|
||||||
|
self._variant = variant
|
||||||
|
self._metric = Bleu(tokeniser=tokeniser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return f"bleu-{self._variant}"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the BLEU check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
|
||||||
|
# Select the appropriate BLEU variant
|
||||||
|
score_map = {
|
||||||
|
1: result.bleu1,
|
||||||
|
2: result.bleu2,
|
||||||
|
3: result.bleu3,
|
||||||
|
4: result.bleu4,
|
||||||
|
}
|
||||||
|
actual_score = score_map[self._variant]
|
||||||
|
passed = actual_score >= self._min_score
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = (
|
||||||
|
f"BLEU-{self._variant} score {actual_score:.2f} "
|
||||||
|
f"meets minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
f"BLEU-{self._variant} score {actual_score:.2f} "
|
||||||
|
f"below minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual_score,
|
||||||
|
threshold=self._min_score,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RougeValidator:
|
||||||
|
"""Validates that ROUGE score meets minimum threshold."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal["1", "2", "l"] = "l",
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the ROUGE validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum ROUGE F-measure required (0.0 to 1.0).
|
||||||
|
variant: ROUGE variant ("1", "2", or "l"). Defaults to "l".
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||||
|
"""
|
||||||
|
if not 0.0 <= min_score <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||||
|
)
|
||||||
|
if variant not in ("1", "2", "l"):
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"variant must be '1', '2', or 'l', got '{variant}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_score = min_score
|
||||||
|
self._variant = variant
|
||||||
|
self._metric = Rouge(tokeniser=tokeniser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return f"rouge-{self._variant}"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the ROUGE check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
|
||||||
|
# Select the appropriate ROUGE variant (use F-measure)
|
||||||
|
score_map = {
|
||||||
|
"1": result.rouge1.fmeasure,
|
||||||
|
"2": result.rouge2.fmeasure,
|
||||||
|
"l": result.rouge_l.fmeasure,
|
||||||
|
}
|
||||||
|
actual_score = score_map[self._variant]
|
||||||
|
passed = actual_score >= self._min_score
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = (
|
||||||
|
f"ROUGE-{self._variant.upper()} score {actual_score:.2f} "
|
||||||
|
f"meets minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
f"ROUGE-{self._variant.upper()} score {actual_score:.2f} "
|
||||||
|
f"below minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual_score,
|
||||||
|
threshold=self._min_score,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LexicalValidator:
|
||||||
|
"""Validates lexical similarity meets threshold."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_jaccard: float | None = None,
|
||||||
|
min_overlap: float | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the lexical validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_jaccard: Minimum Jaccard similarity required (0.0 to 1.0).
|
||||||
|
min_overlap: Minimum token overlap required (0.0 to 1.0).
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If thresholds are invalid or none provided.
|
||||||
|
"""
|
||||||
|
if min_jaccard is None and min_overlap is None:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
"At least one of min_jaccard or min_overlap must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
if min_jaccard is not None and not 0.0 <= min_jaccard <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_jaccard must be between 0.0 and 1.0, got {min_jaccard}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if min_overlap is not None and not 0.0 <= min_overlap <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_overlap must be between 0.0 and 1.0, got {min_overlap}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_jaccard = min_jaccard
|
||||||
|
self._min_overlap = min_overlap
|
||||||
|
self._metric = Lexical(tokeniser=tokeniser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "lexical"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the lexical similarity check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
|
||||||
|
# Check each threshold that was specified
|
||||||
|
failures = []
|
||||||
|
if self._min_jaccard is not None and result.jaccard < self._min_jaccard:
|
||||||
|
failures.append(
|
||||||
|
f"Jaccard {result.jaccard:.2f} below minimum {self._min_jaccard:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._min_overlap is not None and result.token_overlap < self._min_overlap:
|
||||||
|
failures.append(
|
||||||
|
f"token overlap {result.token_overlap:.2f} "
|
||||||
|
f"below minimum {self._min_overlap:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
passed = len(failures) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
parts = []
|
||||||
|
if self._min_jaccard is not None:
|
||||||
|
parts.append(f"Jaccard {result.jaccard:.2f} >= {self._min_jaccard:.2f}")
|
||||||
|
if self._min_overlap is not None:
|
||||||
|
parts.append(
|
||||||
|
f"overlap {result.token_overlap:.2f} >= {self._min_overlap:.2f}"
|
||||||
|
)
|
||||||
|
message = "Lexical similarity: " + ", ".join(parts)
|
||||||
|
else:
|
||||||
|
message = "Lexical similarity: " + "; ".join(failures)
|
||||||
|
|
||||||
|
# Build actual value dict
|
||||||
|
actual = {"jaccard": result.jaccard, "token_overlap": result.token_overlap}
|
||||||
|
threshold = {}
|
||||||
|
if self._min_jaccard is not None:
|
||||||
|
threshold["min_jaccard"] = self._min_jaccard
|
||||||
|
if self._min_overlap is not None:
|
||||||
|
threshold["min_overlap"] = self._min_overlap
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual,
|
||||||
|
threshold=threshold,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticValidator:
|
||||||
|
"""Validates that semantic similarity meets minimum threshold.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_score: float,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the semantic validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum semantic similarity required (0.0 to 1.0).
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||||
|
DependencyError: If sentence-transformers is not installed.
|
||||||
|
"""
|
||||||
|
if not 0.0 <= min_score <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_score = min_score
|
||||||
|
# Lazy import to avoid loading PyTorch unless needed
|
||||||
|
from veritext.semantic.similarity import SemanticSimilarity
|
||||||
|
|
||||||
|
self._metric: SemanticSimilarity = SemanticSimilarity(
|
||||||
|
model=model, cache_embeddings=cache_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "semantic"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the semantic similarity check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
passed = result.similarity >= self._min_score
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = (
|
||||||
|
f"Semantic similarity {result.similarity:.2f} "
|
||||||
|
f"meets minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
f"Semantic similarity {result.similarity:.2f} "
|
||||||
|
f"below minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=result.similarity,
|
||||||
|
threshold=self._min_score,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
1
tests/test_semantic/__init__.py
Normal file
1
tests/test_semantic/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for semantic similarity module."""
|
||||||
240
tests/test_semantic/test_similarity.py
Normal file
240
tests/test_semantic/test_similarity.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Tests for the semantic similarity metric."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Skip all tests if sentence-transformers is not installed
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
|
||||||
|
from veritext.metrics.results import SemanticResult
|
||||||
|
from veritext.semantic import SemanticSimilarity
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSimilarity:
|
||||||
|
"""Tests for the SemanticSimilarity metric class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def semantic(self) -> SemanticSimilarity:
|
||||||
|
"""Provide a SemanticSimilarity metric instance."""
|
||||||
|
return SemanticSimilarity()
|
||||||
|
|
||||||
|
def test_name(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that name returns 'semantic'."""
|
||||||
|
assert semantic.name == "semantic"
|
||||||
|
|
||||||
|
def test_requires_reference(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that semantic similarity requires reference text."""
|
||||||
|
assert semantic.requires_reference is True
|
||||||
|
|
||||||
|
def test_identical_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that identical texts produce high similarity."""
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Identical texts should have very high similarity (close to 1.0)
|
||||||
|
assert result.similarity >= 0.99
|
||||||
|
assert result.model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_semantically_similar_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that semantically similar texts have high similarity."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
reference = "A feline rested on the rug"
|
||||||
|
result = semantic.score(candidate, reference)
|
||||||
|
|
||||||
|
# Similar meanings should have reasonable similarity
|
||||||
|
assert result.similarity > 0.3
|
||||||
|
|
||||||
|
def test_unrelated_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that unrelated texts have low similarity."""
|
||||||
|
candidate = "The quick brown fox"
|
||||||
|
reference = "Quantum physics describes particle behaviour"
|
||||||
|
result = semantic.score(candidate, reference)
|
||||||
|
|
||||||
|
# Unrelated texts should have low similarity
|
||||||
|
assert result.similarity < 0.5
|
||||||
|
|
||||||
|
def test_empty_candidate(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that empty candidate returns zero similarity."""
|
||||||
|
result = semantic.score("", "The cat sat on the mat")
|
||||||
|
assert result.similarity == 0.0
|
||||||
|
|
||||||
|
def test_whitespace_only_candidate(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that whitespace-only candidate returns zero similarity."""
|
||||||
|
result = semantic.score(" \t\n ", "The cat sat on the mat")
|
||||||
|
assert result.similarity == 0.0
|
||||||
|
|
||||||
|
def test_none_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that None reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
semantic.score("The cat sat", None)
|
||||||
|
|
||||||
|
def test_empty_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that empty reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
semantic.score("The cat sat", "")
|
||||||
|
|
||||||
|
def test_whitespace_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that whitespace-only reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
semantic.score("The cat sat", " \t\n ")
|
||||||
|
|
||||||
|
def test_multiple_references(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test semantic similarity with multiple references uses max."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"A dog ran through the park",
|
||||||
|
"The cat sat on the mat", # Exact match
|
||||||
|
]
|
||||||
|
result = semantic.score(candidate, references)
|
||||||
|
|
||||||
|
# Should get high similarity due to exact match reference
|
||||||
|
assert result.similarity >= 0.99
|
||||||
|
|
||||||
|
def test_multiple_references_takes_max(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that multiple references returns maximum similarity."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"Quantum physics is complex", # Low similarity
|
||||||
|
"A feline rested on the rug", # Higher similarity
|
||||||
|
]
|
||||||
|
result = semantic.score(candidate, references)
|
||||||
|
|
||||||
|
# Should use the higher similarity
|
||||||
|
assert result.similarity > 0.3
|
||||||
|
|
||||||
|
def test_result_score_property(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that result.score returns similarity."""
|
||||||
|
result = semantic.score("The cat sat", "The cat sat")
|
||||||
|
assert result.score == result.similarity
|
||||||
|
|
||||||
|
def test_caching_behaviour(self) -> None:
|
||||||
|
"""Test that caching works for repeated texts."""
|
||||||
|
semantic = SemanticSimilarity(cache_embeddings=True)
|
||||||
|
|
||||||
|
# Score same texts multiple times
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result1 = semantic.score(text, text)
|
||||||
|
result2 = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
assert result1.similarity == result2.similarity
|
||||||
|
|
||||||
|
# Clear cache and check again
|
||||||
|
semantic.clear_cache()
|
||||||
|
result3 = semantic.score(text, text)
|
||||||
|
assert result3.similarity == result1.similarity
|
||||||
|
|
||||||
|
def test_caching_disabled(self) -> None:
|
||||||
|
"""Test that caching can be disabled."""
|
||||||
|
semantic = SemanticSimilarity(cache_embeddings=False)
|
||||||
|
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result1 = semantic.score(text, text)
|
||||||
|
result2 = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Results should still be identical (just not cached)
|
||||||
|
assert result1.similarity == result2.similarity
|
||||||
|
|
||||||
|
# Clear cache should not raise even when disabled
|
||||||
|
semantic.clear_cache()
|
||||||
|
|
||||||
|
def test_custom_model(self) -> None:
|
||||||
|
"""Test that custom model name is recorded in result."""
|
||||||
|
# Use the same model but verify it's recorded correctly
|
||||||
|
semantic = SemanticSimilarity(model="all-MiniLM-L6-v2")
|
||||||
|
result = semantic.score("Test text", "Test text")
|
||||||
|
assert result.model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSimilarityBatch:
|
||||||
|
"""Tests for semantic similarity batch scoring."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def semantic(self) -> SemanticSimilarity:
|
||||||
|
"""Provide a SemanticSimilarity metric instance."""
|
||||||
|
return SemanticSimilarity()
|
||||||
|
|
||||||
|
def test_batch_score_basic(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test basic batch scoring."""
|
||||||
|
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||||
|
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
assert len(result.results) == 2
|
||||||
|
# Identical texts should have very high similarity
|
||||||
|
assert all(r.similarity >= 0.99 for r in result.results)
|
||||||
|
|
||||||
|
def test_batch_score_statistics(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that batch scoring computes statistics."""
|
||||||
|
candidates = ["The cat sat", "Quantum physics is complex"]
|
||||||
|
references = ["The cat sat", "The cat sat"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
# Check statistics are computed
|
||||||
|
assert "similarity" in result.stats
|
||||||
|
|
||||||
|
# Mean should be between min and max
|
||||||
|
stats = result.stats["similarity"]
|
||||||
|
assert stats.min <= stats.mean <= stats.max
|
||||||
|
|
||||||
|
def test_batch_score_percentiles(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that batch scoring computes percentiles."""
|
||||||
|
candidates = ["a", "b", "c", "d", "e"]
|
||||||
|
references = ["a", "b", "c", "d", "e"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
stats = result.stats["similarity"]
|
||||||
|
assert 25 in stats.percentiles
|
||||||
|
assert 50 in stats.percentiles
|
||||||
|
assert 75 in stats.percentiles
|
||||||
|
assert 95 in stats.percentiles
|
||||||
|
|
||||||
|
def test_batch_score_none_references_raises(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test that batch scoring raises for None references."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
semantic.batch_score(["text"], None)
|
||||||
|
|
||||||
|
def test_batch_score_length_mismatch_raises(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test that batch scoring raises for mismatched lengths."""
|
||||||
|
with pytest.raises(ValueError, match="must match"):
|
||||||
|
semantic.batch_score(["a", "b"], ["a"])
|
||||||
|
|
||||||
|
def test_batch_score_with_multiple_references(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test batch scoring with multiple references per candidate."""
|
||||||
|
candidates = [
|
||||||
|
"The cat sat on the mat",
|
||||||
|
"A quick brown dog runs fast",
|
||||||
|
]
|
||||||
|
references = [
|
||||||
|
["The cat sat on the mat", "A cat rests on floor"],
|
||||||
|
["A quick brown dog runs fast", "Dogs run very quickly"],
|
||||||
|
]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
# First pair has exact match
|
||||||
|
assert result.results[0].similarity >= 0.99
|
||||||
|
assert result.results[1].similarity >= 0.99
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticResult:
|
||||||
|
"""Tests for SemanticResult type."""
|
||||||
|
|
||||||
|
def test_frozen(self) -> None:
|
||||||
|
"""Test that SemanticResult is frozen."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
result = SemanticResult(similarity=0.85, model="test-model")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
result.similarity = 0.9 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_score_property(self) -> None:
|
||||||
|
"""Test that score property returns similarity."""
|
||||||
|
result = SemanticResult(similarity=0.75, model="test-model")
|
||||||
|
assert result.score == 0.75
|
||||||
1
tests/test_validators/__init__.py
Normal file
1
tests/test_validators/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the validators module."""
|
||||||
198
tests/test_validators/test_composite.py
Normal file
198
tests/test_validators/test_composite.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Tests for composite validators."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import all_of, any_of, bleu, contains, excludes, length
|
||||||
|
from veritext.validators.composite import AllOf, AnyOf
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllOf:
|
||||||
|
"""Tests for AllOf composite validator."""
|
||||||
|
|
||||||
|
def test_all_of_passes_when_all_checks_pass(self) -> None:
|
||||||
|
"""Test that AllOf passes when all checks pass."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=2),
|
||||||
|
contains(patterns=["hello"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
assert all(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_all_of_fails_when_one_check_fails(self) -> None:
|
||||||
|
"""Test that AllOf fails when any check fails."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=2),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
assert len(result.failed_checks) == 1
|
||||||
|
|
||||||
|
def test_all_of_fails_when_all_checks_fail(self) -> None:
|
||||||
|
"""Test that AllOf fails when all checks fail."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert len(result.failed_checks) == 2
|
||||||
|
|
||||||
|
def test_all_of_with_metric_validators(self) -> None:
|
||||||
|
"""Test AllOf with metric-based validators."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
bleu(min_score=0.5),
|
||||||
|
length(min_words=3),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext(reference="the quick brown fox")
|
||||||
|
result = validator.check("the quick brown fox jumps", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
|
||||||
|
def test_all_of_failure_summary(self) -> None:
|
||||||
|
"""Test the failure summary property."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
summary = result.failure_summary
|
||||||
|
assert "failed" in summary.lower()
|
||||||
|
assert "length" in summary
|
||||||
|
assert "contains" in summary
|
||||||
|
|
||||||
|
def test_all_of_raises_on_empty_checks(self) -> None:
|
||||||
|
"""Test that empty checks list raises error."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
AllOf(checks=[])
|
||||||
|
|
||||||
|
def test_all_of_name_property(self) -> None:
|
||||||
|
"""Test the name property."""
|
||||||
|
validator = AllOf(checks=[length(min_chars=1)])
|
||||||
|
assert validator.name == "all_of"
|
||||||
|
|
||||||
|
def test_all_of_factory_function(self) -> None:
|
||||||
|
"""Test the all_of() factory function."""
|
||||||
|
validator = all_of(checks=[length(min_chars=1)])
|
||||||
|
assert isinstance(validator, AllOf)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnyOf:
|
||||||
|
"""Tests for AnyOf composite validator."""
|
||||||
|
|
||||||
|
def test_any_of_passes_when_any_check_passes(self) -> None:
|
||||||
|
"""Test that AnyOf passes when any check passes."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10), # Will fail
|
||||||
|
contains(patterns=["hello"]), # Will pass
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
# At least one check passed
|
||||||
|
assert any(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_any_of_passes_when_all_checks_pass(self) -> None:
|
||||||
|
"""Test that AnyOf passes when all checks pass."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=2),
|
||||||
|
contains(patterns=["hello"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert all(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_any_of_fails_when_all_checks_fail(self) -> None:
|
||||||
|
"""Test that AnyOf fails when all checks fail."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert not any(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_any_of_with_metric_validators(self) -> None:
|
||||||
|
"""Test AnyOf with metric-based validators."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
bleu(min_score=0.9), # Might fail
|
||||||
|
length(min_words=3), # Should pass
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext(reference="different text entirely")
|
||||||
|
result = validator.check("the quick brown fox jumps", context)
|
||||||
|
|
||||||
|
assert result.passed is True # Length check passes
|
||||||
|
|
||||||
|
def test_any_of_with_excludes(self) -> None:
|
||||||
|
"""Test AnyOf with excludes validator."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
excludes(patterns=["error"]),
|
||||||
|
excludes(patterns=["warning"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
# Should pass - neither pattern found
|
||||||
|
result = validator.check("All is well", context)
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
# Should pass - one pattern found, other not
|
||||||
|
result = validator.check("This is an error", context)
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
# Should fail - both patterns found
|
||||||
|
result = validator.check("error and warning", context)
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_any_of_raises_on_empty_checks(self) -> None:
|
||||||
|
"""Test that empty checks list raises error."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
AnyOf(checks=[])
|
||||||
|
|
||||||
|
def test_any_of_name_property(self) -> None:
|
||||||
|
"""Test the name property."""
|
||||||
|
validator = AnyOf(checks=[length(min_chars=1)])
|
||||||
|
assert validator.name == "any_of"
|
||||||
|
|
||||||
|
def test_any_of_factory_function(self) -> None:
|
||||||
|
"""Test the any_of() factory function."""
|
||||||
|
validator = any_of(checks=[length(min_chars=1)])
|
||||||
|
assert isinstance(validator, AnyOf)
|
||||||
334
tests/test_validators/test_constraint.py
Normal file
334
tests/test_validators/test_constraint.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
"""Tests for constraint validators."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import contains, excludes, length, readability
|
||||||
|
from veritext.validators.constraint import (
|
||||||
|
ContainsValidator,
|
||||||
|
ExcludesValidator,
|
||||||
|
LengthValidator,
|
||||||
|
ReadabilityValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLengthValidator:
|
||||||
|
"""Tests for LengthValidator."""
|
||||||
|
|
||||||
|
def test_length_validator_min_chars_passes(self) -> None:
|
||||||
|
"""Test that validator passes when char count meets minimum."""
|
||||||
|
validator = LengthValidator(min_chars=10)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world!", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "length"
|
||||||
|
assert result.actual["chars"] == 12
|
||||||
|
|
||||||
|
def test_length_validator_min_chars_fails(self) -> None:
|
||||||
|
"""Test that validator fails when char count below minimum."""
|
||||||
|
validator = LengthValidator(min_chars=20)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "< min" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_max_chars_passes(self) -> None:
|
||||||
|
"""Test that validator passes when char count within maximum."""
|
||||||
|
validator = LengthValidator(max_chars=20)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual["chars"] == 11
|
||||||
|
|
||||||
|
def test_length_validator_max_chars_fails(self) -> None:
|
||||||
|
"""Test that validator fails when char count exceeds maximum."""
|
||||||
|
validator = LengthValidator(max_chars=5)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "> max" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_min_words_passes(self) -> None:
|
||||||
|
"""Test that validator passes when word count meets minimum."""
|
||||||
|
validator = LengthValidator(min_words=3)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual["words"] == 4
|
||||||
|
|
||||||
|
def test_length_validator_min_words_fails(self) -> None:
|
||||||
|
"""Test that validator fails when word count below minimum."""
|
||||||
|
validator = LengthValidator(min_words=10)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "words < min" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_max_words_passes(self) -> None:
|
||||||
|
"""Test that validator passes when word count within maximum."""
|
||||||
|
validator = LengthValidator(max_words=5)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_length_validator_max_words_fails(self) -> None:
|
||||||
|
"""Test that validator fails when word count exceeds maximum."""
|
||||||
|
validator = LengthValidator(max_words=2)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "words > max" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_combined_constraints(self) -> None:
|
||||||
|
"""Test validator with multiple constraints."""
|
||||||
|
validator = LengthValidator(
|
||||||
|
min_chars=5, max_chars=50, min_words=2, max_words=10
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert "min_chars" in result.threshold
|
||||||
|
assert "max_chars" in result.threshold
|
||||||
|
assert "min_words" in result.threshold
|
||||||
|
assert "max_words" in result.threshold
|
||||||
|
|
||||||
|
def test_length_validator_raises_when_no_constraints(self) -> None:
|
||||||
|
"""Test that validator raises when no constraints provided."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="At least one"):
|
||||||
|
LengthValidator()
|
||||||
|
|
||||||
|
def test_length_validator_raises_on_negative_values(self) -> None:
|
||||||
|
"""Test that negative constraint values raise error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_chars must be >= 0"):
|
||||||
|
LengthValidator(min_chars=-1)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="max_chars must be >= 0"):
|
||||||
|
LengthValidator(max_chars=-1)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_words must be >= 0"):
|
||||||
|
LengthValidator(min_words=-1)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="max_words must be >= 0"):
|
||||||
|
LengthValidator(max_words=-1)
|
||||||
|
|
||||||
|
def test_length_validator_raises_on_invalid_range(self) -> None:
|
||||||
|
"""Test that min > max raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot exceed max_chars"):
|
||||||
|
LengthValidator(min_chars=100, max_chars=50)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot exceed max_words"):
|
||||||
|
LengthValidator(min_words=20, max_words=5)
|
||||||
|
|
||||||
|
def test_length_factory_function(self) -> None:
|
||||||
|
"""Test the length() factory function."""
|
||||||
|
validator = length(min_chars=10, max_words=100)
|
||||||
|
assert isinstance(validator, LengthValidator)
|
||||||
|
assert validator.name == "length"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadabilityValidator:
|
||||||
|
"""Tests for ReadabilityValidator."""
|
||||||
|
|
||||||
|
def test_readability_validator_max_grade_passes(self) -> None:
|
||||||
|
"""Test that validator passes when grade level within maximum."""
|
||||||
|
validator = ReadabilityValidator(max_grade=12.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
# Simple text should have low grade level
|
||||||
|
result = validator.check("The cat sat on the mat. It was a nice day.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "readability"
|
||||||
|
assert "grade" in result.actual
|
||||||
|
|
||||||
|
def test_readability_validator_max_grade_fails(self) -> None:
|
||||||
|
"""Test that validator fails when grade level exceeds maximum."""
|
||||||
|
validator = ReadabilityValidator(max_grade=1.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
# Complex text
|
||||||
|
result = validator.check(
|
||||||
|
"The implementation of sophisticated methodologies necessitates "
|
||||||
|
"comprehensive analytical frameworks for systematic evaluation.",
|
||||||
|
context,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "grade level" in result.message
|
||||||
|
assert "> max" in result.message
|
||||||
|
|
||||||
|
def test_readability_validator_min_ease_passes(self) -> None:
|
||||||
|
"""Test that validator passes when reading ease meets minimum."""
|
||||||
|
validator = ReadabilityValidator(min_ease=30.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
# Simple text should have high reading ease
|
||||||
|
result = validator.check("The cat sat. The dog ran. It was fun.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert "ease" in result.actual
|
||||||
|
|
||||||
|
def test_readability_validator_min_ease_fails(self) -> None:
|
||||||
|
"""Test that validator fails when reading ease below minimum."""
|
||||||
|
validator = ReadabilityValidator(min_ease=100.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check(
|
||||||
|
"The implementation of sophisticated methodologies necessitates "
|
||||||
|
"comprehensive analytical frameworks.",
|
||||||
|
context,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "reading ease" in result.message
|
||||||
|
assert "< min" in result.message
|
||||||
|
|
||||||
|
def test_readability_validator_combined_constraints(self) -> None:
|
||||||
|
"""Test validator with both grade and ease constraints."""
|
||||||
|
validator = ReadabilityValidator(max_grade=12.0, min_ease=30.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("The cat sat on the mat.", context)
|
||||||
|
|
||||||
|
assert "max_grade" in result.threshold
|
||||||
|
assert "min_ease" in result.threshold
|
||||||
|
|
||||||
|
def test_readability_validator_raises_when_no_constraints(self) -> None:
|
||||||
|
"""Test that validator raises when no constraints provided."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="At least one"):
|
||||||
|
ReadabilityValidator()
|
||||||
|
|
||||||
|
def test_readability_factory_function(self) -> None:
|
||||||
|
"""Test the readability() factory function."""
|
||||||
|
validator = readability(max_grade=8.0, min_ease=60.0)
|
||||||
|
assert isinstance(validator, ReadabilityValidator)
|
||||||
|
assert validator.name == "readability"
|
||||||
|
|
||||||
|
|
||||||
|
class TestContainsValidator:
|
||||||
|
"""Tests for ContainsValidator."""
|
||||||
|
|
||||||
|
def test_contains_validator_passes_when_pattern_found(self) -> None:
|
||||||
|
"""Test that validator passes when all patterns are found."""
|
||||||
|
validator = ContainsValidator(patterns=["hello", "world"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("Hello World!", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "contains"
|
||||||
|
assert result.actual["found"] == 2
|
||||||
|
assert result.actual["missing"] == []
|
||||||
|
|
||||||
|
def test_contains_validator_fails_when_pattern_missing(self) -> None:
|
||||||
|
"""Test that validator fails when a pattern is missing."""
|
||||||
|
validator = ContainsValidator(patterns=["hello", "goodbye"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("Hello World!", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "goodbye" in result.actual["missing"]
|
||||||
|
assert "missing" in result.message
|
||||||
|
|
||||||
|
def test_contains_validator_case_insensitive_by_default(self) -> None:
|
||||||
|
"""Test that matching is case-insensitive by default."""
|
||||||
|
validator = ContainsValidator(patterns=["HELLO"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_contains_validator_case_sensitive(self) -> None:
|
||||||
|
"""Test case-sensitive matching."""
|
||||||
|
validator = ContainsValidator(patterns=["HELLO"], case_sensitive=True)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_contains_validator_regex_patterns(self) -> None:
|
||||||
|
"""Test regex pattern matching."""
|
||||||
|
validator = ContainsValidator(patterns=[r"\d{3}-\d{4}"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("Call 555-1234 for info", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_contains_validator_raises_on_empty_patterns(self) -> None:
|
||||||
|
"""Test that empty patterns list raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot be empty"):
|
||||||
|
ContainsValidator(patterns=[])
|
||||||
|
|
||||||
|
def test_contains_factory_function(self) -> None:
|
||||||
|
"""Test the contains() factory function."""
|
||||||
|
validator = contains(patterns=["test"], case_sensitive=True)
|
||||||
|
assert isinstance(validator, ContainsValidator)
|
||||||
|
assert validator.name == "contains"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExcludesValidator:
|
||||||
|
"""Tests for ExcludesValidator."""
|
||||||
|
|
||||||
|
def test_excludes_validator_passes_when_pattern_absent(self) -> None:
|
||||||
|
"""Test that validator passes when all patterns are absent."""
|
||||||
|
validator = ExcludesValidator(patterns=["bad", "forbidden"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is good text.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "excludes"
|
||||||
|
assert result.actual["found"] == []
|
||||||
|
|
||||||
|
def test_excludes_validator_fails_when_pattern_found(self) -> None:
|
||||||
|
"""Test that validator fails when a forbidden pattern is found."""
|
||||||
|
validator = ExcludesValidator(patterns=["bad", "forbidden"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is bad text.", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "bad" in result.actual["found"]
|
||||||
|
assert "forbidden" in result.message
|
||||||
|
|
||||||
|
def test_excludes_validator_case_insensitive_by_default(self) -> None:
|
||||||
|
"""Test that matching is case-insensitive by default."""
|
||||||
|
validator = ExcludesValidator(patterns=["BAD"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is bad text.", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_excludes_validator_case_sensitive(self) -> None:
|
||||||
|
"""Test case-sensitive matching."""
|
||||||
|
validator = ExcludesValidator(patterns=["BAD"], case_sensitive=True)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is bad text.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_excludes_validator_regex_patterns(self) -> None:
|
||||||
|
"""Test regex pattern matching."""
|
||||||
|
validator = ExcludesValidator(patterns=[r"\b\d{4}\b"]) # 4-digit numbers
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
# Should fail when pattern found
|
||||||
|
result = validator.check("PIN is 1234", context)
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
# Should pass when pattern absent
|
||||||
|
result = validator.check("No numbers here", context)
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_excludes_validator_raises_on_empty_patterns(self) -> None:
|
||||||
|
"""Test that empty patterns list raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot be empty"):
|
||||||
|
ExcludesValidator(patterns=[])
|
||||||
|
|
||||||
|
def test_excludes_factory_function(self) -> None:
|
||||||
|
"""Test the excludes() factory function."""
|
||||||
|
validator = excludes(patterns=["test"], case_sensitive=True)
|
||||||
|
assert isinstance(validator, ExcludesValidator)
|
||||||
|
assert validator.name == "excludes"
|
||||||
283
tests/test_validators/test_metric.py
Normal file
283
tests/test_validators/test_metric.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""Tests for metric-based validators."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError, ValidationError
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import bleu, lexical, rouge
|
||||||
|
from veritext.validators.metric import BleuValidator, LexicalValidator, RougeValidator
|
||||||
|
|
||||||
|
|
||||||
|
class TestBleuValidator:
|
||||||
|
"""Tests for BleuValidator."""
|
||||||
|
|
||||||
|
def test_bleu_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when BLEU score meets threshold."""
|
||||||
|
validator = BleuValidator(min_score=0.5, variant=4)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "bleu-4"
|
||||||
|
assert result.actual == 1.0 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_bleu_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when BLEU score is below threshold."""
|
||||||
|
validator = BleuValidator(min_score=0.9, variant=4)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through the park", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.name == "bleu-4"
|
||||||
|
assert result.actual < 0.9
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_bleu_validator_variant_selection(self) -> None:
|
||||||
|
"""Test different BLEU variants."""
|
||||||
|
context = ValidationContext(reference="the quick brown fox jumps")
|
||||||
|
|
||||||
|
for variant in (1, 2, 3, 4):
|
||||||
|
validator = BleuValidator(min_score=0.0, variant=variant) # type: ignore[arg-type]
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
assert result.name == f"bleu-{variant}"
|
||||||
|
|
||||||
|
def test_bleu_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
validator = BleuValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_bleu_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
BleuValidator(min_score=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
BleuValidator(min_score=-0.1)
|
||||||
|
|
||||||
|
def test_bleu_validator_raises_on_invalid_variant(self) -> None:
|
||||||
|
"""Test that invalid variant raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="variant must be"):
|
||||||
|
BleuValidator(min_score=0.5, variant=5) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_bleu_factory_function(self) -> None:
|
||||||
|
"""Test the bleu() factory function."""
|
||||||
|
validator = bleu(min_score=0.6, variant=2)
|
||||||
|
assert isinstance(validator, BleuValidator)
|
||||||
|
assert validator.name == "bleu-2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRougeValidator:
|
||||||
|
"""Tests for RougeValidator."""
|
||||||
|
|
||||||
|
def test_rouge_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when ROUGE score meets threshold."""
|
||||||
|
validator = RougeValidator(min_score=0.5, variant="l")
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "rouge-l"
|
||||||
|
assert result.actual == 1.0 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_rouge_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when ROUGE score is below threshold."""
|
||||||
|
validator = RougeValidator(min_score=0.9, variant="l")
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through the park", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.actual < 0.9
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_rouge_validator_variant_selection(self) -> None:
|
||||||
|
"""Test different ROUGE variants."""
|
||||||
|
context = ValidationContext(reference="the quick brown fox jumps")
|
||||||
|
|
||||||
|
for variant in ("1", "2", "l"):
|
||||||
|
validator = RougeValidator(min_score=0.0, variant=variant) # type: ignore[arg-type]
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
assert result.name == f"rouge-{variant}"
|
||||||
|
|
||||||
|
def test_rouge_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
validator = RougeValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_rouge_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
RougeValidator(min_score=1.5)
|
||||||
|
|
||||||
|
def test_rouge_validator_raises_on_invalid_variant(self) -> None:
|
||||||
|
"""Test that invalid variant raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="variant must be"):
|
||||||
|
RougeValidator(min_score=0.5, variant="3") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_rouge_factory_function(self) -> None:
|
||||||
|
"""Test the rouge() factory function."""
|
||||||
|
validator = rouge(min_score=0.6, variant="2")
|
||||||
|
assert isinstance(validator, RougeValidator)
|
||||||
|
assert validator.name == "rouge-2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLexicalValidator:
|
||||||
|
"""Tests for LexicalValidator."""
|
||||||
|
|
||||||
|
def test_lexical_validator_passes_on_jaccard(self) -> None:
|
||||||
|
"""Test that validator passes when Jaccard similarity meets threshold."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "lexical"
|
||||||
|
assert result.actual["jaccard"] == 1.0
|
||||||
|
|
||||||
|
def test_lexical_validator_fails_on_jaccard(self) -> None:
|
||||||
|
"""Test that validator fails when Jaccard is below threshold."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.9)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through the park", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "Jaccard" in result.message
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_lexical_validator_passes_on_overlap(self) -> None:
|
||||||
|
"""Test that validator passes when token overlap meets threshold."""
|
||||||
|
validator = LexicalValidator(min_overlap=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual["token_overlap"] == 1.0
|
||||||
|
|
||||||
|
def test_lexical_validator_fails_on_overlap(self) -> None:
|
||||||
|
"""Test that validator fails when overlap is below threshold."""
|
||||||
|
validator = LexicalValidator(min_overlap=0.9)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "overlap" in result.message
|
||||||
|
|
||||||
|
def test_lexical_validator_with_both_thresholds(self) -> None:
|
||||||
|
"""Test validator with both Jaccard and overlap thresholds."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.3, min_overlap=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat", context)
|
||||||
|
|
||||||
|
# Should check both thresholds
|
||||||
|
assert "min_jaccard" in result.threshold
|
||||||
|
assert "min_overlap" in result.threshold
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_when_no_threshold(self) -> None:
|
||||||
|
"""Test that validator raises when no threshold is provided."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="At least one"):
|
||||||
|
LexicalValidator()
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_on_invalid_jaccard(self) -> None:
|
||||||
|
"""Test that invalid Jaccard threshold raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_jaccard"):
|
||||||
|
LexicalValidator(min_jaccard=1.5)
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_on_invalid_overlap(self) -> None:
|
||||||
|
"""Test that invalid overlap threshold raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_overlap"):
|
||||||
|
LexicalValidator(min_overlap=-0.1)
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_lexical_factory_function(self) -> None:
|
||||||
|
"""Test the lexical() factory function."""
|
||||||
|
validator = lexical(min_jaccard=0.5, min_overlap=0.6)
|
||||||
|
assert isinstance(validator, LexicalValidator)
|
||||||
|
assert validator.name == "lexical"
|
||||||
|
|
||||||
|
|
||||||
|
# SemanticValidator tests - conditionally run if sentence-transformers is installed
|
||||||
|
class TestSemanticValidator:
|
||||||
|
"""Tests for SemanticValidator."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _skip_if_no_transformers() -> None:
|
||||||
|
"""Skip test if sentence-transformers is not installed."""
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
|
||||||
|
def test_semantic_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when semantic similarity meets threshold."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "semantic"
|
||||||
|
assert result.actual >= 0.99 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_semantic_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when semantic similarity is below threshold."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.99)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check(
|
||||||
|
"quantum physics describes particle behaviour", context
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.name == "semantic"
|
||||||
|
assert result.actual < 0.99
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_semantic_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_semantic_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error without loading model."""
|
||||||
|
# This test doesn't need sentence-transformers since validation happens first
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
SemanticValidator(min_score=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
SemanticValidator(min_score=-0.1)
|
||||||
|
|
||||||
|
def test_semantic_factory_function(self) -> None:
|
||||||
|
"""Test the semantic() factory function."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators import semantic
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = semantic(min_score=0.6)
|
||||||
|
assert isinstance(validator, SemanticValidator)
|
||||||
|
assert validator.name == "semantic"
|
||||||
Reference in New Issue
Block a user