Compare commits
48 Commits
feat/metri
...
docs/polis
| Author | SHA1 | Date | |
|---|---|---|---|
|
0699e97e1d
|
|||
|
7de4505e31
|
|||
|
564d663c78
|
|||
|
0b2bc6c688
|
|||
|
aa687f43cd
|
|||
|
f18427e123
|
|||
|
1754556c99
|
|||
|
13c869f5d6
|
|||
|
93515707cc
|
|||
|
3cde5aba77
|
|||
|
69966d171c
|
|||
|
d5df8b52e6
|
|||
|
8b7c087de7
|
|||
|
c54f8c3f6f
|
|||
|
0cadfd4d23
|
|||
|
e128720917
|
|||
|
f713d5e8a6
|
|||
|
9853b57843
|
|||
|
55faae3e1b
|
|||
|
07ac70e835
|
|||
|
6d1bece815
|
|||
|
40fa39485e
|
|||
|
9115f0c25b
|
|||
|
83c4b4bee5
|
|||
|
44e3e8f4ea
|
|||
|
45dfe07772
|
|||
|
6bafc43754
|
|||
|
012b306749
|
|||
|
ac7c5c69cf
|
|||
|
cd36c54e22
|
|||
|
107fc4e275
|
|||
|
571b770281
|
|||
|
8b3536873e
|
|||
|
9a4ac359a3
|
|||
|
de5ad93524
|
|||
|
cab8099d06
|
|||
|
e2be3daffd
|
|||
|
9239300fd9
|
|||
|
b9f805b2f4
|
|||
|
75cd7b68de
|
|||
|
b2b5eb1518
|
|||
|
9e7b0131b3
|
|||
|
b8ab5811dd
|
|||
|
62fac688e4
|
|||
|
14ac7dbbb9
|
|||
|
aad933f9c4
|
|||
|
2a7476046d
|
|||
|
914c738013
|
@@ -83,6 +83,11 @@ Each layer depends only on layers below it.
|
|||||||
|
|
||||||
## Git Workflow
|
## Git Workflow
|
||||||
|
|
||||||
|
### Before Starting Work
|
||||||
|
|
||||||
|
When starting work from a plan, create a new branch matching the plan's scope before
|
||||||
|
making any changes. Do not reuse an existing branch from previous work, even if related.
|
||||||
|
|
||||||
### Commits
|
### Commits
|
||||||
|
|
||||||
- Format: `type(scope): description`
|
- Format: `type(scope): description`
|
||||||
|
|||||||
93
changelog.md
93
changelog.md
@@ -7,15 +7,108 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- Refactored CLI metric computation to eliminate code duplication
|
||||||
|
- Version format updated from `0.1.0-dev` to `0.1.0.dev0` (PEP 440 compliance)
|
||||||
|
- Settings instance is now cached via `@lru_cache` for better performance
|
||||||
|
- Documented composite validators' intentional deviation from `Check` protocol return type
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Consolidated redundant empty checks in ROUGE-L computation
|
||||||
|
- Fixed README example using incorrect property names (`grade_level` → `flesch_kincaid_grade`, `reading_ease` → `flesch_reading_ease`)
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
- Added Phase 10 (Portfolio Demos) to implementation plan: Streamlit demo and Jupyter notebooks
|
||||||
|
- Updated project plan with portfolio demo section
|
||||||
|
- Fixed potential crash in ROUGE metric when all references are empty after tokenisation
|
||||||
|
- Fixed potential division by zero in readability metric when text has no sentence endings
|
||||||
|
- Fixed unbounded cache growth in `SemanticSimilarity` by implementing LRU eviction with configurable max size
|
||||||
|
- Fixed mutable list aliasing in `AllOf` and `AnyOf` composite validators
|
||||||
|
- Fixed regex pattern validation in `ContainsValidator` and `ExcludesValidator` to fail at init time rather than during `check()`
|
||||||
|
- Fixed pytest plugin tests failing with duplicate plugin registration error
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Added `.score` property to `LexicalResult` for API consistency with other result types
|
||||||
|
- Added `cache_max_size` parameter to `SemanticSimilarity` (default: 1000 embeddings)
|
||||||
|
- Added test coverage for `core/config.py` and `core/logging.py` modules
|
||||||
|
|
||||||
|
## [0.1.0] — 2026-02-03
|
||||||
|
|
||||||
|
Initial release of Veritext, a semantic text validation framework for Python.
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
#### Core
|
||||||
|
|
||||||
- Project scaffold with pyproject.toml and development tooling
|
- Project scaffold with pyproject.toml and development tooling
|
||||||
- Core exception hierarchy (`VeritextError` and subclasses)
|
- Core exception hierarchy (`VeritextError` and subclasses)
|
||||||
- Core types: `ValidationContext`, `CheckResult`, `ValidationResult`
|
- Core types: `ValidationContext`, `CheckResult`, `ValidationResult`
|
||||||
- Word tokeniser with Unicode normalisation support
|
- Word tokeniser with Unicode normalisation support
|
||||||
- Configuration module with pydantic-settings
|
- Configuration module with pydantic-settings
|
||||||
- Structured logging with structlog
|
- Structured logging with structlog
|
||||||
|
|
||||||
|
#### Metrics
|
||||||
|
|
||||||
- Metrics module with `Metric` protocol, `AggregateStats`, and `BatchResult` types
|
- Metrics module with `Metric` protocol, `AggregateStats`, and `BatchResult` types
|
||||||
- BLEU metric implementation (BLEU-1 through BLEU-4 with brevity penalty)
|
- BLEU metric implementation (BLEU-1 through BLEU-4 with brevity penalty)
|
||||||
|
- ROUGE metric (ROUGE-1, ROUGE-2, ROUGE-L with precision/recall/F-measure)
|
||||||
- Lexical similarity metric (Jaccard similarity and token overlap)
|
- Lexical similarity metric (Jaccard similarity and token overlap)
|
||||||
|
- Flesch-Kincaid readability metrics (grade level and reading ease)
|
||||||
- Batch scoring with aggregate statistics for all metrics
|
- Batch scoring with aggregate statistics for all metrics
|
||||||
|
|
||||||
|
#### Validators
|
||||||
|
|
||||||
|
- Validators module with `Check` protocol for validation checks
|
||||||
|
- Metric-based validators: `BleuValidator`, `RougeValidator`, `LexicalValidator`
|
||||||
|
- Constraint validators: `LengthValidator`, `ReadabilityValidator`, `ContainsValidator`, `ExcludesValidator`
|
||||||
|
- Composite validators: `AllOf` (all checks must pass), `AnyOf` (any check must pass)
|
||||||
|
- Factory functions for clean validator API (`bleu()`, `rouge()`, `lexical()`, `length()`, `readability()`, `contains()`, `excludes()`, `all_of()`, `any_of()`)
|
||||||
|
|
||||||
|
#### Semantic Similarity
|
||||||
|
|
||||||
|
- Semantic similarity module with embedding-based text comparison (requires `veritext[semantic]` extra)
|
||||||
|
- `SemanticSimilarity` metric using sentence-transformers for semantic relatedness
|
||||||
|
- `SemanticValidator` for threshold-based semantic similarity validation
|
||||||
|
- `semantic()` factory function for creating semantic validators
|
||||||
|
- Embedding caching for performance optimisation in repeated comparisons
|
||||||
|
|
||||||
|
#### Pytest Plugin
|
||||||
|
|
||||||
|
- Native pytest plugin for CI/CD integration (entry point: `pytest11`)
|
||||||
|
- `validate_text()` assertion function for expressive test assertions
|
||||||
|
- `text_validation` marker for filtering validation tests
|
||||||
|
- Pytest fixtures: `text_validator` factory and `validation_context` helper
|
||||||
|
- Detailed failure messages with text preview and check diagnostics
|
||||||
|
|
||||||
|
#### Benchmarking
|
||||||
|
|
||||||
|
- Benchmark module for quality tracking and regression detection
|
||||||
|
- `Benchmark` class for evaluating text quality over time with metric storage
|
||||||
|
- `BenchmarkRun` and `RegressionReport` data models for tracking runs
|
||||||
|
- SQLite storage backend with WAL mode for concurrent access
|
||||||
|
- Rolling window baseline computation for historical comparison
|
||||||
|
- `check_regression()` for statistical comparison against baseline
|
||||||
|
- `assert_no_regression()` raises `RegressionDetectedError` for CI integration
|
||||||
|
- Customisable tolerance threshold and window size for regression detection
|
||||||
|
- Metadata support for tracking git SHA, model versions, etc.
|
||||||
|
|
||||||
|
#### CLI
|
||||||
|
|
||||||
|
- Command-line interface (CLI) via `veritext` command
|
||||||
|
- `veritext validate` command for inline and file-based text validation
|
||||||
|
- JSONL input format support for batch validation (`--file` option)
|
||||||
|
- Separate candidate/reference file support (`--reference-file` option)
|
||||||
|
- Multiple output formats: table (default), JSON, and simple text
|
||||||
|
- `veritext benchmark run` command for running evaluations and storing results
|
||||||
|
- `veritext benchmark show` command for viewing benchmark history
|
||||||
|
- `veritext benchmark check` command for regression detection with exit code 1 on failure
|
||||||
|
- Rich-formatted terminal output with tables and coloured panels
|
||||||
|
|
||||||
|
#### Documentation
|
||||||
|
|
||||||
|
- Comprehensive readme with usage examples
|
||||||
|
- Example scripts: basic validation, chatbot testing, benchmark regression
|
||||||
|
|||||||
@@ -871,6 +871,59 @@ uv run pytest --cov=src/veritext --cov-report=term-missing
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
### Phase 10: Portfolio Demos
|
||||||
|
|
||||||
|
**Goal:** Interactive demos for showcasing Veritext without installation.
|
||||||
|
|
||||||
|
**Step 1 — Streamlit Demo:**
|
||||||
|
|
||||||
|
Build a quick interactive web UI for general visitors.
|
||||||
|
|
||||||
|
- [ ] Create `demo/streamlit_app.py`
|
||||||
|
- [ ] Text input boxes (candidate + reference)
|
||||||
|
- [ ] Metric selector (BLEU, ROUGE, lexical, readability)
|
||||||
|
- [ ] Threshold sliders for pass/fail validation
|
||||||
|
- [ ] Results table with scores and status
|
||||||
|
- [ ] Deploy to homeserver (e.g., `veritext.kschappell.com`)
|
||||||
|
|
||||||
|
**Step 2 — Jupyter Notebook Collection:**
|
||||||
|
|
||||||
|
Deep-dive notebooks targeting data science and ML recruiters.
|
||||||
|
|
||||||
|
- [ ] Create `notebooks/` directory
|
||||||
|
- [ ] `01-metrics-overview.ipynb` — Introduction to each metric with visualisations
|
||||||
|
- [ ] `02-batch-evaluation.ipynb` — Evaluating model outputs at scale
|
||||||
|
- [ ] `03-regression-detection.ipynb` — Tracking quality over time
|
||||||
|
- [ ] `04-chatbot-validation.ipynb` — Real-world use case
|
||||||
|
|
||||||
|
**Step 3 — JupyterLite Deployment:**
|
||||||
|
|
||||||
|
Host notebooks as static files running in the browser.
|
||||||
|
|
||||||
|
- [ ] Configure JupyterLite build with veritext pre-installed
|
||||||
|
- [ ] Bundle notebooks into static site
|
||||||
|
- [ ] Deploy alongside Streamlit demo
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `demo/streamlit_app.py`
|
||||||
|
- `notebooks/01-metrics-overview.ipynb`
|
||||||
|
- `notebooks/02-batch-evaluation.ipynb`
|
||||||
|
- `notebooks/03-regression-detection.ipynb`
|
||||||
|
- `notebooks/04-chatbot-validation.ipynb`
|
||||||
|
- `notebooks/jupyterlite-config.json`
|
||||||
|
|
||||||
|
**Verification:**
|
||||||
|
```bash
|
||||||
|
# Streamlit
|
||||||
|
uv run streamlit run demo/streamlit_app.py
|
||||||
|
|
||||||
|
# JupyterLite (local preview)
|
||||||
|
jupyter lite build --contents notebooks/
|
||||||
|
jupyter lite serve
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Dependencies
|
## Dependencies
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
|
|||||||
@@ -488,3 +488,47 @@ benchmark.assert_no_regression(tolerance=0.03)
|
|||||||
|
|
||||||
5. **Natural portfolio narrative** — "I was building X and needed a better way to test
|
5. **Natural portfolio narrative** — "I was building X and needed a better way to test
|
||||||
it, so I built this tool." Every interviewer has faced similar problems.
|
it, so I built this tool." Every interviewer has faced similar problems.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Portfolio Demos (Future)
|
||||||
|
|
||||||
|
Interactive demos to showcase Veritext without requiring installation.
|
||||||
|
|
||||||
|
### Streamlit Demo
|
||||||
|
|
||||||
|
A quick interactive web UI for general visitors and recruiters.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Text input boxes (candidate + reference)
|
||||||
|
- Metric selector (BLEU, ROUGE, lexical, readability)
|
||||||
|
- Threshold sliders for pass/fail validation
|
||||||
|
- Results table with scores and status
|
||||||
|
|
||||||
|
**Deployment:** Self-hosted on homeserver (e.g., `veritext.kschappell.com`)
|
||||||
|
|
||||||
|
**Effort:** ~30 minutes
|
||||||
|
|
||||||
|
### Jupyter Notebook Collection
|
||||||
|
|
||||||
|
Deep-dive notebooks targeting data science and ML recruiters.
|
||||||
|
|
||||||
|
**Notebooks:**
|
||||||
|
|
||||||
|
| Notebook | Purpose |
|
||||||
|
|----------|---------|
|
||||||
|
| `01-metrics-overview.ipynb` | Introduction to each metric with visualisations |
|
||||||
|
| `02-batch-evaluation.ipynb` | Evaluating model outputs at scale, statistical analysis |
|
||||||
|
| `03-regression-detection.ipynb` | Tracking quality over time, detecting degradation |
|
||||||
|
| `04-chatbot-validation.ipynb` | Real-world use case: validating chatbot responses |
|
||||||
|
|
||||||
|
**Hosting:** JupyterLite (static files, runs in browser via WebAssembly)
|
||||||
|
|
||||||
|
**Deployment:** Self-hosted alongside Streamlit demo
|
||||||
|
|
||||||
|
**Why both:**
|
||||||
|
|
||||||
|
| Demo Type | Audience | Value |
|
||||||
|
|-----------|----------|-------|
|
||||||
|
| Streamlit | General visitors | Quick, interactive, no friction |
|
||||||
|
| Notebooks | Data/ML recruiters | Shows analytical depth, speaks their language |
|
||||||
|
|||||||
135
examples/basic_validation.py
Normal file
135
examples/basic_validation.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Basic text validation examples.
|
||||||
|
|
||||||
|
Demonstrates core Veritext functionality:
|
||||||
|
- Single metric scoring (BLEU, ROUGE)
|
||||||
|
- Validator usage with thresholds
|
||||||
|
- Composite validators (all_of, any_of)
|
||||||
|
- Constraint validators (length, readability)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.metrics import Bleu, Rouge
|
||||||
|
from veritext.validators import (
|
||||||
|
all_of,
|
||||||
|
any_of,
|
||||||
|
bleu,
|
||||||
|
contains,
|
||||||
|
excludes,
|
||||||
|
length,
|
||||||
|
readability,
|
||||||
|
rouge,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def metric_scoring_example() -> None:
|
||||||
|
"""Score text using individual metrics."""
|
||||||
|
candidate = "The quick brown fox jumps over the lazy dog."
|
||||||
|
reference = "A fast brown fox leaps over a sleepy dog."
|
||||||
|
|
||||||
|
# BLEU scoring (translation quality)
|
||||||
|
bleu_metric = Bleu()
|
||||||
|
bleu_result = bleu_metric.score(candidate, reference)
|
||||||
|
print("BLEU Scores:")
|
||||||
|
print(f" BLEU-1: {bleu_result.bleu1:.3f}")
|
||||||
|
print(f" BLEU-4: {bleu_result.bleu4:.3f}")
|
||||||
|
print(f" Brevity penalty: {bleu_result.brevity_penalty:.3f}")
|
||||||
|
|
||||||
|
# ROUGE scoring (summary quality)
|
||||||
|
rouge_metric = Rouge()
|
||||||
|
rouge_result = rouge_metric.score(candidate, reference)
|
||||||
|
print("\nROUGE Scores:")
|
||||||
|
print(f" ROUGE-1 F1: {rouge_result.rouge1.fmeasure:.3f}")
|
||||||
|
print(f" ROUGE-L F1: {rouge_result.rouge_l.fmeasure:.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
def validator_example() -> None:
|
||||||
|
"""Use validators to make pass/fail decisions."""
|
||||||
|
reference = "Machine learning models require training data."
|
||||||
|
candidate = "ML models need training data to learn patterns."
|
||||||
|
|
||||||
|
context = ValidationContext(reference=reference)
|
||||||
|
|
||||||
|
# BLEU validator with minimum threshold
|
||||||
|
bleu_validator = bleu(min_score=0.3)
|
||||||
|
result = bleu_validator.check(candidate, context)
|
||||||
|
print(f"\nBLEU validation (min 0.3): {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
# ROUGE validator
|
||||||
|
rouge_validator = rouge(min_score=0.5)
|
||||||
|
result = rouge_validator.check(candidate, context)
|
||||||
|
print(f"ROUGE validation (min 0.5): {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
|
||||||
|
def composite_validator_example() -> None:
|
||||||
|
"""Combine validators with all_of and any_of."""
|
||||||
|
reference = "The product launch exceeded all expectations."
|
||||||
|
candidate = "The product release performed beyond expectations."
|
||||||
|
|
||||||
|
context = ValidationContext(reference=reference)
|
||||||
|
|
||||||
|
# All checks must pass
|
||||||
|
strict_validator = all_of(
|
||||||
|
[
|
||||||
|
bleu(min_score=0.2),
|
||||||
|
rouge(min_score=0.4),
|
||||||
|
length(max_chars=100),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = strict_validator.check(candidate, context)
|
||||||
|
print(f"\nStrict (all_of): {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
if not result.passed:
|
||||||
|
print(f" Failures: {result.failure_summary}")
|
||||||
|
|
||||||
|
# At least one check must pass
|
||||||
|
flexible_validator = any_of(
|
||||||
|
[
|
||||||
|
bleu(min_score=0.8), # Unlikely to pass
|
||||||
|
rouge(min_score=0.4), # More likely
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = flexible_validator.check(candidate, context)
|
||||||
|
print(f"Flexible (any_of): {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
|
||||||
|
def constraint_validator_example() -> None:
|
||||||
|
"""Use constraint validators for text properties."""
|
||||||
|
text = "This short guide explains the basics clearly."
|
||||||
|
context = ValidationContext() # No reference needed for constraints
|
||||||
|
|
||||||
|
# Length constraints
|
||||||
|
length_validator = length(min_chars=20, max_chars=100, min_words=5, max_words=20)
|
||||||
|
result = length_validator.check(text, context)
|
||||||
|
print(f"\nLength check: {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
# Readability (Flesch-Kincaid)
|
||||||
|
readability_validator = readability(max_grade=10.0)
|
||||||
|
result = readability_validator.check(text, context)
|
||||||
|
print(f"Readability (grade <= 10): {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
# Content patterns
|
||||||
|
contains_validator = contains(patterns=["guide", "basics"])
|
||||||
|
result = contains_validator.check(text, context)
|
||||||
|
print(f"Contains required terms: {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
excludes_validator = excludes(patterns=["error", "warning"])
|
||||||
|
result = excludes_validator.check(text, context)
|
||||||
|
print(f"Excludes forbidden terms: {'PASS' if result.passed else 'FAIL'}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Run all examples."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Veritext Basic Validation Examples")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
metric_scoring_example()
|
||||||
|
validator_example()
|
||||||
|
composite_validator_example()
|
||||||
|
constraint_validator_example()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All examples completed.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
160
examples/benchmark_regression.py
Normal file
160
examples/benchmark_regression.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""Benchmark quality tracking with regression detection.
|
||||||
|
|
||||||
|
Demonstrates Veritext's benchmark module for CI integration:
|
||||||
|
- Creating a benchmark suite
|
||||||
|
- Running evaluations and storing results
|
||||||
|
- Checking for quality regression
|
||||||
|
- CI integration pattern with exit codes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from veritext.benchmark import Benchmark
|
||||||
|
from veritext.core.exceptions import RegressionDetectedError
|
||||||
|
|
||||||
|
|
||||||
|
def create_sample_data() -> tuple[list[str], list[str]]:
|
||||||
|
"""Create sample candidate/reference pairs for benchmarking."""
|
||||||
|
# Simulated summarisation outputs and references
|
||||||
|
candidates = [
|
||||||
|
"The new policy aims to reduce carbon emissions by 50% by 2030.",
|
||||||
|
"Scientists discovered a new species of deep-sea fish.",
|
||||||
|
"The company reported record profits in the third quarter.",
|
||||||
|
"Researchers developed a breakthrough treatment for the disease.",
|
||||||
|
"The city plans to expand public transportation routes.",
|
||||||
|
]
|
||||||
|
references = [
|
||||||
|
"The policy targets a 50% reduction in carbon emissions by 2030.",
|
||||||
|
"A new deep-sea fish species was discovered by marine biologists.",
|
||||||
|
"Record profits were announced by the company for Q3.",
|
||||||
|
"A breakthrough disease treatment was developed by researchers.",
|
||||||
|
"Public transport expansion is planned for the city.",
|
||||||
|
]
|
||||||
|
return candidates, references
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark_example() -> None:
|
||||||
|
"""Run a benchmark evaluation and view results."""
|
||||||
|
# Use a temp directory for this example
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
storage_path = Path(tmpdir) / "benchmarks"
|
||||||
|
|
||||||
|
# Create benchmark suite
|
||||||
|
bench = Benchmark("summariser_quality", storage_path=storage_path)
|
||||||
|
|
||||||
|
candidates, references = create_sample_data()
|
||||||
|
|
||||||
|
# Run evaluation
|
||||||
|
print("Running benchmark evaluation...")
|
||||||
|
run = bench.evaluate(
|
||||||
|
candidates=candidates,
|
||||||
|
references=references,
|
||||||
|
metrics=["rouge_l", "bleu4"],
|
||||||
|
metadata={"model": "v1.0", "dataset": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nBenchmark run completed:")
|
||||||
|
print(f" Run ID: {run.id[:8]}...")
|
||||||
|
print(f" Samples: {run.sample_count}")
|
||||||
|
print(" Metrics:")
|
||||||
|
for name, value in run.metrics.items():
|
||||||
|
print(f" {name}: {value:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def regression_detection_example() -> None:
|
||||||
|
"""Demonstrate regression detection with historical comparison."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
storage_path = Path(tmpdir) / "benchmarks"
|
||||||
|
bench = Benchmark("summariser_quality", storage_path=storage_path)
|
||||||
|
|
||||||
|
candidates, references = create_sample_data()
|
||||||
|
|
||||||
|
# Simulate historical runs with stable quality
|
||||||
|
print("\nBuilding baseline with historical runs...")
|
||||||
|
for i in range(5):
|
||||||
|
bench.evaluate(
|
||||||
|
candidates=candidates,
|
||||||
|
references=references,
|
||||||
|
metrics=["rouge_l", "bleu4"],
|
||||||
|
metadata={"run": f"baseline_{i}"},
|
||||||
|
)
|
||||||
|
print(f" Baseline run {i + 1} recorded")
|
||||||
|
|
||||||
|
# Check regression (no degradation expected)
|
||||||
|
report = bench.check_regression(tolerance=0.05, window=5)
|
||||||
|
print(f"\nRegression check: {'DETECTED' if report.detected else 'NONE'}")
|
||||||
|
|
||||||
|
# Simulate a degraded model
|
||||||
|
print("\nSimulating degraded model output...")
|
||||||
|
degraded_candidates = [
|
||||||
|
"Policy carbon emissions.", # Much shorter/worse
|
||||||
|
"Fish discovered.",
|
||||||
|
"Company profits.",
|
||||||
|
"Treatment developed.",
|
||||||
|
"Transport expansion.",
|
||||||
|
]
|
||||||
|
bench.evaluate(
|
||||||
|
candidates=degraded_candidates,
|
||||||
|
references=references,
|
||||||
|
metrics=["rouge_l", "bleu4"],
|
||||||
|
metadata={"model": "v1.1-broken"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check regression (should detect)
|
||||||
|
report = bench.check_regression(tolerance=0.05, window=5)
|
||||||
|
print(f"Regression check: {'DETECTED' if report.detected else 'NONE'}")
|
||||||
|
if report.detected:
|
||||||
|
print("\nRegression details:")
|
||||||
|
for metric, delta in report.deltas.items():
|
||||||
|
baseline = report.baseline.get(metric, 0)
|
||||||
|
current = report.current.get(metric, 0)
|
||||||
|
print(f" {metric}: {baseline:.4f} -> {current:.4f} ({delta:+.4f})")
|
||||||
|
|
||||||
|
|
||||||
|
def ci_integration_example() -> None:
|
||||||
|
"""CI integration pattern using assert_no_regression()."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
storage_path = Path(tmpdir) / "benchmarks"
|
||||||
|
bench = Benchmark("ci_check", storage_path=storage_path)
|
||||||
|
|
||||||
|
candidates, references = create_sample_data()
|
||||||
|
|
||||||
|
# Build baseline
|
||||||
|
for _ in range(3):
|
||||||
|
bench.evaluate(candidates, references, metrics=["rouge_l"])
|
||||||
|
|
||||||
|
# Simulate CI check
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("CI Integration Example")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
print("\nRunning evaluation...")
|
||||||
|
bench.evaluate(candidates, references, metrics=["rouge_l"])
|
||||||
|
|
||||||
|
print("Checking for regression...")
|
||||||
|
try:
|
||||||
|
bench.assert_no_regression(tolerance=0.05, window=3)
|
||||||
|
print("No regression detected.")
|
||||||
|
print("CI status: EXIT 0")
|
||||||
|
except RegressionDetectedError as e:
|
||||||
|
print(f"Regression detected: {e}")
|
||||||
|
print("CI status: EXIT 1")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Run all benchmark examples."""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Veritext Benchmark & Regression Detection Examples")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
run_benchmark_example()
|
||||||
|
regression_detection_example()
|
||||||
|
ci_integration_example()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("All examples completed.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
140
examples/chatbot_testing.py
Normal file
140
examples/chatbot_testing.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Pytest integration for chatbot testing.
|
||||||
|
|
||||||
|
Demonstrates Veritext's pytest plugin for testing chatbot responses:
|
||||||
|
- validate_text() assertion function
|
||||||
|
- Custom test fixtures
|
||||||
|
- Test organisation with markers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.pytest_plugin import validate_text
|
||||||
|
|
||||||
|
# Sample chatbot responses for testing
|
||||||
|
CHATBOT_RESPONSES = {
|
||||||
|
"greeting": {
|
||||||
|
"input": "Hello!",
|
||||||
|
"response": "Hi there! How can I help you today?",
|
||||||
|
"expected_keywords": ["help", "hi"],
|
||||||
|
},
|
||||||
|
"weather": {
|
||||||
|
"input": "What's the weather like?",
|
||||||
|
"response": "I don't have access to real-time weather data, but you can "
|
||||||
|
"check a weather service like weather.com for current conditions.",
|
||||||
|
"expected_keywords": ["weather", "check"],
|
||||||
|
},
|
||||||
|
"farewell": {
|
||||||
|
"input": "Goodbye!",
|
||||||
|
"response": "Goodbye! Have a great day!",
|
||||||
|
"expected_keywords": ["goodbye", "day"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for common test setup
|
||||||
|
@pytest.fixture
|
||||||
|
def greeting_response() -> str:
|
||||||
|
"""Provide a sample greeting response."""
|
||||||
|
return CHATBOT_RESPONSES["greeting"]["response"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def weather_response() -> str:
|
||||||
|
"""Provide a sample weather response."""
|
||||||
|
return CHATBOT_RESPONSES["weather"]["response"]
|
||||||
|
|
||||||
|
|
||||||
|
# Basic validation tests
|
||||||
|
class TestResponseQuality:
|
||||||
|
"""Test chatbot response quality using Veritext."""
|
||||||
|
|
||||||
|
def test_greeting_length(self, greeting_response: str) -> None:
|
||||||
|
"""Greeting responses should be concise."""
|
||||||
|
validate_text(
|
||||||
|
greeting_response,
|
||||||
|
min_length=10,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_greeting_readability(self, greeting_response: str) -> None:
|
||||||
|
"""Greeting responses should be easy to read."""
|
||||||
|
validate_text(
|
||||||
|
greeting_response,
|
||||||
|
max_reading_grade=8.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_greeting_contains_keywords(self, greeting_response: str) -> None:
|
||||||
|
"""Greeting should contain expected terms."""
|
||||||
|
validate_text(
|
||||||
|
greeting_response,
|
||||||
|
must_contain=["help"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_weather_response_quality(self, weather_response: str) -> None:
|
||||||
|
"""Weather response should be informative and readable."""
|
||||||
|
validate_text(
|
||||||
|
weather_response,
|
||||||
|
min_length=50,
|
||||||
|
max_length=500,
|
||||||
|
max_reading_grade=10.0,
|
||||||
|
must_contain=["weather"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests with reference comparison
|
||||||
|
class TestResponseSimilarity:
|
||||||
|
"""Test response similarity against reference texts."""
|
||||||
|
|
||||||
|
def test_greeting_similarity(self) -> None:
|
||||||
|
"""Greeting should match expected style."""
|
||||||
|
reference = "Hello! How may I assist you today?"
|
||||||
|
response = CHATBOT_RESPONSES["greeting"]["response"]
|
||||||
|
|
||||||
|
validate_text(
|
||||||
|
response,
|
||||||
|
reference=reference,
|
||||||
|
min_rouge=0.3, # Allow variation in wording
|
||||||
|
min_length=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_farewell_similarity(self) -> None:
|
||||||
|
"""Farewell should match expected style."""
|
||||||
|
reference = "Goodbye! Have a wonderful day!"
|
||||||
|
response = CHATBOT_RESPONSES["farewell"]["response"]
|
||||||
|
|
||||||
|
validate_text(
|
||||||
|
response,
|
||||||
|
reference=reference,
|
||||||
|
min_rouge=0.5,
|
||||||
|
must_contain=["goodbye"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Content safety tests
|
||||||
|
class TestContentSafety:
|
||||||
|
"""Test responses for inappropriate content."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("response_key", ["greeting", "weather", "farewell"])
|
||||||
|
def test_no_profanity(self, response_key: str) -> None:
|
||||||
|
"""Responses should not contain profanity."""
|
||||||
|
response = CHATBOT_RESPONSES[response_key]["response"]
|
||||||
|
validate_text(
|
||||||
|
response,
|
||||||
|
must_exclude=["damn", "hell", "crap"],
|
||||||
|
min_length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("response_key", ["greeting", "weather", "farewell"])
|
||||||
|
def test_no_harmful_content(self, response_key: str) -> None:
|
||||||
|
"""Responses should not contain harmful instructions."""
|
||||||
|
response = CHATBOT_RESPONSES[response_key]["response"]
|
||||||
|
validate_text(
|
||||||
|
response,
|
||||||
|
must_exclude=["hack", "exploit", "attack"],
|
||||||
|
min_length=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Run tests when executed directly
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "veritext"
|
name = "veritext"
|
||||||
version = "0.1.0-dev"
|
version = "0.1.0.dev0"
|
||||||
description = "Semantic text validation framework"
|
description = "Semantic text validation framework"
|
||||||
readme = "readme.md"
|
readme = "readme.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
|
|||||||
386
readme.md
386
readme.md
@@ -2,48 +2,398 @@
|
|||||||
|
|
||||||
Semantic text validation framework for Python.
|
Semantic text validation framework for Python.
|
||||||
|
|
||||||
Validates text outputs against quality criteria using metrics like BLEU, ROUGE,
|
Veritext validates text outputs against quality criteria using metrics like BLEU,
|
||||||
and semantic similarity. Designed for developers building systems that produce
|
ROUGE, and semantic similarity. Designed for developers building systems that produce
|
||||||
text (chatbots, content generators, summarisation tools) who need automated
|
text (chatbots, content generators, summarisation tools) who need automated quality
|
||||||
quality assurance beyond simple string matching.
|
assurance beyond simple string matching.
|
||||||
|
|
||||||
## Status
|
## Features
|
||||||
|
|
||||||
Under active development. See [changelog.md](changelog.md) for progress.
|
- **Multiple metrics** — BLEU, ROUGE, lexical similarity, readability, semantic
|
||||||
|
embeddings
|
||||||
|
- **Composable validators** — Build complex checks from simple primitives
|
||||||
|
- **Native pytest integration** — `validate_text()` assertion for test suites
|
||||||
|
- **Quality benchmarking** — Track metrics over time with regression detection
|
||||||
|
- **CLI tools** — Command-line validation and benchmark management
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install veritext
|
pip install veritext
|
||||||
|
|
||||||
# With semantic similarity support
|
# With semantic similarity support (sentence-transformers)
|
||||||
pip install veritext[semantic]
|
pip install veritext[semantic]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from veritext import validators as v
|
|
||||||
from veritext.core.types import ValidationContext
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import all_of, bleu, length, rouge
|
||||||
|
|
||||||
# Create validators
|
# Create a validator
|
||||||
validator = v.all_of([
|
validator = all_of([
|
||||||
v.bleu(min_score=0.7),
|
bleu(min_score=0.5),
|
||||||
v.length(max_chars=500),
|
rouge(min_score=0.6),
|
||||||
|
length(max_chars=500),
|
||||||
])
|
])
|
||||||
|
|
||||||
# Validate text
|
# Validate text
|
||||||
context = ValidationContext(reference="The cat sat on the mat.")
|
context = ValidationContext(reference="The quick brown fox jumps over the lazy dog.")
|
||||||
result = validator.check("A cat is sitting on the mat.", context)
|
result = validator.check("A fast brown fox leaps over a sleepy dog.", context)
|
||||||
|
|
||||||
if not result.passed:
|
if result.passed:
|
||||||
|
print("Validation passed!")
|
||||||
|
else:
|
||||||
print(result.failure_summary)
|
print(result.failure_summary)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Documentation
|
## Metrics
|
||||||
|
|
||||||
- [Project Plan](docs/project-plan.md)
|
Veritext provides several metrics for text evaluation.
|
||||||
- [Implementation Plan](docs/implementation-plan.md)
|
|
||||||
|
### BLEU
|
||||||
|
|
||||||
|
Measures n-gram precision against reference text. Useful for translation and
|
||||||
|
generation quality.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.metrics import Bleu
|
||||||
|
|
||||||
|
bleu = Bleu()
|
||||||
|
result = bleu.score(
|
||||||
|
candidate="The cat sat on the mat.",
|
||||||
|
reference="A cat is sitting on the mat.",
|
||||||
|
)
|
||||||
|
print(f"BLEU-4: {result.bleu4:.3f}") # Uses 1-4 gram precision
|
||||||
|
print(f"BLEU-1: {result.bleu1:.3f}") # Unigram precision only
|
||||||
|
```
|
||||||
|
|
||||||
|
### ROUGE
|
||||||
|
|
||||||
|
Measures recall-oriented overlap with reference text. Useful for summarisation.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.metrics import Rouge
|
||||||
|
|
||||||
|
rouge = Rouge()
|
||||||
|
result = rouge.score(
|
||||||
|
candidate="Scientists found a new planet.",
|
||||||
|
reference="Researchers discovered a new planet in the solar system.",
|
||||||
|
)
|
||||||
|
print(f"ROUGE-1 F1: {result.rouge1.fmeasure:.3f}") # Unigram overlap
|
||||||
|
print(f"ROUGE-L F1: {result.rouge_l.fmeasure:.3f}") # Longest common subsequence
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lexical Similarity
|
||||||
|
|
||||||
|
Measures token overlap using Jaccard similarity.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.metrics import Lexical
|
||||||
|
|
||||||
|
lexical = Lexical()
|
||||||
|
result = lexical.score(
|
||||||
|
candidate="The quick brown fox",
|
||||||
|
reference="The fast brown fox",
|
||||||
|
)
|
||||||
|
print(f"Jaccard: {result.jaccard:.3f}")
|
||||||
|
print(f"Token overlap: {result.token_overlap:.3f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Readability
|
||||||
|
|
||||||
|
Computes Flesch-Kincaid scores for text complexity.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.metrics import Readability
|
||||||
|
|
||||||
|
readability = Readability()
|
||||||
|
result = readability.score("This is a simple sentence.")
|
||||||
|
print(f"Grade level: {result.flesch_kincaid_grade:.1f}")
|
||||||
|
print(f"Reading ease: {result.flesch_reading_ease:.1f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Semantic Similarity (Optional)
|
||||||
|
|
||||||
|
Requires `pip install veritext[semantic]`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.semantic import SemanticSimilarity
|
||||||
|
|
||||||
|
semantic = SemanticSimilarity()
|
||||||
|
result = semantic.score(
|
||||||
|
candidate="The dog is running in the park.",
|
||||||
|
reference="A canine is jogging through the garden.",
|
||||||
|
)
|
||||||
|
print(f"Similarity: {result.score:.3f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Validators
|
||||||
|
|
||||||
|
Validators wrap metrics with thresholds to make pass/fail decisions.
|
||||||
|
|
||||||
|
### Metric-Based Validators
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import bleu, lexical, rouge
|
||||||
|
|
||||||
|
context = ValidationContext(reference="Reference text here.")
|
||||||
|
|
||||||
|
# BLEU validation
|
||||||
|
validator = bleu(min_score=0.5, variant=4) # BLEU-4
|
||||||
|
result = validator.check("Candidate text here.", context)
|
||||||
|
|
||||||
|
# ROUGE validation
|
||||||
|
validator = rouge(min_score=0.6, variant="l") # ROUGE-L
|
||||||
|
result = validator.check("Candidate text here.", context)
|
||||||
|
|
||||||
|
# Lexical validation
|
||||||
|
validator = lexical(min_jaccard=0.3, min_overlap=0.5)
|
||||||
|
result = validator.check("Candidate text here.", context)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Constraint Validators
|
||||||
|
|
||||||
|
These don't require reference text.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import contains, excludes, length, readability
|
||||||
|
|
||||||
|
context = ValidationContext() # No reference needed
|
||||||
|
|
||||||
|
# Length constraints
|
||||||
|
validator = length(min_chars=50, max_chars=500, min_words=10)
|
||||||
|
result = validator.check("Your text here...", context)
|
||||||
|
|
||||||
|
# Readability constraints
|
||||||
|
validator = readability(max_grade=8.0, min_ease=60.0)
|
||||||
|
result = validator.check("Your text here...", context)
|
||||||
|
|
||||||
|
# Content requirements
|
||||||
|
validator = contains(patterns=["important", "keyword"])
|
||||||
|
result = validator.check("This important text has a keyword.", context)
|
||||||
|
|
||||||
|
# Content exclusions
|
||||||
|
validator = excludes(patterns=["forbidden", "banned"])
|
||||||
|
result = validator.check("This text is clean.", context)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Composite Validators
|
||||||
|
|
||||||
|
Combine multiple checks with logical operators.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.validators import all_of, any_of, bleu, length, rouge
|
||||||
|
|
||||||
|
# All checks must pass
|
||||||
|
validator = all_of([
|
||||||
|
bleu(min_score=0.5),
|
||||||
|
rouge(min_score=0.6),
|
||||||
|
length(max_chars=500),
|
||||||
|
])
|
||||||
|
|
||||||
|
# At least one check must pass
|
||||||
|
validator = any_of([
|
||||||
|
bleu(min_score=0.7),
|
||||||
|
rouge(min_score=0.7),
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pytest Plugin
|
||||||
|
|
||||||
|
Veritext provides native pytest integration for testing text quality.
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.pytest_plugin import validate_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_quality():
|
||||||
|
response = "This is a helpful response to your question."
|
||||||
|
|
||||||
|
validate_text(
|
||||||
|
response,
|
||||||
|
min_length=20,
|
||||||
|
max_length=200,
|
||||||
|
max_reading_grade=10.0,
|
||||||
|
must_contain=["helpful"],
|
||||||
|
must_exclude=["error", "sorry"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary_similarity():
|
||||||
|
summary = "Scientists discovered a new planet."
|
||||||
|
reference = "Researchers found a new planet in our solar system."
|
||||||
|
|
||||||
|
validate_text(
|
||||||
|
summary,
|
||||||
|
reference=reference,
|
||||||
|
min_rouge=0.5,
|
||||||
|
min_length=10,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Parameters
|
||||||
|
|
||||||
|
| Parameter | Description |
|
||||||
|
|-----------|-------------|
|
||||||
|
| `reference` | Reference text for comparison metrics |
|
||||||
|
| `min_bleu` | Minimum BLEU-4 score (0.0-1.0) |
|
||||||
|
| `min_rouge` | Minimum ROUGE-L F1 score (0.0-1.0) |
|
||||||
|
| `min_semantic` | Minimum semantic similarity (0.0-1.0) |
|
||||||
|
| `min_length` | Minimum character count |
|
||||||
|
| `max_length` | Maximum character count |
|
||||||
|
| `max_reading_grade` | Maximum Flesch-Kincaid grade level |
|
||||||
|
| `must_contain` | List of required patterns |
|
||||||
|
| `must_exclude` | List of forbidden patterns |
|
||||||
|
|
||||||
|
## Benchmarking
|
||||||
|
|
||||||
|
Track text quality over time and detect regressions.
|
||||||
|
|
||||||
|
### Running Benchmarks
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.benchmark import Benchmark
|
||||||
|
|
||||||
|
# Create a benchmark suite
|
||||||
|
bench = Benchmark("summariser_quality", storage_path="benchmarks/")
|
||||||
|
|
||||||
|
# Evaluate a batch of outputs
|
||||||
|
candidates = ["Summary 1...", "Summary 2...", "Summary 3..."]
|
||||||
|
references = ["Reference 1...", "Reference 2...", "Reference 3..."]
|
||||||
|
|
||||||
|
run = bench.evaluate(
|
||||||
|
candidates=candidates,
|
||||||
|
references=references,
|
||||||
|
metrics=["rouge_l", "bleu4"],
|
||||||
|
metadata={"model": "v1.2", "git_sha": "abc123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"ROUGE-L: {run.metrics['rouge_l']:.4f}")
|
||||||
|
print(f"BLEU-4: {run.metrics['bleu4']:.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Regression Detection
|
||||||
|
|
||||||
|
```python
|
||||||
|
from veritext.benchmark import Benchmark
|
||||||
|
from veritext.core.exceptions import RegressionDetectedError
|
||||||
|
|
||||||
|
bench = Benchmark("summariser_quality")
|
||||||
|
|
||||||
|
# Check for regression against historical baseline
|
||||||
|
report = bench.check_regression(tolerance=0.05, window=10)
|
||||||
|
if report.detected:
|
||||||
|
print("Quality regression detected!")
|
||||||
|
for metric, delta in report.deltas.items():
|
||||||
|
print(f" {metric}: {delta:+.4f}")
|
||||||
|
|
||||||
|
# Or raise an exception for CI integration
|
||||||
|
try:
|
||||||
|
bench.assert_no_regression(tolerance=0.05)
|
||||||
|
except RegressionDetectedError as e:
|
||||||
|
print(f"CI failure: {e}")
|
||||||
|
exit(1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Viewing History
|
||||||
|
|
||||||
|
```python
|
||||||
|
bench = Benchmark("summariser_quality")
|
||||||
|
|
||||||
|
for run in bench.get_history(limit=10):
|
||||||
|
print(f"{run.timestamp}: rouge_l={run.metrics.get('rouge_l', 0):.4f}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
Veritext provides a command-line interface for validation and benchmarking.
|
||||||
|
|
||||||
|
### Validate Text
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Inline validation
|
||||||
|
veritext validate "Candidate text" -r "Reference text" -m bleu,rouge
|
||||||
|
|
||||||
|
# File-based batch validation (JSONL with "candidate" and "reference" fields)
|
||||||
|
veritext validate -f outputs.jsonl -m bleu,rouge,lexical
|
||||||
|
|
||||||
|
# With threshold for pass/fail
|
||||||
|
veritext validate "Text" -r "Reference" -m bleu -t 0.5 -o simple
|
||||||
|
|
||||||
|
# Output formats: table (default), json, simple
|
||||||
|
veritext validate "Text" -r "Reference" -m bleu -o json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run a benchmark evaluation
|
||||||
|
veritext benchmark run my_bench -f data.jsonl -m rouge_l,bleu4
|
||||||
|
|
||||||
|
# View benchmark history
|
||||||
|
veritext benchmark show my_bench --last 10
|
||||||
|
|
||||||
|
# Check for regression (exits with code 1 if detected)
|
||||||
|
veritext benchmark check my_bench --tolerance 0.05 --window 10
|
||||||
|
```
|
||||||
|
|
||||||
|
### JSONL Format
|
||||||
|
|
||||||
|
For file-based operations, use JSONL with `candidate` and `reference` fields:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"candidate": "Model output 1", "reference": "Expected output 1"}
|
||||||
|
{"candidate": "Model output 2", "reference": "Expected output 2"}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Veritext uses environment variables for configuration:
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `VERITEXT_LOG_LEVEL` | `INFO` | Logging level |
|
||||||
|
| `VERITEXT_LOG_FORMAT` | `console` | Log format (`console` or `json`) |
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://gitea.kschappell.com/kschappell/veritext.git
|
||||||
|
cd veritext
|
||||||
|
uv sync --all-extras
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quality Checks
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
uv run ruff check .
|
||||||
|
|
||||||
|
# Formatting
|
||||||
|
uv run ruff format --check .
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
uv run mypy src/
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python examples/basic_validation.py
|
||||||
|
uv run pytest examples/chatbot_testing.py -v
|
||||||
|
uv run python examples/benchmark_regression.py
|
||||||
|
```
|
||||||
|
|
||||||
## Licence
|
## Licence
|
||||||
|
|
||||||
|
|||||||
12
src/veritext/benchmark/__init__.py
Normal file
12
src/veritext/benchmark/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""Benchmark module for quality tracking and regression detection."""
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
from veritext.benchmark.runner import Benchmark
|
||||||
|
from veritext.benchmark.storage import BenchmarkStorage
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Benchmark",
|
||||||
|
"BenchmarkRun",
|
||||||
|
"BenchmarkStorage",
|
||||||
|
"RegressionReport",
|
||||||
|
]
|
||||||
72
src/veritext/benchmark/models.py
Normal file
72
src/veritext/benchmark/models.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Benchmark data models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkRun(BaseModel):
|
||||||
|
"""Record of a single benchmark execution."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""UUID for this run."""
|
||||||
|
|
||||||
|
benchmark_name: str
|
||||||
|
"""Name identifying this benchmark suite."""
|
||||||
|
|
||||||
|
timestamp: datetime
|
||||||
|
"""When the benchmark was executed."""
|
||||||
|
|
||||||
|
veritext_version: str
|
||||||
|
"""Version of veritext used."""
|
||||||
|
|
||||||
|
metrics: dict[str, float]
|
||||||
|
"""Metric results, e.g. {"rouge_l": 0.82, "bleu4": 0.71}."""
|
||||||
|
|
||||||
|
sample_count: int
|
||||||
|
"""Number of samples evaluated."""
|
||||||
|
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Optional metadata (git_sha, model version, etc.)."""
|
||||||
|
|
||||||
|
|
||||||
|
class RegressionReport(BaseModel):
|
||||||
|
"""Report comparing current run against baseline."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
detected: bool
|
||||||
|
"""Whether a regression was detected."""
|
||||||
|
|
||||||
|
baseline: dict[str, float]
|
||||||
|
"""Baseline metric values (rolling average)."""
|
||||||
|
|
||||||
|
current: dict[str, float]
|
||||||
|
"""Current run metric values."""
|
||||||
|
|
||||||
|
deltas: dict[str, float]
|
||||||
|
"""Difference from baseline (negative = regression)."""
|
||||||
|
|
||||||
|
tolerance: float
|
||||||
|
"""Tolerance threshold used for detection."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Human-readable summary of the report."""
|
||||||
|
if not self.detected:
|
||||||
|
return "No regression detected. All metrics within tolerance."
|
||||||
|
|
||||||
|
regressions = [
|
||||||
|
f" {metric}: {self.current.get(metric, 0.0):.4f} "
|
||||||
|
f"(baseline: {self.baseline.get(metric, 0.0):.4f}, "
|
||||||
|
f"delta: {delta:+.4f})"
|
||||||
|
for metric, delta in self.deltas.items()
|
||||||
|
if delta < -self.tolerance
|
||||||
|
]
|
||||||
|
|
||||||
|
return f"Regression detected (tolerance: {self.tolerance:.2%}):\n" + "\n".join(
|
||||||
|
regressions
|
||||||
|
)
|
||||||
87
src/veritext/benchmark/regression.py
Normal file
87
src/veritext/benchmark/regression.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Regression detection using rolling window comparison."""
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
|
||||||
|
|
||||||
|
def compute_baseline(
|
||||||
|
runs: list[BenchmarkRun],
|
||||||
|
window: int = 10,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute rolling average baseline from recent runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runs: List of benchmark runs (most recent first).
|
||||||
|
window: Number of runs to include in the baseline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of metric names to their average values.
|
||||||
|
"""
|
||||||
|
if not runs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Take up to `window` runs
|
||||||
|
recent_runs = runs[:window]
|
||||||
|
|
||||||
|
# Collect all metric values
|
||||||
|
metric_values: dict[str, list[float]] = {}
|
||||||
|
for run in recent_runs:
|
||||||
|
for metric_name, value in run.metrics.items():
|
||||||
|
if metric_name not in metric_values:
|
||||||
|
metric_values[metric_name] = []
|
||||||
|
metric_values[metric_name].append(value)
|
||||||
|
|
||||||
|
# Compute averages
|
||||||
|
return {
|
||||||
|
metric: sum(values) / len(values) for metric, values in metric_values.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_regression(
|
||||||
|
current: dict[str, float],
|
||||||
|
baseline: dict[str, float],
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
) -> RegressionReport:
|
||||||
|
"""
|
||||||
|
Compare current metrics against baseline.
|
||||||
|
|
||||||
|
A regression is detected if any metric drops by more than the tolerance
|
||||||
|
threshold (relative to its baseline value).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current: Current metric values.
|
||||||
|
baseline: Baseline metric values.
|
||||||
|
tolerance: Maximum allowed drop before regression is flagged (e.g., 0.05 = 5%).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RegressionReport with comparison results.
|
||||||
|
"""
|
||||||
|
if not baseline:
|
||||||
|
# No baseline means no regression possible
|
||||||
|
return RegressionReport(
|
||||||
|
detected=False,
|
||||||
|
baseline=baseline,
|
||||||
|
current=current,
|
||||||
|
deltas={},
|
||||||
|
tolerance=tolerance,
|
||||||
|
)
|
||||||
|
|
||||||
|
deltas: dict[str, float] = {}
|
||||||
|
detected = False
|
||||||
|
|
||||||
|
for metric, baseline_value in baseline.items():
|
||||||
|
current_value = current.get(metric, 0.0)
|
||||||
|
delta = current_value - baseline_value
|
||||||
|
deltas[metric] = delta
|
||||||
|
|
||||||
|
# Check if this metric regressed beyond tolerance
|
||||||
|
if delta < -tolerance:
|
||||||
|
detected = True
|
||||||
|
|
||||||
|
return RegressionReport(
|
||||||
|
detected=detected,
|
||||||
|
baseline=baseline,
|
||||||
|
current=current,
|
||||||
|
deltas=deltas,
|
||||||
|
tolerance=tolerance,
|
||||||
|
)
|
||||||
186
src/veritext/benchmark/runner.py
Normal file
186
src/veritext/benchmark/runner.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Benchmark execution and tracking."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import veritext
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
from veritext.benchmark.regression import compute_baseline, detect_regression
|
||||||
|
from veritext.benchmark.storage import BenchmarkStorage
|
||||||
|
from veritext.core.exceptions import RegressionDetectedError
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
|
# Default metrics to use for evaluation
|
||||||
|
DEFAULT_METRICS = ["rouge_l", "bleu4"]
|
||||||
|
|
||||||
|
|
||||||
|
class Benchmark:
|
||||||
|
"""Track text quality over time."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
storage_path: str | Path = "benchmarks/",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise a benchmark tracker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name identifying this benchmark suite.
|
||||||
|
storage_path: Directory for storing benchmark data.
|
||||||
|
"""
|
||||||
|
self._name = name
|
||||||
|
self._storage_path = Path(storage_path)
|
||||||
|
self._storage = BenchmarkStorage(self._storage_path / f"{name}.db")
|
||||||
|
|
||||||
|
# Initialise metrics
|
||||||
|
self._bleu = Bleu()
|
||||||
|
self._rouge = Rouge()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the benchmark name."""
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def _compute_metrics(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]],
|
||||||
|
metric_names: list[str],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Compute requested metrics for the given samples."""
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
|
||||||
|
for metric_name in metric_names:
|
||||||
|
if metric_name in ("bleu1", "bleu2", "bleu3", "bleu4"):
|
||||||
|
batch_result = self._bleu.batch_score(candidates, references)
|
||||||
|
stats = batch_result.stats.get(metric_name)
|
||||||
|
if stats:
|
||||||
|
results[metric_name] = stats.mean
|
||||||
|
|
||||||
|
elif metric_name in (
|
||||||
|
"rouge1",
|
||||||
|
"rouge2",
|
||||||
|
"rouge_l",
|
||||||
|
"rouge1_fmeasure",
|
||||||
|
"rouge2_fmeasure",
|
||||||
|
"rouge_l_fmeasure",
|
||||||
|
):
|
||||||
|
rouge_result = self._rouge.batch_score(candidates, references)
|
||||||
|
# Map short names to stat names
|
||||||
|
stat_name = metric_name
|
||||||
|
if metric_name == "rouge1":
|
||||||
|
stat_name = "rouge1_fmeasure"
|
||||||
|
elif metric_name == "rouge2":
|
||||||
|
stat_name = "rouge2_fmeasure"
|
||||||
|
elif metric_name == "rouge_l":
|
||||||
|
stat_name = "rouge_l_fmeasure"
|
||||||
|
|
||||||
|
stats = rouge_result.stats.get(stat_name)
|
||||||
|
if stats:
|
||||||
|
results[metric_name] = stats.mean
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]],
|
||||||
|
metrics: list[str] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> BenchmarkRun:
|
||||||
|
"""
|
||||||
|
Evaluate candidates against references, store results, and return the run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of candidate texts to evaluate.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
metrics: List of metrics to compute. Defaults to ["rouge_l", "bleu4"].
|
||||||
|
metadata: Optional metadata (git_sha, model version, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The BenchmarkRun record that was created and stored.
|
||||||
|
"""
|
||||||
|
metric_names = metrics or DEFAULT_METRICS
|
||||||
|
metric_results = self._compute_metrics(candidates, references, metric_names)
|
||||||
|
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
benchmark_name=self._name,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
veritext_version=veritext.__version__,
|
||||||
|
metrics=metric_results,
|
||||||
|
sample_count=len(candidates),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._storage.save_run(run)
|
||||||
|
return run
|
||||||
|
|
||||||
|
def check_regression(
|
||||||
|
self,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
window: int = 10,
|
||||||
|
) -> RegressionReport:
|
||||||
|
"""
|
||||||
|
Compare latest run against historical baseline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tolerance: Maximum allowed metric drop before regression is flagged.
|
||||||
|
window: Number of historical runs to include in baseline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RegressionReport with comparison results.
|
||||||
|
"""
|
||||||
|
runs = self._storage.get_runs(self._name)
|
||||||
|
|
||||||
|
if not runs:
|
||||||
|
# No runs at all
|
||||||
|
return RegressionReport(
|
||||||
|
detected=False,
|
||||||
|
baseline={},
|
||||||
|
current={},
|
||||||
|
deltas={},
|
||||||
|
tolerance=tolerance,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_run = runs[0]
|
||||||
|
# Baseline excludes the current run
|
||||||
|
historical_runs = runs[1:]
|
||||||
|
baseline = compute_baseline(historical_runs, window=window)
|
||||||
|
|
||||||
|
return detect_regression(current_run.metrics, baseline, tolerance)
|
||||||
|
|
||||||
|
def assert_no_regression(
|
||||||
|
self,
|
||||||
|
tolerance: float = 0.05,
|
||||||
|
window: int = 10,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Raise RegressionDetectedError if quality dropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tolerance: Maximum allowed metric drop before regression is flagged.
|
||||||
|
window: Number of historical runs to include in baseline.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RegressionDetectedError: If a regression is detected.
|
||||||
|
"""
|
||||||
|
report = self.check_regression(tolerance=tolerance, window=window)
|
||||||
|
if report.detected:
|
||||||
|
raise RegressionDetectedError(report.summary)
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 20) -> list[BenchmarkRun]:
|
||||||
|
"""
|
||||||
|
Get recent benchmark runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of runs to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BenchmarkRun objects, most recent first.
|
||||||
|
"""
|
||||||
|
return self._storage.get_runs(self._name, limit=limit)
|
||||||
179
src/veritext/benchmark/storage.py
Normal file
179
src/veritext/benchmark/storage.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""SQLite storage for benchmark history."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun
|
||||||
|
from veritext.core.exceptions import StorageError
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkStorage:
|
||||||
|
"""SQLite-backed storage for benchmark runs."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Initialise storage, creating tables if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the SQLite database file.
|
||||||
|
"""
|
||||||
|
self._db_path = db_path
|
||||||
|
self._ensure_parent_exists()
|
||||||
|
self._init_database()
|
||||||
|
|
||||||
|
def _ensure_parent_exists(self) -> None:
|
||||||
|
"""Ensure the parent directory exists."""
|
||||||
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _get_connection(self) -> sqlite3.Connection:
|
||||||
|
"""Get a database connection with WAL mode enabled."""
|
||||||
|
conn = sqlite3.connect(str(self._db_path), timeout=30.0)
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _init_database(self) -> None:
|
||||||
|
"""Create tables if they don't exist."""
|
||||||
|
try:
|
||||||
|
with self._get_connection() as conn:
|
||||||
|
conn.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS benchmark_runs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
benchmark_name TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
veritext_version TEXT NOT NULL,
|
||||||
|
sample_count INTEGER NOT NULL,
|
||||||
|
metadata TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS benchmark_metrics (
|
||||||
|
run_id TEXT REFERENCES benchmark_runs(id) ON DELETE CASCADE,
|
||||||
|
metric_name TEXT NOT NULL,
|
||||||
|
value REAL NOT NULL,
|
||||||
|
PRIMARY KEY (run_id, metric_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_benchmark_name
|
||||||
|
ON benchmark_runs(benchmark_name, timestamp DESC);
|
||||||
|
""")
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
raise StorageError(f"Failed to initialise database: {e}") from e
|
||||||
|
|
||||||
|
def save_run(self, run: BenchmarkRun) -> None:
|
||||||
|
"""
|
||||||
|
Persist a benchmark run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run: The benchmark run to save.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If the save operation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self._get_connection() as conn:
|
||||||
|
# Insert the run
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO benchmark_runs
|
||||||
|
(id, benchmark_name, timestamp, veritext_version, sample_count, metadata)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
run.id,
|
||||||
|
run.benchmark_name,
|
||||||
|
run.timestamp.isoformat(),
|
||||||
|
run.veritext_version,
|
||||||
|
run.sample_count,
|
||||||
|
json.dumps(run.metadata) if run.metadata else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Insert metrics
|
||||||
|
for metric_name, value in run.metrics.items():
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO benchmark_metrics (run_id, metric_name, value)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(run.id, metric_name, value),
|
||||||
|
)
|
||||||
|
except sqlite3.IntegrityError as e:
|
||||||
|
raise StorageError(f"Run with id '{run.id}' already exists") from e
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
raise StorageError(f"Failed to save benchmark run: {e}") from e
|
||||||
|
|
||||||
|
def get_runs(
|
||||||
|
self,
|
||||||
|
benchmark_name: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[BenchmarkRun]:
|
||||||
|
"""
|
||||||
|
Retrieve runs for a benchmark, most recent first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
benchmark_name: Name of the benchmark to retrieve runs for.
|
||||||
|
limit: Maximum number of runs to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BenchmarkRun objects, most recent first.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StorageError: If the retrieval fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with self._get_connection() as conn:
|
||||||
|
query = """
|
||||||
|
SELECT id, benchmark_name, timestamp, veritext_version,
|
||||||
|
sample_count, metadata
|
||||||
|
FROM benchmark_runs
|
||||||
|
WHERE benchmark_name = ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
"""
|
||||||
|
if limit is not None:
|
||||||
|
query += " LIMIT ?"
|
||||||
|
rows = conn.execute(query, (benchmark_name, limit)).fetchall()
|
||||||
|
else:
|
||||||
|
rows = conn.execute(query, (benchmark_name,)).fetchall()
|
||||||
|
|
||||||
|
runs = []
|
||||||
|
for row in rows:
|
||||||
|
# Get metrics for this run
|
||||||
|
metrics_rows = conn.execute(
|
||||||
|
"SELECT metric_name, value FROM benchmark_metrics WHERE run_id = ?",
|
||||||
|
(row["id"],),
|
||||||
|
).fetchall()
|
||||||
|
metrics = {m["metric_name"]: m["value"] for m in metrics_rows}
|
||||||
|
|
||||||
|
metadata = json.loads(row["metadata"]) if row["metadata"] else {}
|
||||||
|
|
||||||
|
runs.append(
|
||||||
|
BenchmarkRun(
|
||||||
|
id=row["id"],
|
||||||
|
benchmark_name=row["benchmark_name"],
|
||||||
|
timestamp=datetime.fromisoformat(row["timestamp"]),
|
||||||
|
veritext_version=row["veritext_version"],
|
||||||
|
sample_count=row["sample_count"],
|
||||||
|
metrics=metrics,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return runs
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
raise StorageError(f"Failed to retrieve benchmark runs: {e}") from e
|
||||||
|
|
||||||
|
def get_latest_run(self, benchmark_name: str) -> BenchmarkRun | None:
|
||||||
|
"""
|
||||||
|
Get the most recent run for a benchmark.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
benchmark_name: Name of the benchmark.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The most recent BenchmarkRun, or None if no runs exist.
|
||||||
|
"""
|
||||||
|
runs = self.get_runs(benchmark_name, limit=1)
|
||||||
|
return runs[0] if runs else None
|
||||||
5
src/veritext/cli/__init__.py
Normal file
5
src/veritext/cli/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""CLI module: Command-line interface for Veritext."""
|
||||||
|
|
||||||
|
from veritext.cli.main import app
|
||||||
|
|
||||||
|
__all__ = ["app"]
|
||||||
166
src/veritext/cli/benchmark.py
Normal file
166
src/veritext/cli/benchmark.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Benchmark commands for quality tracking."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from veritext.benchmark import Benchmark
|
||||||
|
from veritext.cli.formatters import (
|
||||||
|
console,
|
||||||
|
format_benchmark_history,
|
||||||
|
format_regression_report,
|
||||||
|
)
|
||||||
|
from veritext.cli.readers import read_jsonl
|
||||||
|
|
||||||
|
benchmark_app = typer.Typer(
|
||||||
|
name="benchmark",
|
||||||
|
help="Track and compare text quality over time.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@benchmark_app.command("run")
|
||||||
|
def benchmark_run(
|
||||||
|
name: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(help="Name for this benchmark suite."),
|
||||||
|
],
|
||||||
|
file: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Option("--file", "-f", help="JSONL file with candidate/reference pairs."),
|
||||||
|
],
|
||||||
|
metrics: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option(
|
||||||
|
"--metrics",
|
||||||
|
"-m",
|
||||||
|
help="Comma-separated metrics to track (e.g., rouge_l,bleu4).",
|
||||||
|
),
|
||||||
|
] = "rouge_l,bleu4",
|
||||||
|
storage_path: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Option(
|
||||||
|
"--storage",
|
||||||
|
"-s",
|
||||||
|
help="Directory for benchmark data storage.",
|
||||||
|
),
|
||||||
|
] = Path("benchmarks"),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run a benchmark evaluation and store the results.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
veritext benchmark run my_bench -f data.jsonl -m rouge_l,bleu4
|
||||||
|
"""
|
||||||
|
# Read text pairs
|
||||||
|
try:
|
||||||
|
pairs = read_jsonl(file)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
raise typer.Exit(code=1) from e
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
console.print("[yellow]Warning:[/yellow] No text pairs found in file.")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
# Parse metrics
|
||||||
|
metric_names = [m.strip() for m in metrics.split(",")]
|
||||||
|
|
||||||
|
candidates = [p.candidate for p in pairs]
|
||||||
|
references = [p.reference for p in pairs]
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
bench = Benchmark(name, storage_path=storage_path)
|
||||||
|
run = bench.evaluate(candidates, references, metrics=metric_names)
|
||||||
|
|
||||||
|
console.print(f"[green]Benchmark '{name}' completed.[/green]")
|
||||||
|
console.print(f"Samples: {run.sample_count}")
|
||||||
|
console.print("\nMetrics:")
|
||||||
|
for metric_name, value in sorted(run.metrics.items()):
|
||||||
|
console.print(f" {metric_name}: {value:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
@benchmark_app.command("show")
|
||||||
|
def benchmark_show(
|
||||||
|
name: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(help="Name of the benchmark suite."),
|
||||||
|
],
|
||||||
|
last: Annotated[
|
||||||
|
int,
|
||||||
|
typer.Option("--last", "-n", help="Number of recent runs to show."),
|
||||||
|
] = 20,
|
||||||
|
storage_path: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Option(
|
||||||
|
"--storage",
|
||||||
|
"-s",
|
||||||
|
help="Directory for benchmark data storage.",
|
||||||
|
),
|
||||||
|
] = Path("benchmarks"),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Show benchmark history for a suite.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
veritext benchmark show my_bench --last 10
|
||||||
|
"""
|
||||||
|
bench = Benchmark(name, storage_path=storage_path)
|
||||||
|
runs = bench.get_history(limit=last)
|
||||||
|
|
||||||
|
if not runs:
|
||||||
|
console.print(f"[yellow]No benchmark runs found for '{name}'.[/yellow]")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
table = format_benchmark_history(runs)
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@benchmark_app.command("check")
|
||||||
|
def benchmark_check(
|
||||||
|
name: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Argument(help="Name of the benchmark suite."),
|
||||||
|
],
|
||||||
|
tolerance: Annotated[
|
||||||
|
float,
|
||||||
|
typer.Option(
|
||||||
|
"--tolerance",
|
||||||
|
"-t",
|
||||||
|
help="Maximum allowed metric drop (e.g., 0.05 = 5%).",
|
||||||
|
),
|
||||||
|
] = 0.05,
|
||||||
|
window: Annotated[
|
||||||
|
int,
|
||||||
|
typer.Option(
|
||||||
|
"--window",
|
||||||
|
"-w",
|
||||||
|
help="Number of historical runs for baseline.",
|
||||||
|
),
|
||||||
|
] = 10,
|
||||||
|
storage_path: Annotated[
|
||||||
|
Path,
|
||||||
|
typer.Option(
|
||||||
|
"--storage",
|
||||||
|
"-s",
|
||||||
|
help="Directory for benchmark data storage.",
|
||||||
|
),
|
||||||
|
] = Path("benchmarks"),
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Check for quality regression against historical baseline.
|
||||||
|
|
||||||
|
Exits with code 1 if regression detected (for CI integration).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
veritext benchmark check my_bench --tolerance 0.05
|
||||||
|
"""
|
||||||
|
bench = Benchmark(name, storage_path=storage_path)
|
||||||
|
report = bench.check_regression(tolerance=tolerance, window=window)
|
||||||
|
|
||||||
|
panel = format_regression_report(report)
|
||||||
|
console.print(panel)
|
||||||
|
|
||||||
|
if report.detected:
|
||||||
|
raise typer.Exit(code=1)
|
||||||
170
src/veritext/cli/formatters.py
Normal file
170
src/veritext/cli/formatters.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""Rich output formatters for CLI display."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def format_validation_table(
|
||||||
|
results: dict[str, float],
|
||||||
|
threshold: float | None = None,
|
||||||
|
) -> Table:
|
||||||
|
"""
|
||||||
|
Format validation results as a Rich table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Dictionary of metric names to scores.
|
||||||
|
threshold: Optional threshold for pass/fail colouring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table object.
|
||||||
|
"""
|
||||||
|
table = Table(title="Validation Results", show_header=True, header_style="bold")
|
||||||
|
table.add_column("Metric", style="cyan")
|
||||||
|
table.add_column("Score", justify="right")
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
table.add_column("Status", justify="center")
|
||||||
|
|
||||||
|
for metric, score in sorted(results.items()):
|
||||||
|
score_str = f"{score:.4f}"
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
status = "[green]PASS[/green]" if score >= threshold else "[red]FAIL[/red]"
|
||||||
|
table.add_row(metric, score_str, status)
|
||||||
|
else:
|
||||||
|
table.add_row(metric, score_str)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def format_validation_json(results: dict[str, float]) -> str:
|
||||||
|
"""
|
||||||
|
Format validation results as JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Dictionary of metric names to scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON string.
|
||||||
|
"""
|
||||||
|
return json.dumps(results, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def format_validation_simple(results: dict[str, float]) -> str:
|
||||||
|
"""
|
||||||
|
Format validation results as simple text output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Dictionary of metric names to scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Simple text string with one metric per line.
|
||||||
|
"""
|
||||||
|
lines = [f"{metric}: {score:.4f}" for metric, score in sorted(results.items())]
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def format_benchmark_history(runs: list[BenchmarkRun]) -> Table:
|
||||||
|
"""
|
||||||
|
Format benchmark run history as a Rich table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runs: List of BenchmarkRun objects (most recent first).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table object.
|
||||||
|
"""
|
||||||
|
if not runs:
|
||||||
|
table = Table(title="Benchmark History")
|
||||||
|
table.add_column("No runs found")
|
||||||
|
return table
|
||||||
|
|
||||||
|
# Get all metric names from the runs
|
||||||
|
metric_names: set[str] = set()
|
||||||
|
for run in runs:
|
||||||
|
metric_names.update(run.metrics.keys())
|
||||||
|
sorted_metrics = sorted(metric_names)
|
||||||
|
|
||||||
|
table = Table(title="Benchmark History", show_header=True, header_style="bold")
|
||||||
|
table.add_column("Timestamp", style="cyan")
|
||||||
|
table.add_column("Samples", justify="right")
|
||||||
|
for metric in sorted_metrics:
|
||||||
|
table.add_column(metric, justify="right")
|
||||||
|
|
||||||
|
for run in runs:
|
||||||
|
timestamp = run.timestamp.strftime("%Y-%m-%d %H:%M")
|
||||||
|
samples = str(run.sample_count)
|
||||||
|
metric_values = [f"{run.metrics.get(m, 0.0):.4f}" for m in sorted_metrics]
|
||||||
|
table.add_row(timestamp, samples, *metric_values)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def format_regression_report(report: RegressionReport) -> Panel:
|
||||||
|
"""
|
||||||
|
Format a regression report as a Rich panel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
report: RegressionReport object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Panel object with formatted report.
|
||||||
|
"""
|
||||||
|
if not report.detected:
|
||||||
|
content = (
|
||||||
|
f"[green]No regression detected.[/green]\nTolerance: {report.tolerance:.2%}"
|
||||||
|
)
|
||||||
|
return Panel(content, title="Regression Check", border_style="green")
|
||||||
|
|
||||||
|
# Build regression details
|
||||||
|
lines = [
|
||||||
|
"[red]Regression detected![/red]",
|
||||||
|
f"Tolerance: {report.tolerance:.2%}",
|
||||||
|
"",
|
||||||
|
"Metric details:",
|
||||||
|
]
|
||||||
|
|
||||||
|
for metric in sorted(report.deltas.keys()):
|
||||||
|
baseline = report.baseline.get(metric, 0.0)
|
||||||
|
current = report.current.get(metric, 0.0)
|
||||||
|
delta = report.deltas[metric]
|
||||||
|
|
||||||
|
if delta < -report.tolerance:
|
||||||
|
status = "[red]REGRESSED[/red]"
|
||||||
|
else:
|
||||||
|
status = "[green]OK[/green]"
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
f" {metric}: {current:.4f} (baseline: {baseline:.4f}, "
|
||||||
|
f"delta: {delta:+.4f}) {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Panel("\n".join(lines), title="Regression Check", border_style="red")
|
||||||
|
|
||||||
|
|
||||||
|
def print_validation_output(
|
||||||
|
results: dict[str, float],
|
||||||
|
output_format: str = "table",
|
||||||
|
threshold: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Print validation results in the specified format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Dictionary of metric names to scores.
|
||||||
|
output_format: Output format ('table', 'json', or 'simple').
|
||||||
|
threshold: Optional threshold for pass/fail colouring (table only).
|
||||||
|
"""
|
||||||
|
if output_format == "json":
|
||||||
|
console.print(format_validation_json(results))
|
||||||
|
elif output_format == "simple":
|
||||||
|
console.print(format_validation_simple(results))
|
||||||
|
else:
|
||||||
|
console.print(format_validation_table(results, threshold))
|
||||||
37
src/veritext/cli/main.py
Normal file
37
src/veritext/cli/main.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Veritext CLI entry point."""
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
import veritext
|
||||||
|
from veritext.cli.benchmark import benchmark_app
|
||||||
|
from veritext.cli.validate import validate
|
||||||
|
|
||||||
|
app = typer.Typer(
|
||||||
|
name="veritext",
|
||||||
|
help="Semantic text validation framework.",
|
||||||
|
no_args_is_help=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register commands
|
||||||
|
app.command()(validate)
|
||||||
|
app.add_typer(benchmark_app)
|
||||||
|
|
||||||
|
|
||||||
|
@app.callback(invoke_without_command=True)
|
||||||
|
def main(
|
||||||
|
version: bool | None = typer.Option(
|
||||||
|
None,
|
||||||
|
"--version",
|
||||||
|
"-V",
|
||||||
|
help="Show version and exit.",
|
||||||
|
is_eager=True,
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Veritext: Semantic text validation framework for Python."""
|
||||||
|
if version:
|
||||||
|
typer.echo(f"veritext {veritext.__version__}")
|
||||||
|
raise typer.Exit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
120
src/veritext/cli/readers.py
Normal file
120
src/veritext/cli/readers.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Input readers for CLI operations."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextPair:
|
||||||
|
"""A candidate-reference text pair for validation."""
|
||||||
|
|
||||||
|
candidate: str
|
||||||
|
reference: str
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl(path: Path) -> list[TextPair]:
|
||||||
|
"""
|
||||||
|
Read text pairs from a JSONL file.
|
||||||
|
|
||||||
|
Each line must be a JSON object with 'candidate' and 'reference' keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the JSONL file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TextPair objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the file does not exist.
|
||||||
|
ValueError: If any line is malformed or missing required keys.
|
||||||
|
"""
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {path}")
|
||||||
|
|
||||||
|
pairs: list[TextPair] = []
|
||||||
|
with path.open() as f:
|
||||||
|
for line_num, line in enumerate(f, start=1):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON on line {line_num}: {e}") from e
|
||||||
|
|
||||||
|
if "candidate" not in data:
|
||||||
|
raise ValueError(f"Missing 'candidate' key on line {line_num}")
|
||||||
|
if "reference" not in data:
|
||||||
|
raise ValueError(f"Missing 'reference' key on line {line_num}")
|
||||||
|
|
||||||
|
pairs.append(
|
||||||
|
TextPair(
|
||||||
|
candidate=str(data["candidate"]),
|
||||||
|
reference=str(data["reference"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def read_paired_jsonl(candidates_path: Path, references_path: Path) -> list[TextPair]:
|
||||||
|
"""
|
||||||
|
Read text pairs from separate candidate and reference JSONL files.
|
||||||
|
|
||||||
|
Each file should contain one JSON object per line with a 'text' key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates_path: Path to the candidates JSONL file.
|
||||||
|
references_path: Path to the references JSONL file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TextPair objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If either file does not exist.
|
||||||
|
ValueError: If files have different lengths or are malformed.
|
||||||
|
"""
|
||||||
|
candidates = _read_text_jsonl(candidates_path, "candidates")
|
||||||
|
references = _read_text_jsonl(references_path, "references")
|
||||||
|
|
||||||
|
if len(candidates) != len(references):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of candidates ({len(candidates)}) does not match "
|
||||||
|
f"number of references ({len(references)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
TextPair(candidate=c, reference=r)
|
||||||
|
for c, r in zip(candidates, references, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _read_text_jsonl(path: Path, label: str) -> list[str]:
|
||||||
|
"""Read text values from a JSONL file with 'text' key per line."""
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"{label.capitalize()} file not found: {path}")
|
||||||
|
|
||||||
|
texts: list[str] = []
|
||||||
|
with path.open() as f:
|
||||||
|
for line_num, line in enumerate(f, start=1):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid JSON in {label} file on line {line_num}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if "text" not in data:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing 'text' key in {label} file on line {line_num}"
|
||||||
|
)
|
||||||
|
|
||||||
|
texts.append(str(data["text"]))
|
||||||
|
|
||||||
|
return texts
|
||||||
250
src/veritext/cli/validate.py
Normal file
250
src/veritext/cli/validate.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Validate command for computing text metrics."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from veritext.cli.formatters import console, print_validation_output
|
||||||
|
from veritext.cli.readers import read_jsonl, read_paired_jsonl
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.lexical import Lexical
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
|
# Available metrics
|
||||||
|
AVAILABLE_METRICS = frozenset(
|
||||||
|
{"bleu", "bleu1", "bleu2", "bleu3", "bleu4", "rouge", "rouge_l", "lexical"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Lazily-initialised metric instances
|
||||||
|
_bleu: Bleu | None = None
|
||||||
|
_rouge: Rouge | None = None
|
||||||
|
_lexical: Lexical | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bleu() -> Bleu:
|
||||||
|
"""Get or create the BLEU metric instance."""
|
||||||
|
global _bleu
|
||||||
|
if _bleu is None:
|
||||||
|
_bleu = Bleu()
|
||||||
|
return _bleu
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rouge() -> Rouge:
|
||||||
|
"""Get or create the ROUGE metric instance."""
|
||||||
|
global _rouge
|
||||||
|
if _rouge is None:
|
||||||
|
_rouge = Rouge()
|
||||||
|
return _rouge
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lexical() -> Lexical:
|
||||||
|
"""Get or create the lexical metric instance."""
|
||||||
|
global _lexical
|
||||||
|
if _lexical is None:
|
||||||
|
_lexical = Lexical()
|
||||||
|
return _lexical
|
||||||
|
|
||||||
|
|
||||||
|
# Metric registry: maps metric names to (result_keys, single_extractor, batch_extractor)
|
||||||
|
# - result_keys: output keys to populate
|
||||||
|
# - single_extractor: function(candidate, reference) -> dict of results
|
||||||
|
# - batch_extractor: function(candidates, references) -> dict of results
|
||||||
|
def _bleu_single(candidate: str, reference: str, key: str) -> dict[str, float]:
|
||||||
|
"""Extract a BLEU score for single mode."""
|
||||||
|
result = _get_bleu().score(candidate, reference)
|
||||||
|
return {key: getattr(result, key)}
|
||||||
|
|
||||||
|
|
||||||
|
def _bleu_batch(
|
||||||
|
candidates: list[str], references: list[str], key: str
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Extract a BLEU score for batch mode."""
|
||||||
|
batch = _get_bleu().batch_score(candidates, references)
|
||||||
|
stats = batch.stats.get(key)
|
||||||
|
return {key: stats.mean} if stats else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _rouge_single(candidate: str, reference: str) -> dict[str, float]:
|
||||||
|
"""Extract ROUGE-L F-measure for single mode."""
|
||||||
|
result = _get_rouge().score(candidate, reference)
|
||||||
|
return {"rouge_l": result.rouge_l.fmeasure}
|
||||||
|
|
||||||
|
|
||||||
|
def _rouge_batch(candidates: list[str], references: list[str]) -> dict[str, float]:
|
||||||
|
"""Extract ROUGE-L F-measure for batch mode."""
|
||||||
|
batch = _get_rouge().batch_score(candidates, references)
|
||||||
|
stats = batch.stats.get("rouge_l_fmeasure")
|
||||||
|
return {"rouge_l": stats.mean} if stats else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _lexical_single(candidate: str, reference: str) -> dict[str, float]:
|
||||||
|
"""Extract lexical scores for single mode."""
|
||||||
|
result = _get_lexical().score(candidate, reference)
|
||||||
|
return {"jaccard": result.jaccard, "token_overlap": result.token_overlap}
|
||||||
|
|
||||||
|
|
||||||
|
def _lexical_batch(candidates: list[str], references: list[str]) -> dict[str, float]:
|
||||||
|
"""Extract lexical scores for batch mode."""
|
||||||
|
batch = _get_lexical().batch_score(candidates, references)
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
jaccard_stats = batch.stats.get("jaccard")
|
||||||
|
overlap_stats = batch.stats.get("token_overlap")
|
||||||
|
if jaccard_stats:
|
||||||
|
results["jaccard"] = jaccard_stats.mean
|
||||||
|
if overlap_stats:
|
||||||
|
results["token_overlap"] = overlap_stats.mean
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_metrics(
|
||||||
|
candidate: str,
|
||||||
|
reference: str,
|
||||||
|
metric_names: list[str],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Compute requested metrics for a single text pair."""
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
|
||||||
|
for metric in metric_names:
|
||||||
|
if metric in ("bleu", "bleu4"):
|
||||||
|
results.update(_bleu_single(candidate, reference, "bleu4"))
|
||||||
|
elif metric in ("bleu1", "bleu2", "bleu3"):
|
||||||
|
results.update(_bleu_single(candidate, reference, metric))
|
||||||
|
elif metric in ("rouge", "rouge_l"):
|
||||||
|
results.update(_rouge_single(candidate, reference))
|
||||||
|
elif metric == "lexical":
|
||||||
|
results.update(_lexical_single(candidate, reference))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_batch_metrics(
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str],
|
||||||
|
metric_names: list[str],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Compute average metrics for a batch of text pairs."""
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
|
||||||
|
for metric in metric_names:
|
||||||
|
if metric in ("bleu", "bleu4"):
|
||||||
|
results.update(_bleu_batch(candidates, references, "bleu4"))
|
||||||
|
elif metric in ("bleu1", "bleu2", "bleu3"):
|
||||||
|
results.update(_bleu_batch(candidates, references, metric))
|
||||||
|
elif metric in ("rouge", "rouge_l"):
|
||||||
|
results.update(_rouge_batch(candidates, references))
|
||||||
|
elif metric == "lexical":
|
||||||
|
results.update(_lexical_batch(candidates, references))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_metrics(metrics_str: str) -> list[str]:
|
||||||
|
"""Parse comma-separated metric names."""
|
||||||
|
metrics = [m.strip().lower() for m in metrics_str.split(",")]
|
||||||
|
|
||||||
|
# Validate metric names
|
||||||
|
invalid = [m for m in metrics if m not in AVAILABLE_METRICS]
|
||||||
|
if invalid:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Unknown metrics: {', '.join(invalid)}. "
|
||||||
|
f"Available: {', '.join(sorted(AVAILABLE_METRICS))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def validate(
|
||||||
|
text: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Argument(help="Candidate text to validate (inline mode)."),
|
||||||
|
] = None,
|
||||||
|
reference: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option("--reference", "-r", help="Reference text for comparison."),
|
||||||
|
] = None,
|
||||||
|
file: Annotated[
|
||||||
|
Path | None,
|
||||||
|
typer.Option("--file", "-f", help="JSONL file with candidate/reference pairs."),
|
||||||
|
] = None,
|
||||||
|
reference_file: Annotated[
|
||||||
|
Path | None,
|
||||||
|
typer.Option(
|
||||||
|
"--reference-file",
|
||||||
|
"-R",
|
||||||
|
help="Separate JSONL file with references (requires --file).",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
metrics: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option(
|
||||||
|
"--metrics",
|
||||||
|
"-m",
|
||||||
|
help="Comma-separated metrics: bleu, bleu1-4, rouge, rouge_l, lexical.",
|
||||||
|
),
|
||||||
|
] = "bleu,rouge",
|
||||||
|
output: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--output", "-o", help="Output format: table, json, or simple."),
|
||||||
|
] = "table",
|
||||||
|
threshold: Annotated[
|
||||||
|
float | None,
|
||||||
|
typer.Option("--threshold", "-t", help="Score threshold for pass/fail status."),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate text quality using various metrics.
|
||||||
|
|
||||||
|
Use inline mode for single texts:
|
||||||
|
veritext validate "text" -r "reference" -m bleu,rouge
|
||||||
|
|
||||||
|
Use file mode for batches:
|
||||||
|
veritext validate -f outputs.jsonl -m bleu,rouge
|
||||||
|
"""
|
||||||
|
# Parse and validate metric names
|
||||||
|
try:
|
||||||
|
metric_names = _parse_metrics(metrics)
|
||||||
|
except typer.BadParameter as e:
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
raise typer.Exit(code=1) from e
|
||||||
|
|
||||||
|
# Validate output format
|
||||||
|
if output not in ("table", "json", "simple"):
|
||||||
|
console.print(f"[red]Error:[/red] Invalid output format: {output}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Determine mode: inline vs file
|
||||||
|
if file is not None:
|
||||||
|
# File mode
|
||||||
|
try:
|
||||||
|
if reference_file is not None:
|
||||||
|
pairs = read_paired_jsonl(file, reference_file)
|
||||||
|
else:
|
||||||
|
pairs = read_jsonl(file)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
raise typer.Exit(code=1) from e
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
console.print("[yellow]Warning:[/yellow] No text pairs found in file.")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
candidates = [p.candidate for p in pairs]
|
||||||
|
references = [p.reference for p in pairs]
|
||||||
|
|
||||||
|
results = _compute_batch_metrics(candidates, references, metric_names)
|
||||||
|
console.print(f"[dim]Evaluated {len(pairs)} text pairs.[/dim]\n")
|
||||||
|
|
||||||
|
elif text is not None and reference is not None:
|
||||||
|
# Inline mode
|
||||||
|
results = _compute_metrics(text, reference, metric_names)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Invalid usage
|
||||||
|
console.print(
|
||||||
|
"[red]Error:[/red] Provide either text and --reference, "
|
||||||
|
"or --file for batch mode."
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
print_validation_output(results, output, threshold)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Configuration management using pydantic-settings."""
|
"""Configuration management using pydantic-settings."""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ class VeritextSettings(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
def get_settings() -> VeritextSettings:
|
def get_settings() -> VeritextSettings:
|
||||||
"""Get the current settings instance."""
|
"""Get the cached settings instance."""
|
||||||
return VeritextSettings()
|
return VeritextSettings()
|
||||||
|
|||||||
@@ -1,9 +1,18 @@
|
|||||||
"""Metrics module: BLEU, lexical similarity, and batch processing."""
|
"""Metrics module: BLEU, ROUGE, lexical similarity, readability, and batch processing."""
|
||||||
|
|
||||||
from veritext.metrics.base import AggregateStats, BatchResult, Metric
|
from veritext.metrics.base import AggregateStats, BatchResult, Metric
|
||||||
from veritext.metrics.bleu import Bleu
|
from veritext.metrics.bleu import Bleu
|
||||||
from veritext.metrics.lexical import Lexical
|
from veritext.metrics.lexical import Lexical
|
||||||
from veritext.metrics.results import BleuResult, LexicalResult
|
from veritext.metrics.readability import Readability
|
||||||
|
from veritext.metrics.results import (
|
||||||
|
BleuResult,
|
||||||
|
LexicalResult,
|
||||||
|
ReadabilityResult,
|
||||||
|
RougeResult,
|
||||||
|
RougeScore,
|
||||||
|
SemanticResult,
|
||||||
|
)
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AggregateStats",
|
"AggregateStats",
|
||||||
@@ -13,4 +22,10 @@ __all__ = [
|
|||||||
"Lexical",
|
"Lexical",
|
||||||
"LexicalResult",
|
"LexicalResult",
|
||||||
"Metric",
|
"Metric",
|
||||||
|
"Readability",
|
||||||
|
"ReadabilityResult",
|
||||||
|
"Rouge",
|
||||||
|
"RougeResult",
|
||||||
|
"RougeScore",
|
||||||
|
"SemanticResult",
|
||||||
]
|
]
|
||||||
|
|||||||
195
src/veritext/metrics/readability.py
Normal file
195
src/veritext/metrics/readability.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Readability metrics implementation (Flesch-Kincaid)."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult
|
||||||
|
from veritext.metrics.results import ReadabilityResult
|
||||||
|
|
||||||
|
# Sentence-ending punctuation pattern
|
||||||
|
_SENTENCE_ENDINGS = re.compile(r"[.!?]+")
|
||||||
|
|
||||||
|
# Vowel pattern for syllable counting
|
||||||
|
_VOWELS = re.compile(r"[aeiouy]+", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_syllables(word: str) -> int:
|
||||||
|
"""
|
||||||
|
Count syllables in a word using a heuristic approach.
|
||||||
|
|
||||||
|
Uses vowel group counting with adjustments for common patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
word: The word to count syllables for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated syllable count (minimum 1 for non-empty words).
|
||||||
|
"""
|
||||||
|
if not word:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
word = word.lower().strip()
|
||||||
|
if not word:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Count vowel groups
|
||||||
|
vowel_groups = _VOWELS.findall(word)
|
||||||
|
count = len(vowel_groups)
|
||||||
|
|
||||||
|
# Adjust for silent 'e' at end
|
||||||
|
if word.endswith("e") and count > 1:
|
||||||
|
count -= 1
|
||||||
|
|
||||||
|
# Adjust for 'le' ending (e.g., "table", "able")
|
||||||
|
if word.endswith("le") and len(word) > 2 and word[-3] not in "aeiouy":
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# Adjust for 'ed' ending when not adding syllable
|
||||||
|
if word.endswith("ed") and len(word) > 2 and word[-3] not in "dt":
|
||||||
|
count = max(count - 1, 1)
|
||||||
|
|
||||||
|
# Ensure at least 1 syllable for any word
|
||||||
|
return max(count, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_sentences(text: str) -> int:
|
||||||
|
"""
|
||||||
|
Count sentences in text.
|
||||||
|
|
||||||
|
Splits on sentence-ending punctuation (.!?).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to count sentences in.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of sentences (minimum 1 for non-empty text).
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Split on sentence endings and filter empty strings
|
||||||
|
sentences = _SENTENCE_ENDINGS.split(text)
|
||||||
|
# Filter out empty segments
|
||||||
|
sentences = [s for s in sentences if s.strip()]
|
||||||
|
|
||||||
|
return max(len(sentences), 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_words(text: str) -> tuple[list[str], int]:
|
||||||
|
"""
|
||||||
|
Extract words from text and count them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (word list, word count).
|
||||||
|
"""
|
||||||
|
# Extract words (sequences of letters and apostrophes)
|
||||||
|
words = re.findall(r"[a-zA-Z']+", text)
|
||||||
|
# Filter out standalone apostrophes
|
||||||
|
words = [w for w in words if w.replace("'", "")]
|
||||||
|
return words, len(words)
|
||||||
|
|
||||||
|
|
||||||
|
class Readability:
|
||||||
|
"""
|
||||||
|
Readability metric using Flesch-Kincaid formulas.
|
||||||
|
|
||||||
|
Computes:
|
||||||
|
- Flesch-Kincaid Grade Level: US grade level required to understand text
|
||||||
|
- Flesch Reading Ease: Score from 0-100 (higher = easier to read)
|
||||||
|
|
||||||
|
This metric does NOT require reference text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this metric."""
|
||||||
|
return "readability"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_reference(self) -> bool:
|
||||||
|
"""Return whether this metric requires reference text."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
candidate: str,
|
||||||
|
reference: str | list[str] | None = None, # noqa: ARG002
|
||||||
|
) -> ReadabilityResult:
|
||||||
|
"""
|
||||||
|
Compute readability scores for a text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate: The text to score.
|
||||||
|
reference: Ignored (readability doesn't use reference text).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadabilityResult with Flesch-Kincaid scores.
|
||||||
|
"""
|
||||||
|
# Extract words and count
|
||||||
|
words, word_count = _count_words(candidate)
|
||||||
|
|
||||||
|
# Handle empty or trivial text
|
||||||
|
if word_count == 0:
|
||||||
|
return ReadabilityResult(
|
||||||
|
flesch_kincaid_grade=0.0,
|
||||||
|
flesch_reading_ease=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count sentences (ensure at least 1 to avoid division by zero)
|
||||||
|
sentence_count = max(_count_sentences(candidate), 1)
|
||||||
|
|
||||||
|
# Count syllables
|
||||||
|
syllable_count = sum(_count_syllables(word) for word in words)
|
||||||
|
|
||||||
|
# Compute ratios
|
||||||
|
words_per_sentence = word_count / sentence_count
|
||||||
|
syllables_per_word = syllable_count / word_count
|
||||||
|
|
||||||
|
# Flesch-Kincaid Grade Level
|
||||||
|
# Formula: 0.39 * (words/sentences) + 11.8 * (syllables/words) - 15.59
|
||||||
|
grade_level = 0.39 * words_per_sentence + 11.8 * syllables_per_word - 15.59
|
||||||
|
|
||||||
|
# Flesch Reading Ease
|
||||||
|
# Formula: 206.835 - 1.015 * (words/sentences) - 84.6 * (syllables/words)
|
||||||
|
reading_ease = 206.835 - 1.015 * words_per_sentence - 84.6 * syllables_per_word
|
||||||
|
|
||||||
|
return ReadabilityResult(
|
||||||
|
flesch_kincaid_grade=grade_level,
|
||||||
|
flesch_reading_ease=reading_ease,
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]] | None = None, # noqa: ARG002
|
||||||
|
) -> BatchResult[ReadabilityResult]:
|
||||||
|
"""
|
||||||
|
Compute readability scores for a batch of texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of texts to score.
|
||||||
|
references: Ignored (readability doesn't use reference text).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchResult containing individual results and aggregate statistics.
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
raise ValueError("Cannot compute batch statistics from empty list")
|
||||||
|
|
||||||
|
results: list[ReadabilityResult] = []
|
||||||
|
for cand in candidates:
|
||||||
|
results.append(self.score(cand))
|
||||||
|
|
||||||
|
# Compute aggregate statistics
|
||||||
|
stats = {
|
||||||
|
"flesch_kincaid_grade": AggregateStats.from_values(
|
||||||
|
[r.flesch_kincaid_grade for r in results]
|
||||||
|
),
|
||||||
|
"flesch_reading_ease": AggregateStats.from_values(
|
||||||
|
[r.flesch_reading_ease for r in results]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BatchResult(results=results, count=len(results), stats=stats)
|
||||||
@@ -39,3 +39,77 @@ class LexicalResult(BaseModel):
|
|||||||
|
|
||||||
token_overlap: float
|
token_overlap: float
|
||||||
"""Proportion of candidate tokens found in reference."""
|
"""Proportion of candidate tokens found in reference."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
"""Return Jaccard similarity as the primary score."""
|
||||||
|
return self.jaccard
|
||||||
|
|
||||||
|
|
||||||
|
class RougeScore(BaseModel):
|
||||||
|
"""Individual ROUGE variant score with precision, recall, F-measure."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
precision: float
|
||||||
|
"""Precision: overlap / candidate length."""
|
||||||
|
|
||||||
|
recall: float
|
||||||
|
"""Recall: overlap / reference length."""
|
||||||
|
|
||||||
|
fmeasure: float
|
||||||
|
"""F1-measure: harmonic mean of precision and recall."""
|
||||||
|
|
||||||
|
|
||||||
|
class RougeResult(BaseModel):
|
||||||
|
"""Result of ROUGE score computation."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
rouge1: RougeScore
|
||||||
|
"""ROUGE-1 (unigram) score."""
|
||||||
|
|
||||||
|
rouge2: RougeScore
|
||||||
|
"""ROUGE-2 (bigram) score."""
|
||||||
|
|
||||||
|
rouge_l: RougeScore
|
||||||
|
"""ROUGE-L (longest common subsequence) score."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
"""Return ROUGE-L F-measure as the primary score."""
|
||||||
|
return self.rouge_l.fmeasure
|
||||||
|
|
||||||
|
|
||||||
|
class ReadabilityResult(BaseModel):
|
||||||
|
"""Result of readability computation."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
flesch_kincaid_grade: float
|
||||||
|
"""US grade level (e.g., 8.0 = 8th grade reading level)."""
|
||||||
|
|
||||||
|
flesch_reading_ease: float
|
||||||
|
"""Score 0-100, higher = easier to read."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
"""Return Flesch reading ease as the primary score."""
|
||||||
|
return self.flesch_reading_ease
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticResult(BaseModel):
|
||||||
|
"""Result of semantic similarity computation."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
similarity: float
|
||||||
|
"""Cosine similarity score (0.0 to 1.0)."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""Name of the embedding model used."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
"""Return the primary score for this result."""
|
||||||
|
return self.similarity
|
||||||
|
|||||||
282
src/veritext/metrics/rouge.py
Normal file
282
src/veritext/metrics/rouge.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric implementation."""
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult
|
||||||
|
from veritext.metrics.results import RougeResult, RougeScore
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ngrams(tokens: list[str], n: int) -> Counter[tuple[str, ...]]:
|
||||||
|
"""Extract n-grams from a list of tokens."""
|
||||||
|
if n > len(tokens):
|
||||||
|
return Counter()
|
||||||
|
return Counter(tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _ngram_overlap(
|
||||||
|
candidate_ngrams: Counter[tuple[str, ...]],
|
||||||
|
reference_ngrams: Counter[tuple[str, ...]],
|
||||||
|
) -> int:
|
||||||
|
"""Compute the overlap count between candidate and reference n-grams."""
|
||||||
|
overlap = 0
|
||||||
|
for ngram, count in candidate_ngrams.items():
|
||||||
|
overlap += min(count, reference_ngrams.get(ngram, 0))
|
||||||
|
return overlap
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rouge_score(
|
||||||
|
candidate_tokens: list[str],
|
||||||
|
reference_tokens: list[str],
|
||||||
|
n: int,
|
||||||
|
) -> RougeScore:
|
||||||
|
"""
|
||||||
|
Compute ROUGE-n score for given n-gram size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate_tokens: Tokenised candidate text.
|
||||||
|
reference_tokens: Tokenised reference text.
|
||||||
|
n: N-gram size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RougeScore with precision, recall, and F-measure.
|
||||||
|
"""
|
||||||
|
candidate_ngrams = _get_ngrams(candidate_tokens, n)
|
||||||
|
reference_ngrams = _get_ngrams(reference_tokens, n)
|
||||||
|
|
||||||
|
candidate_count = sum(candidate_ngrams.values())
|
||||||
|
reference_count = sum(reference_ngrams.values())
|
||||||
|
|
||||||
|
if candidate_count == 0 and reference_count == 0:
|
||||||
|
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
|
||||||
|
|
||||||
|
overlap = _ngram_overlap(candidate_ngrams, reference_ngrams)
|
||||||
|
|
||||||
|
precision = overlap / candidate_count if candidate_count > 0 else 0.0
|
||||||
|
recall = overlap / reference_count if reference_count > 0 else 0.0
|
||||||
|
|
||||||
|
if precision + recall > 0:
|
||||||
|
fmeasure = 2 * precision * recall / (precision + recall)
|
||||||
|
else:
|
||||||
|
fmeasure = 0.0
|
||||||
|
|
||||||
|
return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||||
|
|
||||||
|
|
||||||
|
def _lcs_length(seq1: list[str], seq2: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Compute the length of the longest common subsequence.
|
||||||
|
|
||||||
|
Uses dynamic programming with O(m*n) time and O(min(m,n)) space.
|
||||||
|
"""
|
||||||
|
if not seq1 or not seq2:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Optimise by using shorter sequence for columns
|
||||||
|
if len(seq1) < len(seq2):
|
||||||
|
seq1, seq2 = seq2, seq1
|
||||||
|
|
||||||
|
m, n = len(seq1), len(seq2)
|
||||||
|
|
||||||
|
# Only need two rows at a time
|
||||||
|
prev = [0] * (n + 1)
|
||||||
|
curr = [0] * (n + 1)
|
||||||
|
|
||||||
|
for i in range(1, m + 1):
|
||||||
|
for j in range(1, n + 1):
|
||||||
|
if seq1[i - 1] == seq2[j - 1]:
|
||||||
|
curr[j] = prev[j - 1] + 1
|
||||||
|
else:
|
||||||
|
curr[j] = max(prev[j], curr[j - 1])
|
||||||
|
prev, curr = curr, prev
|
||||||
|
|
||||||
|
return prev[n]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_rouge_l(
|
||||||
|
candidate_tokens: list[str],
|
||||||
|
reference_tokens: list[str],
|
||||||
|
) -> RougeScore:
|
||||||
|
"""
|
||||||
|
Compute ROUGE-L score using longest common subsequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate_tokens: Tokenised candidate text.
|
||||||
|
reference_tokens: Tokenised reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RougeScore with precision, recall, and F-measure.
|
||||||
|
"""
|
||||||
|
if not candidate_tokens or not reference_tokens:
|
||||||
|
return RougeScore(precision=0.0, recall=0.0, fmeasure=0.0)
|
||||||
|
|
||||||
|
lcs = _lcs_length(candidate_tokens, reference_tokens)
|
||||||
|
|
||||||
|
precision = lcs / len(candidate_tokens)
|
||||||
|
recall = lcs / len(reference_tokens)
|
||||||
|
|
||||||
|
if precision + recall > 0:
|
||||||
|
fmeasure = 2 * precision * recall / (precision + recall)
|
||||||
|
else:
|
||||||
|
fmeasure = 0.0
|
||||||
|
|
||||||
|
return RougeScore(precision=precision, recall=recall, fmeasure=fmeasure)
|
||||||
|
|
||||||
|
|
||||||
|
def _max_rouge_scores(scores: list[RougeScore]) -> RougeScore:
|
||||||
|
"""Select the RougeScore with the highest F-measure from a list."""
|
||||||
|
return max(scores, key=lambda s: s.fmeasure)
|
||||||
|
|
||||||
|
|
||||||
|
class Rouge:
|
||||||
|
"""
|
||||||
|
ROUGE metric for measuring summary/generation quality.
|
||||||
|
|
||||||
|
Computes ROUGE-1 (unigram), ROUGE-2 (bigram), and ROUGE-L (LCS) scores.
|
||||||
|
ROUGE is recall-oriented, measuring how much of the reference is captured.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokeniser: WordTokeniser | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the ROUGE metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
"""
|
||||||
|
self._tokeniser = tokeniser or WordTokeniser()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this metric."""
|
||||||
|
return "rouge"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_reference(self) -> bool:
|
||||||
|
"""Return whether this metric requires reference text."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self, candidate: str, reference: str | list[str] | None = None
|
||||||
|
) -> RougeResult:
|
||||||
|
"""
|
||||||
|
Compute ROUGE scores for a candidate text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate: The text to score.
|
||||||
|
reference: Reference text(s) for comparison. If multiple references
|
||||||
|
are provided, returns the maximum score for each variant.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RougeResult with ROUGE-1, ROUGE-2, and ROUGE-L scores.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If reference is None or empty.
|
||||||
|
"""
|
||||||
|
if reference is None:
|
||||||
|
raise ValueError("ROUGE requires reference text")
|
||||||
|
|
||||||
|
# Normalise reference to list
|
||||||
|
references = [reference] if isinstance(reference, str) else reference
|
||||||
|
|
||||||
|
# Tokenise
|
||||||
|
candidate_tokens = self._tokeniser.tokenise(candidate)
|
||||||
|
reference_token_lists = [self._tokeniser.tokenise(r) for r in references]
|
||||||
|
|
||||||
|
# Handle empty references
|
||||||
|
if all(not ref for ref in reference_token_lists):
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
# Handle empty candidate
|
||||||
|
if not candidate_tokens:
|
||||||
|
return RougeResult(
|
||||||
|
rouge1=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||||
|
rouge2=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||||
|
rouge_l=RougeScore(precision=0.0, recall=0.0, fmeasure=0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute scores for each reference and take max
|
||||||
|
rouge1_scores = []
|
||||||
|
rouge2_scores = []
|
||||||
|
rouge_l_scores = []
|
||||||
|
|
||||||
|
for ref_tokens in reference_token_lists:
|
||||||
|
if not ref_tokens:
|
||||||
|
continue
|
||||||
|
rouge1_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 1))
|
||||||
|
rouge2_scores.append(_compute_rouge_score(candidate_tokens, ref_tokens, 2))
|
||||||
|
rouge_l_scores.append(_compute_rouge_l(candidate_tokens, ref_tokens))
|
||||||
|
|
||||||
|
# All references were empty after tokenisation
|
||||||
|
if not rouge1_scores:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
return RougeResult(
|
||||||
|
rouge1=_max_rouge_scores(rouge1_scores),
|
||||||
|
rouge2=_max_rouge_scores(rouge2_scores),
|
||||||
|
rouge_l=_max_rouge_scores(rouge_l_scores),
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]] | None = None,
|
||||||
|
) -> BatchResult[RougeResult]:
|
||||||
|
"""
|
||||||
|
Compute ROUGE scores for a batch of candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of texts to score.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchResult containing individual results and aggregate statistics.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If references is None or length mismatch.
|
||||||
|
"""
|
||||||
|
if references is None:
|
||||||
|
raise ValueError("ROUGE requires reference texts")
|
||||||
|
|
||||||
|
if len(candidates) != len(references):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of candidates ({len(candidates)}) must match "
|
||||||
|
f"number of references ({len(references)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[RougeResult] = []
|
||||||
|
for i, cand in enumerate(candidates):
|
||||||
|
ref: str | list[str] = references[i]
|
||||||
|
results.append(self.score(cand, ref))
|
||||||
|
|
||||||
|
# Compute aggregate statistics for each score type
|
||||||
|
stats = {
|
||||||
|
"rouge1_precision": AggregateStats.from_values(
|
||||||
|
[r.rouge1.precision for r in results]
|
||||||
|
),
|
||||||
|
"rouge1_recall": AggregateStats.from_values(
|
||||||
|
[r.rouge1.recall for r in results]
|
||||||
|
),
|
||||||
|
"rouge1_fmeasure": AggregateStats.from_values(
|
||||||
|
[r.rouge1.fmeasure for r in results]
|
||||||
|
),
|
||||||
|
"rouge2_precision": AggregateStats.from_values(
|
||||||
|
[r.rouge2.precision for r in results]
|
||||||
|
),
|
||||||
|
"rouge2_recall": AggregateStats.from_values(
|
||||||
|
[r.rouge2.recall for r in results]
|
||||||
|
),
|
||||||
|
"rouge2_fmeasure": AggregateStats.from_values(
|
||||||
|
[r.rouge2.fmeasure for r in results]
|
||||||
|
),
|
||||||
|
"rouge_l_precision": AggregateStats.from_values(
|
||||||
|
[r.rouge_l.precision for r in results]
|
||||||
|
),
|
||||||
|
"rouge_l_recall": AggregateStats.from_values(
|
||||||
|
[r.rouge_l.recall for r in results]
|
||||||
|
),
|
||||||
|
"rouge_l_fmeasure": AggregateStats.from_values(
|
||||||
|
[r.rouge_l.fmeasure for r in results]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BatchResult(results=results, count=len(results), stats=stats)
|
||||||
22
src/veritext/pytest_plugin/__init__.py
Normal file
22
src/veritext/pytest_plugin/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Pytest plugin for text validation.
|
||||||
|
|
||||||
|
This plugin provides native pytest integration for Veritext, enabling
|
||||||
|
text validation assertions in test suites.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from veritext.pytest_plugin import validate_text
|
||||||
|
>>>
|
||||||
|
>>> def test_summary_quality():
|
||||||
|
... text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
... validate_text(
|
||||||
|
... text,
|
||||||
|
... min_length=10,
|
||||||
|
... max_length=100,
|
||||||
|
... max_reading_grade=8.0,
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
from veritext.pytest_plugin.assertions import validate_text
|
||||||
|
from veritext.pytest_plugin.plugin import pytest_configure
|
||||||
|
|
||||||
|
__all__ = ["pytest_configure", "validate_text"]
|
||||||
141
src/veritext/pytest_plugin/assertions.py
Normal file
141
src/veritext/pytest_plugin/assertions.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Assertion functions for text validation in pytest."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext, ValidationResult
|
||||||
|
from veritext.validators import all_of
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from veritext.validators.base import Check
|
||||||
|
|
||||||
|
|
||||||
|
def validate_text(
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
reference: str | list[str] | None = None,
|
||||||
|
min_bleu: float | None = None,
|
||||||
|
min_rouge: float | None = None,
|
||||||
|
min_semantic: float | None = None,
|
||||||
|
max_length: int | None = None,
|
||||||
|
min_length: int | None = None,
|
||||||
|
max_reading_grade: float | None = None,
|
||||||
|
must_contain: list[str] | None = None,
|
||||||
|
must_exclude: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Assert text passes all specified validation criteria.
|
||||||
|
|
||||||
|
This is the primary assertion function for text validation in pytest.
|
||||||
|
It builds validators from keyword arguments and raises AssertionError
|
||||||
|
with detailed failure information if validation fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
reference: Reference text for comparison metrics (BLEU, ROUGE, semantic).
|
||||||
|
min_bleu: Minimum BLEU-4 score required (0.0 to 1.0).
|
||||||
|
min_rouge: Minimum ROUGE-L F-measure required (0.0 to 1.0).
|
||||||
|
min_semantic: Minimum semantic similarity required (0.0 to 1.0).
|
||||||
|
max_length: Maximum character count allowed.
|
||||||
|
min_length: Minimum character count required.
|
||||||
|
max_reading_grade: Maximum Flesch-Kincaid grade level.
|
||||||
|
must_contain: Patterns that must be present in the text.
|
||||||
|
must_exclude: Patterns that must not be present in the text.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: With detailed failure information if validation fails.
|
||||||
|
ValueError: If comparison metrics requested but reference not provided,
|
||||||
|
or if no validation criteria are specified.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> validate_text(
|
||||||
|
... "The quick brown fox jumps over the lazy dog.",
|
||||||
|
... min_length=10,
|
||||||
|
... max_length=100,
|
||||||
|
... max_reading_grade=8.0,
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
# Validate that reference is provided for comparison metrics
|
||||||
|
if any([min_bleu, min_rouge, min_semantic]) and reference is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Reference text required for comparison metrics "
|
||||||
|
"(min_bleu, min_rouge, min_semantic)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build list of validators from kwargs
|
||||||
|
checks: list[Check] = []
|
||||||
|
|
||||||
|
if min_bleu is not None:
|
||||||
|
from veritext.validators import bleu
|
||||||
|
|
||||||
|
checks.append(bleu(min_score=min_bleu))
|
||||||
|
|
||||||
|
if min_rouge is not None:
|
||||||
|
from veritext.validators import rouge
|
||||||
|
|
||||||
|
checks.append(rouge(min_score=min_rouge))
|
||||||
|
|
||||||
|
if min_semantic is not None:
|
||||||
|
# Lazy import to avoid loading sentence-transformers unless needed
|
||||||
|
from veritext.validators import semantic
|
||||||
|
|
||||||
|
checks.append(semantic(min_score=min_semantic))
|
||||||
|
|
||||||
|
if max_length is not None or min_length is not None:
|
||||||
|
from veritext.validators import length
|
||||||
|
|
||||||
|
checks.append(length(min_chars=min_length, max_chars=max_length))
|
||||||
|
|
||||||
|
if max_reading_grade is not None:
|
||||||
|
from veritext.validators import readability
|
||||||
|
|
||||||
|
checks.append(readability(max_grade=max_reading_grade))
|
||||||
|
|
||||||
|
if must_contain is not None:
|
||||||
|
from veritext.validators import contains
|
||||||
|
|
||||||
|
checks.append(contains(patterns=must_contain))
|
||||||
|
|
||||||
|
if must_exclude is not None:
|
||||||
|
from veritext.validators import excludes
|
||||||
|
|
||||||
|
checks.append(excludes(patterns=must_exclude))
|
||||||
|
|
||||||
|
if not checks:
|
||||||
|
raise ValueError("At least one validation criterion must be specified")
|
||||||
|
|
||||||
|
# Run validation
|
||||||
|
context = ValidationContext(reference=reference)
|
||||||
|
validator = all_of(checks)
|
||||||
|
result = validator.check(text, context)
|
||||||
|
|
||||||
|
if not result.passed:
|
||||||
|
raise AssertionError(_format_failure(text, result))
|
||||||
|
|
||||||
|
|
||||||
|
def _format_failure(text: str, result: ValidationResult) -> str:
|
||||||
|
"""Format a detailed failure message for pytest output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text that was validated.
|
||||||
|
result: The validation result containing check failures.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted failure message with check details.
|
||||||
|
"""
|
||||||
|
lines = ["Text validation failed:"]
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Show a preview of the text (truncated if long)
|
||||||
|
preview = text[:100] + "..." if len(text) > 100 else text
|
||||||
|
lines.append(f" Text: {preview!r}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# List all failed checks with details
|
||||||
|
lines.append(" Failed checks:")
|
||||||
|
for check in result.failed_checks:
|
||||||
|
lines.append(f" - {check.name}:")
|
||||||
|
lines.append(f" {check.message}")
|
||||||
|
if check.threshold is not None:
|
||||||
|
lines.append(f" Expected: >= {check.threshold}")
|
||||||
|
lines.append(f" Actual: {check.actual}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
80
src/veritext/pytest_plugin/fixtures.py
Normal file
80
src/veritext/pytest_plugin/fixtures.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Pytest fixtures for text validation."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext, ValidationResult
|
||||||
|
from veritext.validators import all_of
|
||||||
|
from veritext.validators.base import Check
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
|
||||||
|
class ValidatorFactory:
|
||||||
|
"""Factory for building validators from keyword arguments."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
checks: list[Check],
|
||||||
|
reference: str | list[str] | None = None,
|
||||||
|
) -> "Callable[[str], ValidationResult]":
|
||||||
|
"""Create a validator function from a list of checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of validation checks to apply.
|
||||||
|
reference: Optional reference text for comparison metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A callable that takes text and returns a ValidationResult.
|
||||||
|
"""
|
||||||
|
validator = all_of(checks)
|
||||||
|
context = ValidationContext(reference=reference)
|
||||||
|
|
||||||
|
def validate(text: str) -> ValidationResult:
|
||||||
|
return validator.check(text, context)
|
||||||
|
|
||||||
|
return validate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_validator() -> ValidatorFactory:
|
||||||
|
"""Provide a factory for building validators.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> def test_with_factory(text_validator):
|
||||||
|
... from veritext.validators import bleu, length
|
||||||
|
... validate = text_validator(
|
||||||
|
... checks=[bleu(min_score=0.5), length(min_words=10)],
|
||||||
|
... reference="The reference text.",
|
||||||
|
... )
|
||||||
|
... result = validate("Some candidate text.")
|
||||||
|
... assert result.passed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidatorFactory instance.
|
||||||
|
"""
|
||||||
|
return ValidatorFactory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validation_context() -> "Callable[..., ValidationContext]":
|
||||||
|
"""Provide a factory for creating ValidationContext objects.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> def test_with_context(validation_context):
|
||||||
|
... ctx = validation_context(reference="The reference text.")
|
||||||
|
... assert ctx.reference == "The reference text."
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A callable that creates ValidationContext objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _create(
|
||||||
|
reference: str | list[str] | None = None,
|
||||||
|
**metadata: Any,
|
||||||
|
) -> ValidationContext:
|
||||||
|
return ValidationContext(reference=reference, metadata=metadata)
|
||||||
|
|
||||||
|
return _create
|
||||||
18
src/veritext/pytest_plugin/plugin.py
Normal file
18
src/veritext/pytest_plugin/plugin.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Pytest hooks for Veritext plugin."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config: "pytest.Config") -> None:
|
||||||
|
"""Register Veritext markers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Pytest configuration object.
|
||||||
|
"""
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
"text_validation: mark test as a text validation test",
|
||||||
|
)
|
||||||
16
src/veritext/semantic/__init__.py
Normal file
16
src/veritext/semantic/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Semantic similarity module: embedding-based text comparison.
|
||||||
|
|
||||||
|
This module provides semantic similarity using sentence-transformers.
|
||||||
|
It requires the `veritext[semantic]` extra to be installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from veritext.semantic import SemanticSimilarity
|
||||||
|
>>>
|
||||||
|
>>> metric = SemanticSimilarity()
|
||||||
|
>>> result = metric.score("The cat sat on the mat", "A feline rested on the rug")
|
||||||
|
>>> print(f"Similarity: {result.similarity:.2f}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from veritext.semantic.similarity import SemanticSimilarity
|
||||||
|
|
||||||
|
__all__ = ["SemanticSimilarity"]
|
||||||
203
src/veritext/semantic/similarity.py
Normal file
203
src/veritext/semantic/similarity.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Embedding-based semantic similarity using sentence-transformers."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from veritext.core.exceptions import DependencyError
|
||||||
|
from veritext.metrics.base import AggregateStats, BatchResult
|
||||||
|
from veritext.metrics.results import SemanticResult
|
||||||
|
|
||||||
|
# Default maximum cache size (number of embeddings to store)
|
||||||
|
DEFAULT_CACHE_MAX_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticSimilarity:
|
||||||
|
"""
|
||||||
|
Embedding-based semantic similarity using sentence-transformers.
|
||||||
|
|
||||||
|
Computes cosine similarity between text embeddings to measure semantic
|
||||||
|
relatedness. This metric captures meaning beyond lexical overlap.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
cache_max_size: int = DEFAULT_CACHE_MAX_SIZE,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the semantic similarity metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
Defaults to "all-MiniLM-L6-v2" (22MB, good quality/size tradeoff).
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
Defaults to True.
|
||||||
|
cache_max_size: Maximum number of embeddings to cache. Oldest entries
|
||||||
|
are evicted when the limit is reached. Defaults to 1000.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyError: If sentence-transformers is not installed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
except ImportError as err:
|
||||||
|
raise DependencyError(
|
||||||
|
"Install veritext[semantic] for semantic similarity: "
|
||||||
|
"pip install veritext[semantic]"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
self._model_name = model
|
||||||
|
self._model: Any = SentenceTransformer(model)
|
||||||
|
self._cache: OrderedDict[str, Any] | None = (
|
||||||
|
OrderedDict() if cache_embeddings else None
|
||||||
|
)
|
||||||
|
self._cache_max_size = cache_max_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this metric."""
|
||||||
|
return "semantic"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_reference(self) -> bool:
|
||||||
|
"""Return whether this metric requires reference text."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_embedding(self, text: str) -> Any:
|
||||||
|
"""
|
||||||
|
Get embedding for text, using LRU cache if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The embedding tensor.
|
||||||
|
"""
|
||||||
|
if self._cache is not None and text in self._cache:
|
||||||
|
# Move to end to mark as recently used
|
||||||
|
self._cache.move_to_end(text)
|
||||||
|
return self._cache[text]
|
||||||
|
|
||||||
|
embedding = self._model.encode(text, convert_to_tensor=True)
|
||||||
|
|
||||||
|
if self._cache is not None:
|
||||||
|
# Evict oldest entries if cache is full
|
||||||
|
while len(self._cache) >= self._cache_max_size:
|
||||||
|
self._cache.popitem(last=False)
|
||||||
|
self._cache[text] = embedding
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def _cosine_similarity(self, embedding1: Any, embedding2: Any) -> float:
|
||||||
|
"""
|
||||||
|
Compute cosine similarity between two embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding1: First embedding tensor.
|
||||||
|
embedding2: Second embedding tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity score (0.0 to 1.0).
|
||||||
|
"""
|
||||||
|
from sentence_transformers import util
|
||||||
|
|
||||||
|
similarity: float = util.cos_sim(embedding1, embedding2).item()
|
||||||
|
# Clamp to [0, 1] as negative similarities are possible but not meaningful
|
||||||
|
return max(0.0, min(1.0, similarity))
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self, candidate: str, reference: str | list[str] | None = None
|
||||||
|
) -> SemanticResult:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity between candidate and reference.
|
||||||
|
|
||||||
|
When multiple references are provided, returns the maximum similarity
|
||||||
|
across all references.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidate: The text to score.
|
||||||
|
reference: Reference text(s) for comparison.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SemanticResult with similarity score and model name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If reference is None or empty.
|
||||||
|
"""
|
||||||
|
if reference is None:
|
||||||
|
raise ValueError("Semantic similarity requires reference text")
|
||||||
|
|
||||||
|
# Normalise reference to list
|
||||||
|
references = [reference] if isinstance(reference, str) else reference
|
||||||
|
|
||||||
|
if not references:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
# Handle empty candidate
|
||||||
|
candidate_stripped = candidate.strip()
|
||||||
|
if not candidate_stripped:
|
||||||
|
return SemanticResult(similarity=0.0, model=self._model_name)
|
||||||
|
|
||||||
|
# Handle empty references
|
||||||
|
valid_references = [r for r in references if r.strip()]
|
||||||
|
if not valid_references:
|
||||||
|
raise ValueError("Reference text cannot be empty")
|
||||||
|
|
||||||
|
# Get candidate embedding
|
||||||
|
candidate_embedding = self._get_embedding(candidate_stripped)
|
||||||
|
|
||||||
|
# Compute similarity against each reference, take maximum
|
||||||
|
max_similarity = 0.0
|
||||||
|
for ref in valid_references:
|
||||||
|
ref_embedding = self._get_embedding(ref.strip())
|
||||||
|
similarity = self._cosine_similarity(candidate_embedding, ref_embedding)
|
||||||
|
max_similarity = max(max_similarity, similarity)
|
||||||
|
|
||||||
|
return SemanticResult(similarity=max_similarity, model=self._model_name)
|
||||||
|
|
||||||
|
def batch_score(
|
||||||
|
self,
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str] | list[list[str]] | None = None,
|
||||||
|
) -> BatchResult[SemanticResult]:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity for a batch of candidates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of texts to score.
|
||||||
|
references: Reference text(s) for each candidate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchResult containing individual results and aggregate statistics.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If references is None or length mismatch.
|
||||||
|
"""
|
||||||
|
if references is None:
|
||||||
|
raise ValueError("Semantic similarity requires reference texts")
|
||||||
|
|
||||||
|
if len(candidates) != len(references):
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of candidates ({len(candidates)}) must match "
|
||||||
|
f"number of references ({len(references)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[SemanticResult] = []
|
||||||
|
for i, cand in enumerate(candidates):
|
||||||
|
ref: str | list[str] = references[i]
|
||||||
|
results.append(self.score(cand, ref))
|
||||||
|
|
||||||
|
# Compute aggregate statistics
|
||||||
|
stats = {
|
||||||
|
"similarity": AggregateStats.from_values([r.similarity for r in results]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BatchResult(results=results, count=len(results), stats=stats)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the embedding cache."""
|
||||||
|
if self._cache is not None:
|
||||||
|
self._cache.clear()
|
||||||
239
src/veritext/validators/__init__.py
Normal file
239
src/veritext/validators/__init__.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""Validators module: composable validation checks for text quality.
|
||||||
|
|
||||||
|
This module provides validators that apply thresholds to metrics and return
|
||||||
|
pass/fail decisions with diagnostics.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from veritext.validators import bleu, length, all_of
|
||||||
|
>>> from veritext.core.types import ValidationContext
|
||||||
|
>>>
|
||||||
|
>>> validator = all_of([
|
||||||
|
... bleu(min_score=0.5),
|
||||||
|
... length(min_words=10),
|
||||||
|
... ])
|
||||||
|
>>> context = ValidationContext(reference="The quick brown fox.")
|
||||||
|
>>> result = validator.check("The quick brown fox jumps.", context)
|
||||||
|
>>> print(result.passed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.validators.base import Check
|
||||||
|
from veritext.validators.composite import AllOf, AnyOf
|
||||||
|
from veritext.validators.constraint import (
|
||||||
|
ContainsValidator,
|
||||||
|
ExcludesValidator,
|
||||||
|
LengthValidator,
|
||||||
|
ReadabilityValidator,
|
||||||
|
)
|
||||||
|
from veritext.validators.metric import (
|
||||||
|
BleuValidator,
|
||||||
|
LexicalValidator,
|
||||||
|
RougeValidator,
|
||||||
|
SemanticValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Factory functions for clean API
|
||||||
|
def bleu(
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal[1, 2, 3, 4] = 4,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> BleuValidator:
|
||||||
|
"""Create a BLEU validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum BLEU score required (0.0 to 1.0).
|
||||||
|
variant: BLEU variant to use (1, 2, 3, or 4). Defaults to 4.
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BleuValidator instance.
|
||||||
|
"""
|
||||||
|
return BleuValidator(min_score=min_score, variant=variant, tokeniser=tokeniser)
|
||||||
|
|
||||||
|
|
||||||
|
def rouge(
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal["1", "2", "l"] = "l",
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> RougeValidator:
|
||||||
|
"""Create a ROUGE validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum ROUGE F-measure required (0.0 to 1.0).
|
||||||
|
variant: ROUGE variant ("1", "2", or "l"). Defaults to "l".
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RougeValidator instance.
|
||||||
|
"""
|
||||||
|
return RougeValidator(min_score=min_score, variant=variant, tokeniser=tokeniser)
|
||||||
|
|
||||||
|
|
||||||
|
def lexical(
|
||||||
|
min_jaccard: float | None = None,
|
||||||
|
min_overlap: float | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> LexicalValidator:
|
||||||
|
"""Create a lexical similarity validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_jaccard: Minimum Jaccard similarity required (0.0 to 1.0).
|
||||||
|
min_overlap: Minimum token overlap required (0.0 to 1.0).
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LexicalValidator instance.
|
||||||
|
"""
|
||||||
|
return LexicalValidator(
|
||||||
|
min_jaccard=min_jaccard, min_overlap=min_overlap, tokeniser=tokeniser
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def length(
|
||||||
|
min_chars: int | None = None,
|
||||||
|
max_chars: int | None = None,
|
||||||
|
min_words: int | None = None,
|
||||||
|
max_words: int | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> LengthValidator:
|
||||||
|
"""Create a length validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_chars: Minimum character count (inclusive).
|
||||||
|
max_chars: Maximum character count (inclusive).
|
||||||
|
min_words: Minimum word count (inclusive).
|
||||||
|
max_words: Maximum word count (inclusive).
|
||||||
|
tokeniser: Tokeniser to use for word counting. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LengthValidator instance.
|
||||||
|
"""
|
||||||
|
return LengthValidator(
|
||||||
|
min_chars=min_chars,
|
||||||
|
max_chars=max_chars,
|
||||||
|
min_words=min_words,
|
||||||
|
max_words=max_words,
|
||||||
|
tokeniser=tokeniser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def readability(
|
||||||
|
max_grade: float | None = None,
|
||||||
|
min_ease: float | None = None,
|
||||||
|
) -> ReadabilityValidator:
|
||||||
|
"""Create a readability validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_grade: Maximum Flesch-Kincaid grade level allowed.
|
||||||
|
min_ease: Minimum Flesch Reading Ease score required.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadabilityValidator instance.
|
||||||
|
"""
|
||||||
|
return ReadabilityValidator(max_grade=max_grade, min_ease=min_ease)
|
||||||
|
|
||||||
|
|
||||||
|
def contains(
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> ContainsValidator:
|
||||||
|
"""Create a contains validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContainsValidator instance.
|
||||||
|
"""
|
||||||
|
return ContainsValidator(patterns=patterns, case_sensitive=case_sensitive)
|
||||||
|
|
||||||
|
|
||||||
|
def excludes(
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> ExcludesValidator:
|
||||||
|
"""Create an excludes validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must not be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExcludesValidator instance.
|
||||||
|
"""
|
||||||
|
return ExcludesValidator(patterns=patterns, case_sensitive=case_sensitive)
|
||||||
|
|
||||||
|
|
||||||
|
def all_of(checks: list[Check]) -> AllOf:
|
||||||
|
"""Create an AllOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks that must all pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AllOf instance.
|
||||||
|
"""
|
||||||
|
return AllOf(checks=checks)
|
||||||
|
|
||||||
|
|
||||||
|
def any_of(checks: list[Check]) -> AnyOf:
|
||||||
|
"""Create an AnyOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks where at least one must pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnyOf instance.
|
||||||
|
"""
|
||||||
|
return AnyOf(checks=checks)
|
||||||
|
|
||||||
|
|
||||||
|
def semantic(
|
||||||
|
min_score: float,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
) -> SemanticValidator:
|
||||||
|
"""Create a semantic similarity validator.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum semantic similarity required (0.0 to 1.0).
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SemanticValidator instance.
|
||||||
|
"""
|
||||||
|
return SemanticValidator(
|
||||||
|
min_score=min_score, model=model, cache_embeddings=cache_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AllOf",
|
||||||
|
"AnyOf",
|
||||||
|
"BleuValidator",
|
||||||
|
"Check",
|
||||||
|
"ContainsValidator",
|
||||||
|
"ExcludesValidator",
|
||||||
|
"LengthValidator",
|
||||||
|
"LexicalValidator",
|
||||||
|
"ReadabilityValidator",
|
||||||
|
"RougeValidator",
|
||||||
|
"SemanticValidator",
|
||||||
|
"all_of",
|
||||||
|
"any_of",
|
||||||
|
"bleu",
|
||||||
|
"contains",
|
||||||
|
"excludes",
|
||||||
|
"length",
|
||||||
|
"lexical",
|
||||||
|
"readability",
|
||||||
|
"rouge",
|
||||||
|
"semantic",
|
||||||
|
]
|
||||||
31
src/veritext/validators/base.py
Normal file
31
src/veritext/validators/base.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Base types and protocols for validation checks."""
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Check(Protocol):
|
||||||
|
"""Protocol for validation checks.
|
||||||
|
|
||||||
|
A Check computes a score or property of text and compares it against
|
||||||
|
a threshold to produce a pass/fail result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""Run the check and return a result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status and diagnostics.
|
||||||
|
"""
|
||||||
|
...
|
||||||
102
src/veritext/validators/composite.py
Normal file
102
src/veritext/validators/composite.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Composite validators for combining multiple checks.
|
||||||
|
|
||||||
|
Note: CompositeCheck classes (AllOf, AnyOf) intentionally return ValidationResult
|
||||||
|
rather than CheckResult. This allows callers to inspect individual check results
|
||||||
|
for detailed error reporting. They implement a compatible interface but are not
|
||||||
|
substitutable where Check is expected as a type constraint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext, ValidationResult
|
||||||
|
from veritext.validators.base import Check
|
||||||
|
|
||||||
|
|
||||||
|
class AllOf:
|
||||||
|
"""Passes only if all checks pass.
|
||||||
|
|
||||||
|
Note: Returns ValidationResult (not CheckResult) to expose child results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, checks: list[Check]) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the AllOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks that must all pass.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If checks list is empty.
|
||||||
|
"""
|
||||||
|
if not checks:
|
||||||
|
raise ValueError("checks list cannot be empty")
|
||||||
|
|
||||||
|
self._checks = list(checks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this composite check."""
|
||||||
|
return "all_of"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Run all checks and return aggregate result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult that passes only if all checks pass.
|
||||||
|
"""
|
||||||
|
results: list[CheckResult] = []
|
||||||
|
for check in self._checks:
|
||||||
|
results.append(check.check(text, context))
|
||||||
|
|
||||||
|
all_passed = all(r.passed for r in results)
|
||||||
|
|
||||||
|
return ValidationResult(passed=all_passed, checks=results)
|
||||||
|
|
||||||
|
|
||||||
|
class AnyOf:
|
||||||
|
"""Passes if any check passes.
|
||||||
|
|
||||||
|
Note: Returns ValidationResult (not CheckResult) to expose child results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, checks: list[Check]) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the AnyOf composite validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checks: List of checks where at least one must pass.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If checks list is empty.
|
||||||
|
"""
|
||||||
|
if not checks:
|
||||||
|
raise ValueError("checks list cannot be empty")
|
||||||
|
|
||||||
|
self._checks = list(checks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this composite check."""
|
||||||
|
return "any_of"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Run all checks and return aggregate result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult that passes if any check passes.
|
||||||
|
"""
|
||||||
|
results: list[CheckResult] = []
|
||||||
|
for check in self._checks:
|
||||||
|
results.append(check.check(text, context))
|
||||||
|
|
||||||
|
any_passed = any(r.passed for r in results)
|
||||||
|
|
||||||
|
return ValidationResult(passed=any_passed, checks=results)
|
||||||
359
src/veritext/validators/constraint.py
Normal file
359
src/veritext/validators/constraint.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""Constraint validators that do not require reference text."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext
|
||||||
|
from veritext.metrics.readability import Readability
|
||||||
|
|
||||||
|
|
||||||
|
class LengthValidator:
|
||||||
|
"""Validates text length constraints."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_chars: int | None = None,
|
||||||
|
max_chars: int | None = None,
|
||||||
|
min_words: int | None = None,
|
||||||
|
max_words: int | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the length validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_chars: Minimum character count (inclusive).
|
||||||
|
max_chars: Maximum character count (inclusive).
|
||||||
|
min_words: Minimum word count (inclusive).
|
||||||
|
max_words: Maximum word count (inclusive).
|
||||||
|
tokeniser: Tokeniser to use for word counting. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If no constraints provided or invalid values.
|
||||||
|
"""
|
||||||
|
if all(v is None for v in (min_chars, max_chars, min_words, max_words)):
|
||||||
|
raise InvalidThresholdError("At least one length constraint must be set")
|
||||||
|
|
||||||
|
if min_chars is not None and min_chars < 0:
|
||||||
|
raise InvalidThresholdError(f"min_chars must be >= 0, got {min_chars}")
|
||||||
|
if max_chars is not None and max_chars < 0:
|
||||||
|
raise InvalidThresholdError(f"max_chars must be >= 0, got {max_chars}")
|
||||||
|
if min_words is not None and min_words < 0:
|
||||||
|
raise InvalidThresholdError(f"min_words must be >= 0, got {min_words}")
|
||||||
|
if max_words is not None and max_words < 0:
|
||||||
|
raise InvalidThresholdError(f"max_words must be >= 0, got {max_words}")
|
||||||
|
|
||||||
|
if min_chars is not None and max_chars is not None and min_chars > max_chars:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_chars ({min_chars}) cannot exceed max_chars ({max_chars})"
|
||||||
|
)
|
||||||
|
if min_words is not None and max_words is not None and min_words > max_words:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_words ({min_words}) cannot exceed max_words ({max_words})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_chars = min_chars
|
||||||
|
self._max_chars = max_chars
|
||||||
|
self._min_words = min_words
|
||||||
|
self._max_words = max_words
|
||||||
|
self._tokeniser = tokeniser or WordTokeniser()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "length"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the length check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for length checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
char_count = len(text)
|
||||||
|
words = self._tokeniser.tokenise(text)
|
||||||
|
word_count = len(words)
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
if self._min_chars is not None and char_count < self._min_chars:
|
||||||
|
failures.append(f"{char_count} chars < min {self._min_chars}")
|
||||||
|
if self._max_chars is not None and char_count > self._max_chars:
|
||||||
|
failures.append(f"{char_count} chars > max {self._max_chars}")
|
||||||
|
if self._min_words is not None and word_count < self._min_words:
|
||||||
|
failures.append(f"{word_count} words < min {self._min_words}")
|
||||||
|
if self._max_words is not None and word_count > self._max_words:
|
||||||
|
failures.append(f"{word_count} words > max {self._max_words}")
|
||||||
|
|
||||||
|
passed = len(failures) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = f"Length check passed: {char_count} chars, {word_count} words"
|
||||||
|
else:
|
||||||
|
message = "Length check failed: " + "; ".join(failures)
|
||||||
|
|
||||||
|
actual = {"chars": char_count, "words": word_count}
|
||||||
|
threshold = {}
|
||||||
|
if self._min_chars is not None:
|
||||||
|
threshold["min_chars"] = self._min_chars
|
||||||
|
if self._max_chars is not None:
|
||||||
|
threshold["max_chars"] = self._max_chars
|
||||||
|
if self._min_words is not None:
|
||||||
|
threshold["min_words"] = self._min_words
|
||||||
|
if self._max_words is not None:
|
||||||
|
threshold["max_words"] = self._max_words
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual,
|
||||||
|
threshold=threshold,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadabilityValidator:
|
||||||
|
"""Validates Flesch-Kincaid readability."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_grade: float | None = None,
|
||||||
|
min_ease: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the readability validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_grade: Maximum Flesch-Kincaid grade level allowed.
|
||||||
|
min_ease: Minimum Flesch Reading Ease score required.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If no constraints provided.
|
||||||
|
"""
|
||||||
|
if max_grade is None and min_ease is None:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
"At least one of max_grade or min_ease must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._max_grade = max_grade
|
||||||
|
self._min_ease = min_ease
|
||||||
|
self._metric = Readability()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "readability"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the readability check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for readability checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
result = self._metric.score(text)
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
if (
|
||||||
|
self._max_grade is not None
|
||||||
|
and result.flesch_kincaid_grade > self._max_grade
|
||||||
|
):
|
||||||
|
failures.append(
|
||||||
|
f"grade level {result.flesch_kincaid_grade:.1f} "
|
||||||
|
f"> max {self._max_grade:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._min_ease is not None and result.flesch_reading_ease < self._min_ease:
|
||||||
|
failures.append(
|
||||||
|
f"reading ease {result.flesch_reading_ease:.1f} "
|
||||||
|
f"< min {self._min_ease:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
passed = len(failures) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
parts = []
|
||||||
|
if self._max_grade is not None:
|
||||||
|
parts.append(
|
||||||
|
f"grade {result.flesch_kincaid_grade:.1f} <= {self._max_grade:.1f}"
|
||||||
|
)
|
||||||
|
if self._min_ease is not None:
|
||||||
|
parts.append(
|
||||||
|
f"ease {result.flesch_reading_ease:.1f} >= {self._min_ease:.1f}"
|
||||||
|
)
|
||||||
|
message = "Readability: " + ", ".join(parts)
|
||||||
|
else:
|
||||||
|
message = "Readability: " + "; ".join(failures)
|
||||||
|
|
||||||
|
actual = {
|
||||||
|
"grade": result.flesch_kincaid_grade,
|
||||||
|
"ease": result.flesch_reading_ease,
|
||||||
|
}
|
||||||
|
threshold = {}
|
||||||
|
if self._max_grade is not None:
|
||||||
|
threshold["max_grade"] = self._max_grade
|
||||||
|
if self._min_ease is not None:
|
||||||
|
threshold["min_ease"] = self._min_ease
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual,
|
||||||
|
threshold=threshold,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContainsValidator:
|
||||||
|
"""Validates text contains required patterns."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the contains validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If patterns list is empty or contains invalid regex.
|
||||||
|
"""
|
||||||
|
if not patterns:
|
||||||
|
raise InvalidThresholdError("patterns list cannot be empty")
|
||||||
|
|
||||||
|
self._patterns = patterns
|
||||||
|
self._case_sensitive = case_sensitive
|
||||||
|
self._flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
|
||||||
|
self._compiled_patterns: list[re.Pattern[str]] = []
|
||||||
|
for pattern in patterns:
|
||||||
|
try:
|
||||||
|
self._compiled_patterns.append(re.compile(pattern, self._flags))
|
||||||
|
except re.error as e:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"Invalid regex pattern '{pattern}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "contains"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the contains check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for contains checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
missing = []
|
||||||
|
for pattern, compiled in zip(
|
||||||
|
self._patterns, self._compiled_patterns, strict=True
|
||||||
|
):
|
||||||
|
if not compiled.search(text):
|
||||||
|
missing.append(pattern)
|
||||||
|
|
||||||
|
passed = len(missing) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = f"Text contains all {len(self._patterns)} required pattern(s)"
|
||||||
|
else:
|
||||||
|
message = f"Text missing {len(missing)} pattern(s): {missing}"
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual={"found": len(self._patterns) - len(missing), "missing": missing},
|
||||||
|
threshold={"patterns": self._patterns},
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExcludesValidator:
|
||||||
|
"""Validates text excludes forbidden patterns."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patterns: list[str],
|
||||||
|
case_sensitive: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the excludes validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patterns: List of substrings or regex patterns that must not be present.
|
||||||
|
case_sensitive: Whether matching is case-sensitive. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If patterns list is empty or contains invalid regex.
|
||||||
|
"""
|
||||||
|
if not patterns:
|
||||||
|
raise InvalidThresholdError("patterns list cannot be empty")
|
||||||
|
|
||||||
|
self._patterns = patterns
|
||||||
|
self._case_sensitive = case_sensitive
|
||||||
|
self._flags = 0 if case_sensitive else re.IGNORECASE
|
||||||
|
|
||||||
|
self._compiled_patterns: list[re.Pattern[str]] = []
|
||||||
|
for pattern in patterns:
|
||||||
|
try:
|
||||||
|
self._compiled_patterns.append(re.compile(pattern, self._flags))
|
||||||
|
except re.error as e:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"Invalid regex pattern '{pattern}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "excludes"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult: # noqa: ARG002
|
||||||
|
"""
|
||||||
|
Run the excludes check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context (not used for excludes checks).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
"""
|
||||||
|
found = []
|
||||||
|
for pattern, compiled in zip(
|
||||||
|
self._patterns, self._compiled_patterns, strict=True
|
||||||
|
):
|
||||||
|
if compiled.search(text):
|
||||||
|
found.append(pattern)
|
||||||
|
|
||||||
|
passed = len(found) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = f"Text excludes all {len(self._patterns)} forbidden pattern(s)"
|
||||||
|
else:
|
||||||
|
message = f"Text contains {len(found)} forbidden pattern(s): {found}"
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual={"excluded": len(self._patterns) - len(found), "found": found},
|
||||||
|
threshold={"patterns": self._patterns},
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
370
src/veritext/validators/metric.py
Normal file
370
src/veritext/validators/metric.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""Metric-based validators that require reference text."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError, ValidationError
|
||||||
|
from veritext.core.tokenisation import WordTokeniser
|
||||||
|
from veritext.core.types import CheckResult, ValidationContext
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.lexical import Lexical
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
|
|
||||||
|
class BleuValidator:
|
||||||
|
"""Validates that BLEU score meets minimum threshold."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal[1, 2, 3, 4] = 4,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the BLEU validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum BLEU score required (0.0 to 1.0).
|
||||||
|
variant: BLEU variant to use (1, 2, 3, or 4). Defaults to 4.
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||||
|
"""
|
||||||
|
if not 0.0 <= min_score <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||||
|
)
|
||||||
|
if variant not in (1, 2, 3, 4):
|
||||||
|
raise InvalidThresholdError(f"variant must be 1, 2, 3, or 4, got {variant}")
|
||||||
|
|
||||||
|
self._min_score = min_score
|
||||||
|
self._variant = variant
|
||||||
|
self._metric = Bleu(tokeniser=tokeniser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return f"bleu-{self._variant}"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the BLEU check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
|
||||||
|
# Select the appropriate BLEU variant
|
||||||
|
score_map = {
|
||||||
|
1: result.bleu1,
|
||||||
|
2: result.bleu2,
|
||||||
|
3: result.bleu3,
|
||||||
|
4: result.bleu4,
|
||||||
|
}
|
||||||
|
actual_score = score_map[self._variant]
|
||||||
|
passed = actual_score >= self._min_score
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = (
|
||||||
|
f"BLEU-{self._variant} score {actual_score:.2f} "
|
||||||
|
f"meets minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
f"BLEU-{self._variant} score {actual_score:.2f} "
|
||||||
|
f"below minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual_score,
|
||||||
|
threshold=self._min_score,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RougeValidator:
|
||||||
|
"""Validates that ROUGE score meets minimum threshold."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_score: float,
|
||||||
|
variant: Literal["1", "2", "l"] = "l",
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the ROUGE validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum ROUGE F-measure required (0.0 to 1.0).
|
||||||
|
variant: ROUGE variant ("1", "2", or "l"). Defaults to "l".
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||||
|
"""
|
||||||
|
if not 0.0 <= min_score <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||||
|
)
|
||||||
|
if variant not in ("1", "2", "l"):
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"variant must be '1', '2', or 'l', got '{variant}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_score = min_score
|
||||||
|
self._variant = variant
|
||||||
|
self._metric = Rouge(tokeniser=tokeniser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return f"rouge-{self._variant}"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the ROUGE check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
|
||||||
|
# Select the appropriate ROUGE variant (use F-measure)
|
||||||
|
score_map = {
|
||||||
|
"1": result.rouge1.fmeasure,
|
||||||
|
"2": result.rouge2.fmeasure,
|
||||||
|
"l": result.rouge_l.fmeasure,
|
||||||
|
}
|
||||||
|
actual_score = score_map[self._variant]
|
||||||
|
passed = actual_score >= self._min_score
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = (
|
||||||
|
f"ROUGE-{self._variant.upper()} score {actual_score:.2f} "
|
||||||
|
f"meets minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
f"ROUGE-{self._variant.upper()} score {actual_score:.2f} "
|
||||||
|
f"below minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual_score,
|
||||||
|
threshold=self._min_score,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LexicalValidator:
|
||||||
|
"""Validates lexical similarity meets threshold."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_jaccard: float | None = None,
|
||||||
|
min_overlap: float | None = None,
|
||||||
|
tokeniser: WordTokeniser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the lexical validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_jaccard: Minimum Jaccard similarity required (0.0 to 1.0).
|
||||||
|
min_overlap: Minimum token overlap required (0.0 to 1.0).
|
||||||
|
tokeniser: Tokeniser to use. Defaults to WordTokeniser().
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If thresholds are invalid or none provided.
|
||||||
|
"""
|
||||||
|
if min_jaccard is None and min_overlap is None:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
"At least one of min_jaccard or min_overlap must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
if min_jaccard is not None and not 0.0 <= min_jaccard <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_jaccard must be between 0.0 and 1.0, got {min_jaccard}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if min_overlap is not None and not 0.0 <= min_overlap <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_overlap must be between 0.0 and 1.0, got {min_overlap}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_jaccard = min_jaccard
|
||||||
|
self._min_overlap = min_overlap
|
||||||
|
self._metric = Lexical(tokeniser=tokeniser)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "lexical"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the lexical similarity check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
|
||||||
|
# Check each threshold that was specified
|
||||||
|
failures = []
|
||||||
|
if self._min_jaccard is not None and result.jaccard < self._min_jaccard:
|
||||||
|
failures.append(
|
||||||
|
f"Jaccard {result.jaccard:.2f} below minimum {self._min_jaccard:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._min_overlap is not None and result.token_overlap < self._min_overlap:
|
||||||
|
failures.append(
|
||||||
|
f"token overlap {result.token_overlap:.2f} "
|
||||||
|
f"below minimum {self._min_overlap:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
passed = len(failures) == 0
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
parts = []
|
||||||
|
if self._min_jaccard is not None:
|
||||||
|
parts.append(f"Jaccard {result.jaccard:.2f} >= {self._min_jaccard:.2f}")
|
||||||
|
if self._min_overlap is not None:
|
||||||
|
parts.append(
|
||||||
|
f"overlap {result.token_overlap:.2f} >= {self._min_overlap:.2f}"
|
||||||
|
)
|
||||||
|
message = "Lexical similarity: " + ", ".join(parts)
|
||||||
|
else:
|
||||||
|
message = "Lexical similarity: " + "; ".join(failures)
|
||||||
|
|
||||||
|
# Build actual value dict
|
||||||
|
actual = {"jaccard": result.jaccard, "token_overlap": result.token_overlap}
|
||||||
|
threshold = {}
|
||||||
|
if self._min_jaccard is not None:
|
||||||
|
threshold["min_jaccard"] = self._min_jaccard
|
||||||
|
if self._min_overlap is not None:
|
||||||
|
threshold["min_overlap"] = self._min_overlap
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=actual,
|
||||||
|
threshold=threshold,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticValidator:
|
||||||
|
"""Validates that semantic similarity meets minimum threshold.
|
||||||
|
|
||||||
|
Requires the `veritext[semantic]` extra to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_score: float,
|
||||||
|
model: str = "all-MiniLM-L6-v2",
|
||||||
|
cache_embeddings: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the semantic validator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score: Minimum semantic similarity required (0.0 to 1.0).
|
||||||
|
model: Name of the sentence-transformers model to use.
|
||||||
|
cache_embeddings: Whether to cache embeddings for repeated texts.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidThresholdError: If min_score is not in range [0.0, 1.0].
|
||||||
|
DependencyError: If sentence-transformers is not installed.
|
||||||
|
"""
|
||||||
|
if not 0.0 <= min_score <= 1.0:
|
||||||
|
raise InvalidThresholdError(
|
||||||
|
f"min_score must be between 0.0 and 1.0, got {min_score}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._min_score = min_score
|
||||||
|
# Lazy import to avoid loading PyTorch unless needed
|
||||||
|
from veritext.semantic.similarity import SemanticSimilarity
|
||||||
|
|
||||||
|
self._metric: SemanticSimilarity = SemanticSimilarity(
|
||||||
|
model=model, cache_embeddings=cache_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the name of this check."""
|
||||||
|
return "semantic"
|
||||||
|
|
||||||
|
def check(self, text: str, context: ValidationContext) -> CheckResult:
|
||||||
|
"""
|
||||||
|
Run the semantic similarity check.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to validate.
|
||||||
|
context: Validation context containing reference text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckResult with pass/fail status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValidationError: If reference text is missing from context.
|
||||||
|
"""
|
||||||
|
if context.reference is None:
|
||||||
|
raise ValidationError(f"{self.name} requires reference text in context")
|
||||||
|
|
||||||
|
result = self._metric.score(text, context.reference)
|
||||||
|
passed = result.similarity >= self._min_score
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
message = (
|
||||||
|
f"Semantic similarity {result.similarity:.2f} "
|
||||||
|
f"meets minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
f"Semantic similarity {result.similarity:.2f} "
|
||||||
|
f"below minimum {self._min_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CheckResult(
|
||||||
|
name=self.name,
|
||||||
|
passed=passed,
|
||||||
|
actual=result.similarity,
|
||||||
|
threshold=self._min_score,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
1
tests/test_benchmark/__init__.py
Normal file
1
tests/test_benchmark/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the benchmark module."""
|
||||||
145
tests/test_benchmark/test_models.py
Normal file
145
tests/test_benchmark/test_models.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Tests for benchmark data models."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkRun:
|
||||||
|
"""Tests for BenchmarkRun model."""
|
||||||
|
|
||||||
|
def test_create_benchmark_run(self) -> None:
|
||||||
|
"""BenchmarkRun can be created with required fields."""
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id="test-id-123",
|
||||||
|
benchmark_name="test-benchmark",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0-dev",
|
||||||
|
metrics={"bleu4": 0.75, "rouge_l": 0.82},
|
||||||
|
sample_count=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert run.id == "test-id-123"
|
||||||
|
assert run.benchmark_name == "test-benchmark"
|
||||||
|
assert run.veritext_version == "0.1.0-dev"
|
||||||
|
assert run.metrics == {"bleu4": 0.75, "rouge_l": 0.82}
|
||||||
|
assert run.sample_count == 100
|
||||||
|
assert run.metadata == {}
|
||||||
|
|
||||||
|
def test_create_with_metadata(self) -> None:
|
||||||
|
"""BenchmarkRun can include optional metadata."""
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id="test-id-456",
|
||||||
|
benchmark_name="test-benchmark",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0-dev",
|
||||||
|
metrics={"bleu4": 0.75},
|
||||||
|
sample_count=50,
|
||||||
|
metadata={"git_sha": "abc123", "model_version": "gpt-4"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert run.metadata == {"git_sha": "abc123", "model_version": "gpt-4"}
|
||||||
|
|
||||||
|
def test_frozen_model(self) -> None:
|
||||||
|
"""BenchmarkRun is immutable."""
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id="test-id",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
run.id = "new-id" # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_serialisation(self) -> None:
|
||||||
|
"""BenchmarkRun can be serialised to dict."""
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id="test-id",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = run.model_dump()
|
||||||
|
assert data["id"] == "test-id"
|
||||||
|
assert data["benchmark_name"] == "test"
|
||||||
|
assert data["metrics"] == {"bleu4": 0.5}
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegressionReport:
|
||||||
|
"""Tests for RegressionReport model."""
|
||||||
|
|
||||||
|
def test_no_regression_summary(self) -> None:
|
||||||
|
"""Summary indicates no regression when detected is False."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=False,
|
||||||
|
baseline={"bleu4": 0.75, "rouge_l": 0.80},
|
||||||
|
current={"bleu4": 0.76, "rouge_l": 0.81},
|
||||||
|
deltas={"bleu4": 0.01, "rouge_l": 0.01},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "No regression detected" in report.summary
|
||||||
|
|
||||||
|
def test_regression_summary(self) -> None:
|
||||||
|
"""Summary lists regressed metrics when detected is True."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=True,
|
||||||
|
baseline={"bleu4": 0.75, "rouge_l": 0.80},
|
||||||
|
current={"bleu4": 0.65, "rouge_l": 0.78},
|
||||||
|
deltas={"bleu4": -0.10, "rouge_l": -0.02},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Regression detected" in report.summary
|
||||||
|
assert "bleu4" in report.summary
|
||||||
|
assert "0.6500" in report.summary
|
||||||
|
assert "baseline: 0.7500" in report.summary
|
||||||
|
|
||||||
|
def test_regression_excludes_within_tolerance(self) -> None:
|
||||||
|
"""Summary only shows metrics that exceed tolerance."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=True,
|
||||||
|
baseline={"bleu4": 0.75, "rouge_l": 0.80},
|
||||||
|
current={"bleu4": 0.65, "rouge_l": 0.78},
|
||||||
|
deltas={"bleu4": -0.10, "rouge_l": -0.02},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# rouge_l is -0.02, within tolerance of 0.05, so shouldn't appear
|
||||||
|
assert "rouge_l" not in report.summary
|
||||||
|
# bleu4 is -0.10, exceeds tolerance, so should appear
|
||||||
|
assert "bleu4" in report.summary
|
||||||
|
|
||||||
|
def test_frozen_model(self) -> None:
|
||||||
|
"""RegressionReport is immutable."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=False,
|
||||||
|
baseline={},
|
||||||
|
current={},
|
||||||
|
deltas={},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
report.detected = True # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_tolerance_in_summary(self) -> None:
|
||||||
|
"""Summary includes tolerance threshold."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=True,
|
||||||
|
baseline={"metric": 0.80},
|
||||||
|
current={"metric": 0.50},
|
||||||
|
deltas={"metric": -0.30},
|
||||||
|
tolerance=0.10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "10.00%" in report.summary
|
||||||
229
tests/test_benchmark/test_regression.py
Normal file
229
tests/test_benchmark/test_regression.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""Tests for regression detection."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun
|
||||||
|
from veritext.benchmark.regression import compute_baseline, detect_regression
|
||||||
|
|
||||||
|
|
||||||
|
def make_run(
|
||||||
|
run_id: str,
|
||||||
|
metrics: dict[str, float],
|
||||||
|
day: int = 1,
|
||||||
|
) -> BenchmarkRun:
|
||||||
|
"""Helper to create a BenchmarkRun."""
|
||||||
|
return BenchmarkRun(
|
||||||
|
id=run_id,
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, day, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics=metrics,
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeBaseline:
|
||||||
|
"""Tests for baseline computation."""
|
||||||
|
|
||||||
|
def test_empty_runs(self) -> None:
|
||||||
|
"""Returns empty baseline for empty runs list."""
|
||||||
|
baseline = compute_baseline([])
|
||||||
|
assert baseline == {}
|
||||||
|
|
||||||
|
def test_single_run(self) -> None:
|
||||||
|
"""Single run produces baseline equal to that run's metrics."""
|
||||||
|
runs = [make_run("r1", {"bleu4": 0.75, "rouge_l": 0.80})]
|
||||||
|
|
||||||
|
baseline = compute_baseline(runs)
|
||||||
|
|
||||||
|
assert baseline["bleu4"] == 0.75
|
||||||
|
assert baseline["rouge_l"] == 0.80
|
||||||
|
|
||||||
|
def test_multiple_runs_average(self) -> None:
|
||||||
|
"""Baseline is the average of all runs in window."""
|
||||||
|
runs = [
|
||||||
|
make_run("r1", {"bleu4": 0.70}, day=3),
|
||||||
|
make_run("r2", {"bleu4": 0.80}, day=2),
|
||||||
|
make_run("r3", {"bleu4": 0.90}, day=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
baseline = compute_baseline(runs, window=3)
|
||||||
|
|
||||||
|
assert baseline["bleu4"] == pytest.approx(0.80) # (0.70+0.80+0.90)/3
|
||||||
|
|
||||||
|
def test_window_limits_runs(self) -> None:
|
||||||
|
"""Only includes runs within the window size."""
|
||||||
|
runs = [
|
||||||
|
make_run("r1", {"bleu4": 0.70}, day=5), # most recent
|
||||||
|
make_run("r2", {"bleu4": 0.80}, day=4),
|
||||||
|
make_run("r3", {"bleu4": 0.90}, day=3),
|
||||||
|
make_run("r4", {"bleu4": 0.60}, day=2), # excluded
|
||||||
|
make_run("r5", {"bleu4": 0.50}, day=1), # excluded
|
||||||
|
]
|
||||||
|
|
||||||
|
baseline = compute_baseline(runs, window=3)
|
||||||
|
|
||||||
|
# Only first 3 runs: (0.70 + 0.80 + 0.90) / 3 = 0.80
|
||||||
|
assert baseline["bleu4"] == pytest.approx(0.80)
|
||||||
|
|
||||||
|
def test_partial_history(self) -> None:
|
||||||
|
"""Works when fewer runs than window size exist."""
|
||||||
|
runs = [
|
||||||
|
make_run("r1", {"bleu4": 0.70}),
|
||||||
|
make_run("r2", {"bleu4": 0.80}),
|
||||||
|
]
|
||||||
|
|
||||||
|
baseline = compute_baseline(runs, window=10)
|
||||||
|
|
||||||
|
# Only 2 runs available: (0.70 + 0.80) / 2 = 0.75
|
||||||
|
assert baseline["bleu4"] == pytest.approx(0.75)
|
||||||
|
|
||||||
|
def test_multiple_metrics(self) -> None:
|
||||||
|
"""Computes baseline for all metrics present."""
|
||||||
|
runs = [
|
||||||
|
make_run("r1", {"bleu4": 0.70, "rouge_l": 0.75}),
|
||||||
|
make_run("r2", {"bleu4": 0.80, "rouge_l": 0.85}),
|
||||||
|
]
|
||||||
|
|
||||||
|
baseline = compute_baseline(runs)
|
||||||
|
|
||||||
|
assert baseline["bleu4"] == pytest.approx(0.75)
|
||||||
|
assert baseline["rouge_l"] == pytest.approx(0.80)
|
||||||
|
|
||||||
|
def test_varying_metrics(self) -> None:
|
||||||
|
"""Handles runs with different metric sets."""
|
||||||
|
runs = [
|
||||||
|
make_run("r1", {"bleu4": 0.70, "rouge_l": 0.75}),
|
||||||
|
make_run("r2", {"bleu4": 0.80}), # No rouge_l
|
||||||
|
]
|
||||||
|
|
||||||
|
baseline = compute_baseline(runs)
|
||||||
|
|
||||||
|
# bleu4 appears in both runs
|
||||||
|
assert baseline["bleu4"] == pytest.approx(0.75)
|
||||||
|
# rouge_l only appears in one run
|
||||||
|
assert baseline["rouge_l"] == pytest.approx(0.75)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectRegression:
|
||||||
|
"""Tests for regression detection."""
|
||||||
|
|
||||||
|
def test_no_baseline(self) -> None:
|
||||||
|
"""No regression when baseline is empty."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.70},
|
||||||
|
baseline={},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not report.detected
|
||||||
|
assert report.deltas == {}
|
||||||
|
|
||||||
|
def test_no_regression_stable(self) -> None:
|
||||||
|
"""No regression when metrics are stable."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.75},
|
||||||
|
baseline={"bleu4": 0.75},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not report.detected
|
||||||
|
assert report.deltas["bleu4"] == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_no_regression_improved(self) -> None:
|
||||||
|
"""No regression when metrics improved."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.85},
|
||||||
|
baseline={"bleu4": 0.75},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not report.detected
|
||||||
|
assert report.deltas["bleu4"] == pytest.approx(0.10)
|
||||||
|
|
||||||
|
def test_no_regression_within_tolerance(self) -> None:
|
||||||
|
"""No regression when drop is within tolerance."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.73},
|
||||||
|
baseline={"bleu4": 0.75},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not report.detected
|
||||||
|
assert report.deltas["bleu4"] == pytest.approx(-0.02)
|
||||||
|
|
||||||
|
def test_regression_detected(self) -> None:
|
||||||
|
"""Regression detected when metric drops beyond tolerance."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.65},
|
||||||
|
baseline={"bleu4": 0.75},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert report.detected
|
||||||
|
assert report.deltas["bleu4"] == pytest.approx(-0.10)
|
||||||
|
|
||||||
|
def test_regression_at_tolerance_boundary(self) -> None:
|
||||||
|
"""Drop at tolerance boundary is not a regression."""
|
||||||
|
# Use a value clearly at the boundary (accounting for float precision)
|
||||||
|
# The implementation checks delta < -tolerance (strictly less than)
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.50},
|
||||||
|
baseline={"bleu4": 0.50},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delta is 0.0, well within tolerance
|
||||||
|
assert not report.detected
|
||||||
|
assert report.deltas["bleu4"] == 0.0
|
||||||
|
|
||||||
|
def test_regression_just_beyond_tolerance(self) -> None:
|
||||||
|
"""Just beyond tolerance is a regression."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.6999},
|
||||||
|
baseline={"bleu4": 0.75},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delta is -0.0501, which is < -tolerance
|
||||||
|
assert report.detected
|
||||||
|
|
||||||
|
def test_multiple_metrics_any_regresses(self) -> None:
|
||||||
|
"""Regression detected if any metric exceeds tolerance."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={"bleu4": 0.65, "rouge_l": 0.80},
|
||||||
|
baseline={"bleu4": 0.75, "rouge_l": 0.80},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert report.detected
|
||||||
|
# Only bleu4 regressed
|
||||||
|
assert report.deltas["bleu4"] == pytest.approx(-0.10)
|
||||||
|
assert report.deltas["rouge_l"] == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_report_contains_all_values(self) -> None:
|
||||||
|
"""Report includes baseline, current, and deltas."""
|
||||||
|
baseline = {"bleu4": 0.75, "rouge_l": 0.80}
|
||||||
|
current = {"bleu4": 0.65, "rouge_l": 0.82}
|
||||||
|
|
||||||
|
report = detect_regression(current, baseline, tolerance=0.05)
|
||||||
|
|
||||||
|
assert report.baseline == baseline
|
||||||
|
assert report.current == current
|
||||||
|
assert report.tolerance == 0.05
|
||||||
|
assert "bleu4" in report.deltas
|
||||||
|
assert "rouge_l" in report.deltas
|
||||||
|
|
||||||
|
def test_missing_metric_in_current(self) -> None:
|
||||||
|
"""Missing metric in current treated as zero."""
|
||||||
|
report = detect_regression(
|
||||||
|
current={},
|
||||||
|
baseline={"bleu4": 0.75},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 0.0 - 0.75 = -0.75, which is a regression
|
||||||
|
assert report.detected
|
||||||
|
assert report.deltas["bleu4"] == pytest.approx(-0.75)
|
||||||
247
tests/test_benchmark/test_runner.py
Normal file
247
tests/test_benchmark/test_runner.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""Tests for benchmark runner."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun
|
||||||
|
from veritext.benchmark.runner import Benchmark
|
||||||
|
from veritext.core.exceptions import RegressionDetectedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def benchmark(tmp_path: Path) -> Benchmark:
|
||||||
|
"""Create a Benchmark instance with temporary storage."""
|
||||||
|
return Benchmark("test-suite", storage_path=tmp_path / "benchmarks")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_data() -> tuple[list[str], list[str]]:
|
||||||
|
"""Sample candidates and references for testing."""
|
||||||
|
candidates = [
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"A fast auburn fox leaps above the sleepy hound.",
|
||||||
|
]
|
||||||
|
references = [
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"The swift brown fox jumps over the lazy dog.",
|
||||||
|
]
|
||||||
|
return candidates, references
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkInit:
|
||||||
|
"""Tests for Benchmark initialisation."""
|
||||||
|
|
||||||
|
def test_creates_storage_directory(self, tmp_path: Path) -> None:
|
||||||
|
"""Benchmark creates storage directory on init."""
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
Benchmark("my-suite", storage_path=storage_path)
|
||||||
|
|
||||||
|
assert storage_path.exists()
|
||||||
|
|
||||||
|
def test_name_property(self, benchmark: Benchmark) -> None:
|
||||||
|
"""Benchmark exposes its name."""
|
||||||
|
assert benchmark.name == "test-suite"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvaluate:
|
||||||
|
"""Tests for the evaluate method."""
|
||||||
|
|
||||||
|
def test_evaluate_stores_run(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Evaluate creates and stores a benchmark run."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
run = benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
assert isinstance(run, BenchmarkRun)
|
||||||
|
assert run.benchmark_name == "test-suite"
|
||||||
|
assert run.sample_count == 2
|
||||||
|
|
||||||
|
def test_evaluate_returns_metrics(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Evaluate computes default metrics."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
run = benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
# Default metrics are rouge_l and bleu4
|
||||||
|
assert "rouge_l" in run.metrics
|
||||||
|
assert "bleu4" in run.metrics
|
||||||
|
assert 0.0 <= run.metrics["rouge_l"] <= 1.0
|
||||||
|
assert 0.0 <= run.metrics["bleu4"] <= 1.0
|
||||||
|
|
||||||
|
def test_evaluate_custom_metrics(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Evaluate can compute custom metrics."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
run = benchmark.evaluate(
|
||||||
|
candidates, references, metrics=["bleu1", "bleu2", "rouge1"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bleu1" in run.metrics
|
||||||
|
assert "bleu2" in run.metrics
|
||||||
|
assert "rouge1" in run.metrics
|
||||||
|
assert "bleu4" not in run.metrics # Not requested
|
||||||
|
|
||||||
|
def test_evaluate_with_metadata(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Evaluate can include metadata."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
run = benchmark.evaluate(
|
||||||
|
candidates, references, metadata={"git_sha": "abc123", "model": "gpt-4"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert run.metadata == {"git_sha": "abc123", "model": "gpt-4"}
|
||||||
|
|
||||||
|
def test_evaluate_stores_retrievable(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Stored run can be retrieved."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
run = benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
history = benchmark.get_history()
|
||||||
|
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0].id == run.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckRegression:
|
||||||
|
"""Tests for regression checking."""
|
||||||
|
|
||||||
|
def test_check_no_runs(self, benchmark: Benchmark) -> None:
|
||||||
|
"""No regression when no runs exist."""
|
||||||
|
report = benchmark.check_regression()
|
||||||
|
|
||||||
|
assert not report.detected
|
||||||
|
assert report.baseline == {}
|
||||||
|
assert report.current == {}
|
||||||
|
|
||||||
|
def test_check_single_run(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""No regression with single run (no baseline)."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
report = benchmark.check_regression()
|
||||||
|
|
||||||
|
# First run has no baseline to compare against
|
||||||
|
assert not report.detected
|
||||||
|
|
||||||
|
def test_check_stable_metrics(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""No regression when metrics are stable."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
# Run multiple times with same data
|
||||||
|
for _ in range(3):
|
||||||
|
benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
report = benchmark.check_regression()
|
||||||
|
assert not report.detected
|
||||||
|
|
||||||
|
def test_check_reports_regression(self, tmp_path: Path) -> None:
|
||||||
|
"""Reports regression when metrics drop significantly."""
|
||||||
|
benchmark = Benchmark("regress-test", storage_path=tmp_path / "benchmarks")
|
||||||
|
|
||||||
|
# First run with good metrics
|
||||||
|
good_candidates = ["The quick brown fox jumps."]
|
||||||
|
good_references = ["The quick brown fox jumps."]
|
||||||
|
benchmark.evaluate(good_candidates, good_references)
|
||||||
|
|
||||||
|
# Second run with worse metrics (different text)
|
||||||
|
bad_candidates = ["Something completely different here."]
|
||||||
|
benchmark.evaluate(bad_candidates, good_references)
|
||||||
|
|
||||||
|
report = benchmark.check_regression(tolerance=0.05)
|
||||||
|
|
||||||
|
# Should detect regression since second run is very different
|
||||||
|
assert report.detected or any(d < -0.05 for d in report.deltas.values())
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssertNoRegression:
|
||||||
|
"""Tests for assert_no_regression method."""
|
||||||
|
|
||||||
|
def test_passes_when_stable(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Does not raise when metrics are stable."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
benchmark.assert_no_regression()
|
||||||
|
|
||||||
|
def test_raises_on_regression(self, tmp_path: Path) -> None:
|
||||||
|
"""Raises RegressionDetectedError when quality drops."""
|
||||||
|
benchmark = Benchmark("regress-test", storage_path=tmp_path / "benchmarks")
|
||||||
|
|
||||||
|
# Establish baseline with perfect match
|
||||||
|
perfect = ["The quick brown fox."]
|
||||||
|
benchmark.evaluate(perfect, perfect)
|
||||||
|
|
||||||
|
# Second run with terrible match
|
||||||
|
terrible = ["Completely unrelated text."]
|
||||||
|
benchmark.evaluate(terrible, perfect)
|
||||||
|
|
||||||
|
with pytest.raises(RegressionDetectedError):
|
||||||
|
benchmark.assert_no_regression(tolerance=0.05)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetHistory:
|
||||||
|
"""Tests for get_history method."""
|
||||||
|
|
||||||
|
def test_empty_history(self, benchmark: Benchmark) -> None:
|
||||||
|
"""Returns empty list when no runs."""
|
||||||
|
history = benchmark.get_history()
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
def test_returns_runs(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Returns benchmark runs."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
run1 = benchmark.evaluate(candidates, references)
|
||||||
|
run2 = benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
history = benchmark.get_history()
|
||||||
|
|
||||||
|
assert len(history) == 2
|
||||||
|
assert history[0].id == run2.id # Most recent first
|
||||||
|
assert history[1].id == run1.id
|
||||||
|
|
||||||
|
def test_respects_limit(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Respects limit parameter."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
history = benchmark.get_history(limit=3)
|
||||||
|
assert len(history) == 3
|
||||||
|
|
||||||
|
def test_default_limit(
|
||||||
|
self, benchmark: Benchmark, sample_data: tuple[list[str], list[str]]
|
||||||
|
) -> None:
|
||||||
|
"""Default limit is 20."""
|
||||||
|
candidates, references = sample_data
|
||||||
|
|
||||||
|
for _ in range(25):
|
||||||
|
benchmark.evaluate(candidates, references)
|
||||||
|
|
||||||
|
history = benchmark.get_history()
|
||||||
|
assert len(history) == 20
|
||||||
297
tests/test_benchmark/test_storage.py
Normal file
297
tests/test_benchmark/test_storage.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""Tests for benchmark SQLite storage."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun
|
||||||
|
from veritext.benchmark.storage import BenchmarkStorage
|
||||||
|
from veritext.core.exceptions import StorageError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_path(tmp_path: Path) -> Path:
|
||||||
|
"""Return a temporary database path."""
|
||||||
|
return tmp_path / "benchmarks" / "test.db"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(db_path: Path) -> BenchmarkStorage:
|
||||||
|
"""Create a BenchmarkStorage instance."""
|
||||||
|
return BenchmarkStorage(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_run() -> BenchmarkRun:
|
||||||
|
"""Create a sample benchmark run."""
|
||||||
|
return BenchmarkRun(
|
||||||
|
id="run-001",
|
||||||
|
benchmark_name="test-suite",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0-dev",
|
||||||
|
metrics={"bleu4": 0.75, "rouge_l": 0.82},
|
||||||
|
sample_count=100,
|
||||||
|
metadata={"git_sha": "abc123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseCreation:
|
||||||
|
"""Tests for database initialisation."""
|
||||||
|
|
||||||
|
def test_creates_database_file(self, db_path: Path) -> None:
|
||||||
|
"""Storage creates the database file on init."""
|
||||||
|
assert not db_path.exists()
|
||||||
|
BenchmarkStorage(db_path)
|
||||||
|
assert db_path.exists()
|
||||||
|
|
||||||
|
def test_creates_parent_directories(self, tmp_path: Path) -> None:
|
||||||
|
"""Storage creates parent directories if needed."""
|
||||||
|
nested_path = tmp_path / "deep" / "nested" / "path" / "test.db"
|
||||||
|
BenchmarkStorage(nested_path)
|
||||||
|
assert nested_path.exists()
|
||||||
|
|
||||||
|
def test_creates_tables(self, db_path: Path) -> None:
|
||||||
|
"""Storage creates required tables."""
|
||||||
|
BenchmarkStorage(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||||
|
tables = {row[0] for row in cursor.fetchall()}
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert "benchmark_runs" in tables
|
||||||
|
assert "benchmark_metrics" in tables
|
||||||
|
|
||||||
|
def test_creates_index(self, db_path: Path) -> None:
|
||||||
|
"""Storage creates index on benchmark_name and timestamp."""
|
||||||
|
BenchmarkStorage(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='index'")
|
||||||
|
indices = {row[0] for row in cursor.fetchall()}
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert "idx_benchmark_name" in indices
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveRun:
|
||||||
|
"""Tests for saving benchmark runs."""
|
||||||
|
|
||||||
|
def test_save_run(
|
||||||
|
self, storage: BenchmarkStorage, sample_run: BenchmarkRun
|
||||||
|
) -> None:
|
||||||
|
"""Storage can save a benchmark run."""
|
||||||
|
storage.save_run(sample_run)
|
||||||
|
|
||||||
|
runs = storage.get_runs("test-suite")
|
||||||
|
assert len(runs) == 1
|
||||||
|
assert runs[0].id == "run-001"
|
||||||
|
|
||||||
|
def test_save_preserves_all_fields(
|
||||||
|
self, storage: BenchmarkStorage, sample_run: BenchmarkRun
|
||||||
|
) -> None:
|
||||||
|
"""Saved run preserves all fields correctly."""
|
||||||
|
storage.save_run(sample_run)
|
||||||
|
|
||||||
|
runs = storage.get_runs("test-suite")
|
||||||
|
run = runs[0]
|
||||||
|
|
||||||
|
assert run.id == sample_run.id
|
||||||
|
assert run.benchmark_name == sample_run.benchmark_name
|
||||||
|
assert run.timestamp == sample_run.timestamp
|
||||||
|
assert run.veritext_version == sample_run.veritext_version
|
||||||
|
assert run.metrics == sample_run.metrics
|
||||||
|
assert run.sample_count == sample_run.sample_count
|
||||||
|
assert run.metadata == sample_run.metadata
|
||||||
|
|
||||||
|
def test_save_duplicate_id_raises(
|
||||||
|
self, storage: BenchmarkStorage, sample_run: BenchmarkRun
|
||||||
|
) -> None:
|
||||||
|
"""Saving a run with duplicate ID raises StorageError."""
|
||||||
|
storage.save_run(sample_run)
|
||||||
|
|
||||||
|
with pytest.raises(StorageError, match="already exists"):
|
||||||
|
storage.save_run(sample_run)
|
||||||
|
|
||||||
|
def test_save_run_empty_metadata(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Run with empty metadata saves correctly."""
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id="run-no-meta",
|
||||||
|
benchmark_name="test-suite",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0-dev",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage.save_run(run)
|
||||||
|
retrieved = storage.get_latest_run("test-suite")
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRuns:
|
||||||
|
"""Tests for retrieving benchmark runs."""
|
||||||
|
|
||||||
|
def test_get_runs_empty_database(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Returns empty list for empty database."""
|
||||||
|
runs = storage.get_runs("nonexistent")
|
||||||
|
assert runs == []
|
||||||
|
|
||||||
|
def test_get_runs_filters_by_name(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Returns only runs matching the benchmark name."""
|
||||||
|
run1 = BenchmarkRun(
|
||||||
|
id="run-1",
|
||||||
|
benchmark_name="suite-a",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
run2 = BenchmarkRun(
|
||||||
|
id="run-2",
|
||||||
|
benchmark_name="suite-b",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.6},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage.save_run(run1)
|
||||||
|
storage.save_run(run2)
|
||||||
|
|
||||||
|
runs_a = storage.get_runs("suite-a")
|
||||||
|
runs_b = storage.get_runs("suite-b")
|
||||||
|
|
||||||
|
assert len(runs_a) == 1
|
||||||
|
assert runs_a[0].id == "run-1"
|
||||||
|
assert len(runs_b) == 1
|
||||||
|
assert runs_b[0].id == "run-2"
|
||||||
|
|
||||||
|
def test_get_runs_ordered_by_timestamp(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Returns runs ordered by timestamp, most recent first."""
|
||||||
|
run_old = BenchmarkRun(
|
||||||
|
id="run-old",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 10, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
run_new = BenchmarkRun(
|
||||||
|
id="run-new",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 20, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.6},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save in reverse order
|
||||||
|
storage.save_run(run_new)
|
||||||
|
storage.save_run(run_old)
|
||||||
|
|
||||||
|
runs = storage.get_runs("test")
|
||||||
|
assert runs[0].id == "run-new"
|
||||||
|
assert runs[1].id == "run-old"
|
||||||
|
|
||||||
|
def test_get_runs_with_limit(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Respects limit parameter."""
|
||||||
|
for i in range(5):
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id=f"run-{i}",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, i + 1, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5 + i * 0.1},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
storage.save_run(run)
|
||||||
|
|
||||||
|
runs = storage.get_runs("test", limit=3)
|
||||||
|
assert len(runs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLatestRun:
|
||||||
|
"""Tests for getting the latest run."""
|
||||||
|
|
||||||
|
def test_get_latest_run_empty(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Returns None for empty database."""
|
||||||
|
result = storage.get_latest_run("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_latest_run(self, storage: BenchmarkStorage) -> None:
|
||||||
|
"""Returns the most recent run."""
|
||||||
|
run_old = BenchmarkRun(
|
||||||
|
id="run-old",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 10, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
run_new = BenchmarkRun(
|
||||||
|
id="run-new",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 20, 12, 0, 0, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.6},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage.save_run(run_old)
|
||||||
|
storage.save_run(run_new)
|
||||||
|
|
||||||
|
latest = storage.get_latest_run("test")
|
||||||
|
assert latest is not None
|
||||||
|
assert latest.id == "run-new"
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentAccess:
|
||||||
|
"""Tests for concurrent database access."""
|
||||||
|
|
||||||
|
def test_concurrent_writes(self, db_path: Path) -> None:
|
||||||
|
"""Multiple threads can write concurrently with WAL mode."""
|
||||||
|
errors: list[Exception] = []
|
||||||
|
|
||||||
|
def write_run(run_id: int) -> None:
|
||||||
|
try:
|
||||||
|
storage = BenchmarkStorage(db_path)
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id=f"run-{run_id}",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2025, 1, 15, 12, 0, run_id, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"bleu4": 0.5},
|
||||||
|
sample_count=10,
|
||||||
|
)
|
||||||
|
storage.save_run(run)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=write_run, args=(i,)) for i in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert not errors, f"Concurrent writes failed: {errors}"
|
||||||
|
|
||||||
|
storage = BenchmarkStorage(db_path)
|
||||||
|
runs = storage.get_runs("test")
|
||||||
|
assert len(runs) == 10
|
||||||
|
|
||||||
|
def test_wal_mode_enabled(self, db_path: Path) -> None:
|
||||||
|
"""Database uses WAL journal mode."""
|
||||||
|
BenchmarkStorage(db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
cursor = conn.execute("PRAGMA journal_mode")
|
||||||
|
mode = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert mode.lower() == "wal"
|
||||||
1
tests/test_cli/__init__.py
Normal file
1
tests/test_cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""CLI test suite."""
|
||||||
337
tests/test_cli/test_benchmark.py
Normal file
337
tests/test_cli/test_benchmark.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""Tests for CLI benchmark commands."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from veritext.cli.main import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkRun:
|
||||||
|
"""Tests for benchmark run command."""
|
||||||
|
|
||||||
|
def test_benchmark_run_basic(self, tmp_path: Path) -> None:
|
||||||
|
"""Test basic benchmark run."""
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text(
|
||||||
|
'{"candidate": "hello world today", "reference": "hello world today"}\n'
|
||||||
|
'{"candidate": "foo bar baz qux", "reference": "foo bar baz qux"}'
|
||||||
|
)
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(data_file),
|
||||||
|
"-m",
|
||||||
|
"rouge_l,bleu4",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Benchmark 'test_bench' completed" in result.stdout
|
||||||
|
assert "Samples: 2" in result.stdout
|
||||||
|
assert "rouge_l:" in result.stdout
|
||||||
|
assert "bleu4:" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_run_file_not_found(self, tmp_path: Path) -> None:
|
||||||
|
"""Test benchmark run with non-existent file."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
"/nonexistent/file.jsonl",
|
||||||
|
"-s",
|
||||||
|
str(tmp_path / "benchmarks"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Error" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_run_creates_storage(self, tmp_path: Path) -> None:
|
||||||
|
"""Test that benchmark run creates storage directory."""
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text('{"candidate": "hello", "reference": "hello"}')
|
||||||
|
storage_path = tmp_path / "new_benchmarks"
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(data_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert storage_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkShow:
|
||||||
|
"""Tests for benchmark show command."""
|
||||||
|
|
||||||
|
def test_benchmark_show_no_runs(self, tmp_path: Path) -> None:
|
||||||
|
"""Test showing benchmark with no runs."""
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
storage_path.mkdir()
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"show",
|
||||||
|
"nonexistent_bench",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "No benchmark runs found" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_show_with_runs(self, tmp_path: Path) -> None:
|
||||||
|
"""Test showing benchmark history with runs."""
|
||||||
|
# First create some runs
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text('{"candidate": "hello world", "reference": "hello world"}')
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
|
||||||
|
# Run benchmark twice
|
||||||
|
for _ in range(2):
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(data_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show history
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"show",
|
||||||
|
"test_bench",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Benchmark History" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_show_limit(self, tmp_path: Path) -> None:
|
||||||
|
"""Test showing limited benchmark history."""
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text('{"candidate": "hello", "reference": "hello"}')
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
|
||||||
|
# Run benchmark 3 times
|
||||||
|
for _ in range(3):
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(data_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show only last 2
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"show",
|
||||||
|
"test_bench",
|
||||||
|
"--last",
|
||||||
|
"2",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkCheck:
|
||||||
|
"""Tests for benchmark check command."""
|
||||||
|
|
||||||
|
def test_benchmark_check_no_regression(self, tmp_path: Path) -> None:
|
||||||
|
"""Test checking for regression with no regression."""
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text(
|
||||||
|
'{"candidate": "hello world today", "reference": "hello world today"}'
|
||||||
|
)
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
|
||||||
|
# Run benchmark twice with same data (no regression)
|
||||||
|
for _ in range(2):
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(data_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for regression
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"check",
|
||||||
|
"test_bench",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "No regression detected" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_check_with_regression(self, tmp_path: Path) -> None:
|
||||||
|
"""Test checking for regression when regression occurs."""
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
|
||||||
|
# First run with good data
|
||||||
|
good_file = tmp_path / "good.jsonl"
|
||||||
|
good_file.write_text(
|
||||||
|
'{"candidate": "hello world today", "reference": "hello world today"}'
|
||||||
|
)
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(good_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second run with bad data (regression)
|
||||||
|
bad_file = tmp_path / "bad.jsonl"
|
||||||
|
bad_file.write_text(
|
||||||
|
'{"candidate": "completely different", "reference": "hello world today"}'
|
||||||
|
)
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(bad_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for regression
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"check",
|
||||||
|
"test_bench",
|
||||||
|
"-t",
|
||||||
|
"0.05",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Regression detected" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_check_custom_tolerance(self, tmp_path: Path) -> None:
|
||||||
|
"""Test checking regression with custom tolerance."""
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text('{"candidate": "hello", "reference": "hello"}')
|
||||||
|
storage_path = tmp_path / "benchmarks"
|
||||||
|
|
||||||
|
runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"run",
|
||||||
|
"test_bench",
|
||||||
|
"-f",
|
||||||
|
str(data_file),
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"benchmark",
|
||||||
|
"check",
|
||||||
|
"test_bench",
|
||||||
|
"--tolerance",
|
||||||
|
"0.10",
|
||||||
|
"-s",
|
||||||
|
str(storage_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "10.00%" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkHelp:
|
||||||
|
"""Tests for benchmark help output."""
|
||||||
|
|
||||||
|
def test_benchmark_help(self) -> None:
|
||||||
|
"""Test benchmark help output."""
|
||||||
|
result = runner.invoke(app, ["benchmark", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "run" in result.stdout
|
||||||
|
assert "show" in result.stdout
|
||||||
|
assert "check" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_run_help(self) -> None:
|
||||||
|
"""Test benchmark run help output."""
|
||||||
|
result = runner.invoke(app, ["benchmark", "run", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "--file" in result.stdout
|
||||||
|
assert "--metrics" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_show_help(self) -> None:
|
||||||
|
"""Test benchmark show help output."""
|
||||||
|
result = runner.invoke(app, ["benchmark", "show", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "--last" in result.stdout
|
||||||
|
|
||||||
|
def test_benchmark_check_help(self) -> None:
|
||||||
|
"""Test benchmark check help output."""
|
||||||
|
result = runner.invoke(app, ["benchmark", "check", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "--tolerance" in result.stdout
|
||||||
|
assert "--window" in result.stdout
|
||||||
141
tests/test_cli/test_formatters.py
Normal file
141
tests/test_cli/test_formatters.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Tests for CLI output formatters."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from veritext.benchmark.models import BenchmarkRun, RegressionReport
|
||||||
|
from veritext.cli.formatters import (
|
||||||
|
format_benchmark_history,
|
||||||
|
format_regression_report,
|
||||||
|
format_validation_json,
|
||||||
|
format_validation_simple,
|
||||||
|
format_validation_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatValidationTable:
|
||||||
|
"""Tests for format_validation_table function."""
|
||||||
|
|
||||||
|
def test_format_empty_results(self) -> None:
|
||||||
|
"""Test formatting empty results."""
|
||||||
|
table = format_validation_table({})
|
||||||
|
assert table.title == "Validation Results"
|
||||||
|
assert table.row_count == 0
|
||||||
|
|
||||||
|
def test_format_single_metric(self) -> None:
|
||||||
|
"""Test formatting a single metric."""
|
||||||
|
results = {"bleu4": 0.8523}
|
||||||
|
table = format_validation_table(results)
|
||||||
|
assert table.row_count == 1
|
||||||
|
|
||||||
|
def test_format_multiple_metrics(self) -> None:
|
||||||
|
"""Test formatting multiple metrics."""
|
||||||
|
results = {"bleu4": 0.85, "rouge_l": 0.92, "jaccard": 0.75}
|
||||||
|
table = format_validation_table(results)
|
||||||
|
assert table.row_count == 3
|
||||||
|
|
||||||
|
def test_format_with_threshold(self) -> None:
|
||||||
|
"""Test formatting with threshold for pass/fail."""
|
||||||
|
results = {"bleu4": 0.85, "rouge_l": 0.45}
|
||||||
|
table = format_validation_table(results, threshold=0.5)
|
||||||
|
# Should have 3 columns: Metric, Score, Status
|
||||||
|
assert table.row_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatValidationJson:
|
||||||
|
"""Tests for format_validation_json function."""
|
||||||
|
|
||||||
|
def test_format_empty_results(self) -> None:
|
||||||
|
"""Test formatting empty results as JSON."""
|
||||||
|
result = format_validation_json({})
|
||||||
|
assert result == "{}"
|
||||||
|
|
||||||
|
def test_format_results(self) -> None:
|
||||||
|
"""Test formatting results as JSON."""
|
||||||
|
results = {"bleu4": 0.85, "rouge_l": 0.92}
|
||||||
|
result = format_validation_json(results)
|
||||||
|
assert '"bleu4": 0.85' in result
|
||||||
|
assert '"rouge_l": 0.92' in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatValidationSimple:
|
||||||
|
"""Tests for format_validation_simple function."""
|
||||||
|
|
||||||
|
def test_format_empty_results(self) -> None:
|
||||||
|
"""Test formatting empty results as simple text."""
|
||||||
|
result = format_validation_simple({})
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_format_results(self) -> None:
|
||||||
|
"""Test formatting results as simple text."""
|
||||||
|
results = {"bleu4": 0.8523, "rouge_l": 0.9234}
|
||||||
|
result = format_validation_simple(results)
|
||||||
|
assert "bleu4: 0.8523" in result
|
||||||
|
assert "rouge_l: 0.9234" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatBenchmarkHistory:
|
||||||
|
"""Tests for format_benchmark_history function."""
|
||||||
|
|
||||||
|
def test_format_empty_history(self) -> None:
|
||||||
|
"""Test formatting empty benchmark history."""
|
||||||
|
table = format_benchmark_history([])
|
||||||
|
assert table.title == "Benchmark History"
|
||||||
|
|
||||||
|
def test_format_single_run(self) -> None:
|
||||||
|
"""Test formatting a single benchmark run."""
|
||||||
|
run = BenchmarkRun(
|
||||||
|
id="test-id",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2024, 1, 15, 10, 30, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"rouge_l": 0.85, "bleu4": 0.72},
|
||||||
|
sample_count=100,
|
||||||
|
)
|
||||||
|
table = format_benchmark_history([run])
|
||||||
|
assert table.row_count == 1
|
||||||
|
|
||||||
|
def test_format_multiple_runs(self) -> None:
|
||||||
|
"""Test formatting multiple benchmark runs."""
|
||||||
|
runs = [
|
||||||
|
BenchmarkRun(
|
||||||
|
id=f"test-id-{i}",
|
||||||
|
benchmark_name="test",
|
||||||
|
timestamp=datetime(2024, 1, i + 1, 10, 30, tzinfo=UTC),
|
||||||
|
veritext_version="0.1.0",
|
||||||
|
metrics={"rouge_l": 0.8 + i * 0.01},
|
||||||
|
sample_count=100,
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
table = format_benchmark_history(runs)
|
||||||
|
assert table.row_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatRegressionReport:
|
||||||
|
"""Tests for format_regression_report function."""
|
||||||
|
|
||||||
|
def test_format_no_regression(self) -> None:
|
||||||
|
"""Test formatting report with no regression."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=False,
|
||||||
|
baseline={"rouge_l": 0.85},
|
||||||
|
current={"rouge_l": 0.86},
|
||||||
|
deltas={"rouge_l": 0.01},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
panel = format_regression_report(report)
|
||||||
|
assert panel.title == "Regression Check"
|
||||||
|
assert panel.border_style == "green"
|
||||||
|
|
||||||
|
def test_format_with_regression(self) -> None:
|
||||||
|
"""Test formatting report with regression detected."""
|
||||||
|
report = RegressionReport(
|
||||||
|
detected=True,
|
||||||
|
baseline={"rouge_l": 0.85, "bleu4": 0.72},
|
||||||
|
current={"rouge_l": 0.70, "bleu4": 0.70},
|
||||||
|
deltas={"rouge_l": -0.15, "bleu4": -0.02},
|
||||||
|
tolerance=0.05,
|
||||||
|
)
|
||||||
|
panel = format_regression_report(report)
|
||||||
|
assert panel.title == "Regression Check"
|
||||||
|
assert panel.border_style == "red"
|
||||||
145
tests/test_cli/test_readers.py
Normal file
145
tests/test_cli/test_readers.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Tests for CLI input readers."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.cli.readers import TextPair, read_jsonl, read_paired_jsonl
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextPair:
|
||||||
|
"""Tests for TextPair dataclass."""
|
||||||
|
|
||||||
|
def test_create_text_pair(self) -> None:
|
||||||
|
"""Test creating a TextPair."""
|
||||||
|
pair = TextPair(candidate="hello", reference="world")
|
||||||
|
assert pair.candidate == "hello"
|
||||||
|
assert pair.reference == "world"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadJsonl:
|
||||||
|
"""Tests for read_jsonl function."""
|
||||||
|
|
||||||
|
def test_read_valid_jsonl(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading a valid JSONL file."""
|
||||||
|
data = [
|
||||||
|
{"candidate": "foo", "reference": "bar"},
|
||||||
|
{"candidate": "baz", "reference": "qux"},
|
||||||
|
]
|
||||||
|
jsonl_file = tmp_path / "data.jsonl"
|
||||||
|
jsonl_file.write_text("\n".join(json.dumps(d) for d in data))
|
||||||
|
|
||||||
|
pairs = read_jsonl(jsonl_file)
|
||||||
|
|
||||||
|
assert len(pairs) == 2
|
||||||
|
assert pairs[0].candidate == "foo"
|
||||||
|
assert pairs[0].reference == "bar"
|
||||||
|
assert pairs[1].candidate == "baz"
|
||||||
|
assert pairs[1].reference == "qux"
|
||||||
|
|
||||||
|
def test_read_empty_file(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading an empty JSONL file."""
|
||||||
|
jsonl_file = tmp_path / "empty.jsonl"
|
||||||
|
jsonl_file.write_text("")
|
||||||
|
|
||||||
|
pairs = read_jsonl(jsonl_file)
|
||||||
|
|
||||||
|
assert pairs == []
|
||||||
|
|
||||||
|
def test_read_file_with_blank_lines(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading a JSONL file with blank lines."""
|
||||||
|
jsonl_file = tmp_path / "data.jsonl"
|
||||||
|
content = '{"candidate": "a", "reference": "b"}\n\n{"candidate": "c", "reference": "d"}\n'
|
||||||
|
jsonl_file.write_text(content)
|
||||||
|
|
||||||
|
pairs = read_jsonl(jsonl_file)
|
||||||
|
|
||||||
|
assert len(pairs) == 2
|
||||||
|
|
||||||
|
def test_read_file_not_found(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading a non-existent file."""
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
read_jsonl(tmp_path / "nonexistent.jsonl")
|
||||||
|
|
||||||
|
def test_read_invalid_json(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading a file with invalid JSON."""
|
||||||
|
jsonl_file = tmp_path / "invalid.jsonl"
|
||||||
|
jsonl_file.write_text("not valid json")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid JSON on line 1"):
|
||||||
|
read_jsonl(jsonl_file)
|
||||||
|
|
||||||
|
def test_read_missing_candidate_key(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading a file missing the candidate key."""
|
||||||
|
jsonl_file = tmp_path / "data.jsonl"
|
||||||
|
jsonl_file.write_text('{"reference": "bar"}')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing 'candidate' key on line 1"):
|
||||||
|
read_jsonl(jsonl_file)
|
||||||
|
|
||||||
|
def test_read_missing_reference_key(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading a file missing the reference key."""
|
||||||
|
jsonl_file = tmp_path / "data.jsonl"
|
||||||
|
jsonl_file.write_text('{"candidate": "foo"}')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing 'reference' key on line 1"):
|
||||||
|
read_jsonl(jsonl_file)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadPairedJsonl:
|
||||||
|
"""Tests for read_paired_jsonl function."""
|
||||||
|
|
||||||
|
def test_read_paired_valid(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading valid paired JSONL files."""
|
||||||
|
candidates_file = tmp_path / "candidates.jsonl"
|
||||||
|
references_file = tmp_path / "references.jsonl"
|
||||||
|
|
||||||
|
candidates_file.write_text('{"text": "foo"}\n{"text": "bar"}')
|
||||||
|
references_file.write_text('{"text": "baz"}\n{"text": "qux"}')
|
||||||
|
|
||||||
|
pairs = read_paired_jsonl(candidates_file, references_file)
|
||||||
|
|
||||||
|
assert len(pairs) == 2
|
||||||
|
assert pairs[0].candidate == "foo"
|
||||||
|
assert pairs[0].reference == "baz"
|
||||||
|
assert pairs[1].candidate == "bar"
|
||||||
|
assert pairs[1].reference == "qux"
|
||||||
|
|
||||||
|
def test_read_paired_length_mismatch(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading paired files with different lengths."""
|
||||||
|
candidates_file = tmp_path / "candidates.jsonl"
|
||||||
|
references_file = tmp_path / "references.jsonl"
|
||||||
|
|
||||||
|
candidates_file.write_text('{"text": "foo"}\n{"text": "bar"}')
|
||||||
|
references_file.write_text('{"text": "baz"}')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="does not match"):
|
||||||
|
read_paired_jsonl(candidates_file, references_file)
|
||||||
|
|
||||||
|
def test_read_paired_candidates_not_found(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading when candidates file doesn't exist."""
|
||||||
|
references_file = tmp_path / "references.jsonl"
|
||||||
|
references_file.write_text('{"text": "baz"}')
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError, match="Candidates file not found"):
|
||||||
|
read_paired_jsonl(tmp_path / "nonexistent.jsonl", references_file)
|
||||||
|
|
||||||
|
def test_read_paired_references_not_found(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading when references file doesn't exist."""
|
||||||
|
candidates_file = tmp_path / "candidates.jsonl"
|
||||||
|
candidates_file.write_text('{"text": "foo"}')
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError, match="References file not found"):
|
||||||
|
read_paired_jsonl(candidates_file, tmp_path / "nonexistent.jsonl")
|
||||||
|
|
||||||
|
def test_read_paired_missing_text_key(self, tmp_path: Path) -> None:
|
||||||
|
"""Test reading paired files with missing text key."""
|
||||||
|
candidates_file = tmp_path / "candidates.jsonl"
|
||||||
|
references_file = tmp_path / "references.jsonl"
|
||||||
|
|
||||||
|
candidates_file.write_text('{"value": "foo"}')
|
||||||
|
references_file.write_text('{"text": "baz"}')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing 'text' key in candidates file"):
|
||||||
|
read_paired_jsonl(candidates_file, references_file)
|
||||||
233
tests/test_cli/test_validate.py
Normal file
233
tests/test_cli/test_validate.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Tests for CLI validate command."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from veritext.cli.main import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateInline:
|
||||||
|
"""Tests for inline validation mode."""
|
||||||
|
|
||||||
|
def test_validate_inline_basic(self) -> None:
|
||||||
|
"""Test basic inline validation."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"The quick brown fox jumps",
|
||||||
|
"-r",
|
||||||
|
"The quick brown fox jumps",
|
||||||
|
"-m",
|
||||||
|
"bleu",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "bleu4" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_inline_with_rouge(self) -> None:
|
||||||
|
"""Test inline validation with ROUGE metric."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"hello world today",
|
||||||
|
"-r",
|
||||||
|
"hello world here",
|
||||||
|
"-m",
|
||||||
|
"rouge",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "rouge_l" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_inline_with_lexical(self) -> None:
|
||||||
|
"""Test inline validation with lexical metric."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"hello world",
|
||||||
|
"-r",
|
||||||
|
"hello everyone",
|
||||||
|
"-m",
|
||||||
|
"lexical",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "jaccard" in result.stdout
|
||||||
|
assert "token_overlap" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_inline_json_output(self) -> None:
|
||||||
|
"""Test inline validation with JSON output."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"hello world today",
|
||||||
|
"-r",
|
||||||
|
"hello world today",
|
||||||
|
"-m",
|
||||||
|
"bleu",
|
||||||
|
"-o",
|
||||||
|
"json",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
assert "bleu4" in data
|
||||||
|
|
||||||
|
def test_validate_inline_simple_output(self) -> None:
|
||||||
|
"""Test inline validation with simple output."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"hello world today",
|
||||||
|
"-r",
|
||||||
|
"hello world today",
|
||||||
|
"-m",
|
||||||
|
"rouge",
|
||||||
|
"-o",
|
||||||
|
"simple",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "rouge_l:" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_inline_missing_reference(self) -> None:
|
||||||
|
"""Test inline validation without reference."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["validate", "hello world", "-m", "bleu"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Error" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_inline_invalid_metric(self) -> None:
|
||||||
|
"""Test inline validation with invalid metric."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["validate", "hello", "-r", "world", "-m", "invalid_metric"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Unknown metrics" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateFile:
|
||||||
|
"""Tests for file-based validation mode."""
|
||||||
|
|
||||||
|
def test_validate_file_basic(self, tmp_path: Path) -> None:
|
||||||
|
"""Test basic file-based validation."""
|
||||||
|
data_file = tmp_path / "data.jsonl"
|
||||||
|
data_file.write_text(
|
||||||
|
'{"candidate": "hello world today", "reference": "hello world today"}\n'
|
||||||
|
'{"candidate": "foo bar baz", "reference": "foo bar baz"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["validate", "-f", str(data_file), "-m", "bleu"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "bleu4" in result.stdout
|
||||||
|
assert "Evaluated 2 text pairs" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_file_not_found(self) -> None:
|
||||||
|
"""Test file-based validation with non-existent file."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
["validate", "-f", "/nonexistent/file.jsonl", "-m", "bleu"],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Error" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_paired_files(self, tmp_path: Path) -> None:
|
||||||
|
"""Test validation with separate candidate and reference files."""
|
||||||
|
candidates_file = tmp_path / "candidates.jsonl"
|
||||||
|
references_file = tmp_path / "references.jsonl"
|
||||||
|
|
||||||
|
candidates_file.write_text(
|
||||||
|
'{"text": "hello world today"}\n{"text": "foo bar baz"}'
|
||||||
|
)
|
||||||
|
references_file.write_text(
|
||||||
|
'{"text": "hello world today"}\n{"text": "foo bar baz"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"-f",
|
||||||
|
str(candidates_file),
|
||||||
|
"-R",
|
||||||
|
str(references_file),
|
||||||
|
"-m",
|
||||||
|
"bleu",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Evaluated 2 text pairs" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateOptions:
|
||||||
|
"""Tests for validate command options."""
|
||||||
|
|
||||||
|
def test_validate_with_threshold(self) -> None:
|
||||||
|
"""Test validation with threshold option."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"hello world today",
|
||||||
|
"-r",
|
||||||
|
"hello world today",
|
||||||
|
"-m",
|
||||||
|
"bleu",
|
||||||
|
"-t",
|
||||||
|
"0.5",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Table output should include Status column
|
||||||
|
assert "Status" in result.stdout or "PASS" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_invalid_output_format(self) -> None:
|
||||||
|
"""Test validation with invalid output format."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"hello",
|
||||||
|
"-r",
|
||||||
|
"world",
|
||||||
|
"-m",
|
||||||
|
"bleu",
|
||||||
|
"-o",
|
||||||
|
"invalid",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Invalid output format" in result.stdout
|
||||||
|
|
||||||
|
def test_validate_multiple_metrics(self) -> None:
|
||||||
|
"""Test validation with multiple metrics."""
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"validate",
|
||||||
|
"The quick brown fox",
|
||||||
|
"-r",
|
||||||
|
"The quick brown fox",
|
||||||
|
"-m",
|
||||||
|
"bleu,rouge,lexical",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "bleu4" in result.stdout
|
||||||
|
assert "rouge_l" in result.stdout
|
||||||
|
assert "jaccard" in result.stdout
|
||||||
73
tests/test_core/test_config.py
Normal file
73
tests/test_core/test_config.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Tests for configuration module."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.config import VeritextSettings, get_settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestVeritextSettings:
|
||||||
|
"""Tests for VeritextSettings."""
|
||||||
|
|
||||||
|
def test_default_log_level(self) -> None:
|
||||||
|
"""Test default log level is INFO."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.log_level == "INFO"
|
||||||
|
|
||||||
|
def test_default_log_format(self) -> None:
|
||||||
|
"""Test default log format is console."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.log_format == "console"
|
||||||
|
|
||||||
|
def test_default_benchmark_path(self) -> None:
|
||||||
|
"""Test default benchmark storage path."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.benchmark_storage_path == Path("benchmarks")
|
||||||
|
|
||||||
|
def test_default_tokeniser_lowercase(self) -> None:
|
||||||
|
"""Test default tokeniser lowercase setting."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.tokeniser_lowercase is True
|
||||||
|
|
||||||
|
def test_default_tokeniser_remove_punctuation(self) -> None:
|
||||||
|
"""Test default tokeniser remove punctuation setting."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.tokeniser_remove_punctuation is True
|
||||||
|
|
||||||
|
def test_default_semantic_model(self) -> None:
|
||||||
|
"""Test default semantic model name."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.semantic_model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_default_semantic_cache_enabled(self) -> None:
|
||||||
|
"""Test semantic cache is enabled by default."""
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.semantic_cache_embeddings is True
|
||||||
|
|
||||||
|
def test_env_var_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test environment variable overrides default settings."""
|
||||||
|
monkeypatch.setenv("VERITEXT_LOG_LEVEL", "DEBUG")
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.log_level == "DEBUG"
|
||||||
|
|
||||||
|
def test_env_var_override_log_format(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
"""Test environment variable overrides log format."""
|
||||||
|
monkeypatch.setenv("VERITEXT_LOG_FORMAT", "json")
|
||||||
|
settings = VeritextSettings()
|
||||||
|
assert settings.log_format == "json"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSettings:
|
||||||
|
"""Tests for get_settings function."""
|
||||||
|
|
||||||
|
def test_get_settings_returns_instance(self) -> None:
|
||||||
|
"""Test get_settings returns a VeritextSettings instance."""
|
||||||
|
settings = get_settings()
|
||||||
|
assert isinstance(settings, VeritextSettings)
|
||||||
|
|
||||||
|
def test_get_settings_returns_valid_defaults(self) -> None:
|
||||||
|
"""Test get_settings returns instance with valid defaults."""
|
||||||
|
settings = get_settings()
|
||||||
|
assert settings.log_level in ("DEBUG", "INFO", "WARNING", "ERROR")
|
||||||
|
assert settings.log_format in ("console", "json")
|
||||||
56
tests/test_core/test_logging.py
Normal file
56
tests/test_core/test_logging.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Tests for logging module."""
|
||||||
|
|
||||||
|
from veritext.core.logging import configure_logging, get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetLogger:
|
||||||
|
"""Tests for get_logger function."""
|
||||||
|
|
||||||
|
def test_get_logger_returns_logger(self) -> None:
|
||||||
|
"""Test get_logger returns a logger instance."""
|
||||||
|
logger = get_logger()
|
||||||
|
assert logger is not None
|
||||||
|
|
||||||
|
def test_get_logger_default_name(self) -> None:
|
||||||
|
"""Test get_logger uses 'veritext' as default name."""
|
||||||
|
logger = get_logger()
|
||||||
|
# The logger should be a bound logger from structlog
|
||||||
|
assert hasattr(logger, "info")
|
||||||
|
assert hasattr(logger, "debug")
|
||||||
|
assert hasattr(logger, "warning")
|
||||||
|
assert hasattr(logger, "error")
|
||||||
|
|
||||||
|
def test_get_logger_custom_name(self) -> None:
|
||||||
|
"""Test get_logger respects custom name parameter."""
|
||||||
|
logger = get_logger("custom.module")
|
||||||
|
assert logger is not None
|
||||||
|
assert hasattr(logger, "info")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigureLogging:
|
||||||
|
"""Tests for configure_logging function."""
|
||||||
|
|
||||||
|
def test_configure_logging_console_format(self) -> None:
|
||||||
|
"""Test configure_logging with console format does not raise."""
|
||||||
|
configure_logging(level="INFO", log_format="console")
|
||||||
|
logger = get_logger()
|
||||||
|
assert logger is not None
|
||||||
|
|
||||||
|
def test_configure_logging_json_format(self) -> None:
|
||||||
|
"""Test configure_logging with json format does not raise."""
|
||||||
|
configure_logging(level="DEBUG", log_format="json")
|
||||||
|
logger = get_logger()
|
||||||
|
assert logger is not None
|
||||||
|
|
||||||
|
def test_configure_logging_uses_defaults(self) -> None:
|
||||||
|
"""Test configure_logging uses settings defaults when not provided."""
|
||||||
|
configure_logging()
|
||||||
|
logger = get_logger()
|
||||||
|
assert logger is not None
|
||||||
|
|
||||||
|
def test_configure_logging_different_levels(self) -> None:
|
||||||
|
"""Test configure_logging accepts different log levels."""
|
||||||
|
for level in ("DEBUG", "INFO", "WARNING", "ERROR"):
|
||||||
|
configure_logging(level=level)
|
||||||
|
logger = get_logger()
|
||||||
|
assert logger is not None
|
||||||
274
tests/test_metrics/test_readability.py
Normal file
274
tests/test_metrics/test_readability.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""Tests for the readability metric."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.metrics import Readability, ReadabilityResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadability:
|
||||||
|
"""Tests for the Readability metric class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def readability(self) -> Readability:
|
||||||
|
"""Provide a readability metric instance."""
|
||||||
|
return Readability()
|
||||||
|
|
||||||
|
def test_name(self, readability: Readability) -> None:
|
||||||
|
"""Test that name returns 'readability'."""
|
||||||
|
assert readability.name == "readability"
|
||||||
|
|
||||||
|
def test_requires_reference(self, readability: Readability) -> None:
|
||||||
|
"""Test that readability does NOT require reference text."""
|
||||||
|
assert readability.requires_reference is False
|
||||||
|
|
||||||
|
def test_simple_text(self, readability: Readability) -> None:
|
||||||
|
"""Test readability of simple, easy text."""
|
||||||
|
# Simple children's text - short sentences, simple words
|
||||||
|
text = "The cat sat. The dog ran. I see a bird."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should have low grade level and high reading ease
|
||||||
|
assert result.flesch_kincaid_grade < 5.0
|
||||||
|
assert result.flesch_reading_ease > 80.0
|
||||||
|
|
||||||
|
def test_complex_text(self, readability: Readability) -> None:
|
||||||
|
"""Test readability of complex, academic text."""
|
||||||
|
# Complex academic text - long sentences, polysyllabic words
|
||||||
|
text = (
|
||||||
|
"The implementation of sophisticated computational methodologies "
|
||||||
|
"necessitates comprehensive understanding of algorithmic complexity "
|
||||||
|
"and architectural considerations."
|
||||||
|
)
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should have high grade level and low reading ease
|
||||||
|
assert result.flesch_kincaid_grade > 12.0
|
||||||
|
assert result.flesch_reading_ease < 30.0
|
||||||
|
|
||||||
|
def test_medium_text(self, readability: Readability) -> None:
|
||||||
|
"""Test readability of medium-difficulty text."""
|
||||||
|
text = (
|
||||||
|
"The weather today is quite pleasant. "
|
||||||
|
"Many people are enjoying the sunshine in the park. "
|
||||||
|
"Children play while parents watch nearby."
|
||||||
|
)
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should be middle of the road
|
||||||
|
assert 3.0 < result.flesch_kincaid_grade < 10.0
|
||||||
|
assert 50.0 < result.flesch_reading_ease < 90.0
|
||||||
|
|
||||||
|
def test_single_sentence(self, readability: Readability) -> None:
|
||||||
|
"""Test readability with a single sentence."""
|
||||||
|
text = "The cat sat on the mat."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should compute without error
|
||||||
|
assert result.flesch_kincaid_grade is not None
|
||||||
|
assert result.flesch_reading_ease is not None
|
||||||
|
|
||||||
|
def test_single_word(self, readability: Readability) -> None:
|
||||||
|
"""Test readability with a single word."""
|
||||||
|
text = "Cat"
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should handle single word (1 word, 1 sentence, 1 syllable)
|
||||||
|
assert result.flesch_kincaid_grade is not None
|
||||||
|
assert result.flesch_reading_ease is not None
|
||||||
|
|
||||||
|
def test_empty_text(self, readability: Readability) -> None:
|
||||||
|
"""Test that empty text returns zero scores."""
|
||||||
|
result = readability.score("")
|
||||||
|
|
||||||
|
assert result.flesch_kincaid_grade == 0.0
|
||||||
|
assert result.flesch_reading_ease == 0.0
|
||||||
|
|
||||||
|
def test_whitespace_only(self, readability: Readability) -> None:
|
||||||
|
"""Test that whitespace-only text returns zero scores."""
|
||||||
|
result = readability.score(" \t\n ")
|
||||||
|
|
||||||
|
assert result.flesch_kincaid_grade == 0.0
|
||||||
|
assert result.flesch_reading_ease == 0.0
|
||||||
|
|
||||||
|
def test_reference_ignored(self, readability: Readability) -> None:
|
||||||
|
"""Test that reference parameter is ignored."""
|
||||||
|
text = "The cat sat on the mat."
|
||||||
|
|
||||||
|
# Score with no reference
|
||||||
|
result1 = readability.score(text)
|
||||||
|
# Score with reference (should be ignored)
|
||||||
|
result2 = readability.score(text, "Completely different text")
|
||||||
|
# Score with list of references
|
||||||
|
result3 = readability.score(text, ["ref1", "ref2"])
|
||||||
|
|
||||||
|
# All should produce identical results
|
||||||
|
assert result1.flesch_kincaid_grade == result2.flesch_kincaid_grade
|
||||||
|
assert result1.flesch_reading_ease == result2.flesch_reading_ease
|
||||||
|
assert result1.flesch_kincaid_grade == result3.flesch_kincaid_grade
|
||||||
|
|
||||||
|
def test_punctuation_handling(self, readability: Readability) -> None:
|
||||||
|
"""Test that punctuation affects sentence counting."""
|
||||||
|
# Same words, different sentence structure
|
||||||
|
text1 = "The cat sat on the mat" # 1 sentence
|
||||||
|
text2 = "The cat sat. On the mat." # 2 sentences
|
||||||
|
|
||||||
|
result1 = readability.score(text1)
|
||||||
|
result2 = readability.score(text2)
|
||||||
|
|
||||||
|
# Different sentence counts should affect scores
|
||||||
|
assert result1.flesch_kincaid_grade != result2.flesch_kincaid_grade
|
||||||
|
|
||||||
|
def test_question_marks_count_sentences(self, readability: Readability) -> None:
|
||||||
|
"""Test that question marks end sentences."""
|
||||||
|
text = "What is this? It is a test."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should count as 2 sentences
|
||||||
|
# With 7 words total, words_per_sentence = 3.5
|
||||||
|
assert result.flesch_kincaid_grade is not None
|
||||||
|
|
||||||
|
def test_exclamation_marks_count_sentences(self, readability: Readability) -> None:
|
||||||
|
"""Test that exclamation marks end sentences."""
|
||||||
|
text = "Wow! That is amazing!"
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should count as 2 sentences
|
||||||
|
assert result.flesch_kincaid_grade is not None
|
||||||
|
|
||||||
|
def test_multiple_punctuation(self, readability: Readability) -> None:
|
||||||
|
"""Test handling of multiple punctuation marks."""
|
||||||
|
text = "What?! That's crazy... Well then."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should handle gracefully
|
||||||
|
assert result.flesch_kincaid_grade is not None
|
||||||
|
|
||||||
|
def test_result_score_property(self, readability: Readability) -> None:
|
||||||
|
"""Test that result.score returns flesch_reading_ease."""
|
||||||
|
result = readability.score("The cat sat on the mat.")
|
||||||
|
assert result.score == result.flesch_reading_ease
|
||||||
|
|
||||||
|
def test_contractions(self, readability: Readability) -> None:
|
||||||
|
"""Test handling of contractions."""
|
||||||
|
text = "I'm going to the store. It's not far away."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should handle contractions as words
|
||||||
|
assert result.flesch_kincaid_grade is not None
|
||||||
|
assert result.flesch_reading_ease is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadabilityBatch:
|
||||||
|
"""Tests for readability batch scoring."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def readability(self) -> Readability:
|
||||||
|
"""Provide a readability metric instance."""
|
||||||
|
return Readability()
|
||||||
|
|
||||||
|
def test_batch_score_basic(self, readability: Readability) -> None:
|
||||||
|
"""Test basic batch scoring."""
|
||||||
|
candidates = [
|
||||||
|
"The cat sat on the mat.",
|
||||||
|
"A dog ran through the park.",
|
||||||
|
]
|
||||||
|
result = readability.batch_score(candidates)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
assert len(result.results) == 2
|
||||||
|
|
||||||
|
def test_batch_score_statistics(self, readability: Readability) -> None:
|
||||||
|
"""Test that batch scoring computes statistics."""
|
||||||
|
candidates = [
|
||||||
|
"Cat sat.", # Very simple
|
||||||
|
"The implementation of sophisticated methodologies requires expertise.",
|
||||||
|
]
|
||||||
|
result = readability.batch_score(candidates)
|
||||||
|
|
||||||
|
# Check statistics are computed
|
||||||
|
assert "flesch_kincaid_grade" in result.stats
|
||||||
|
assert "flesch_reading_ease" in result.stats
|
||||||
|
|
||||||
|
# First should be easier than second
|
||||||
|
assert (
|
||||||
|
result.results[0].flesch_reading_ease
|
||||||
|
> result.results[1].flesch_reading_ease
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_batch_score_percentiles(self, readability: Readability) -> None:
|
||||||
|
"""Test that batch scoring computes percentiles."""
|
||||||
|
candidates = ["a", "b", "c", "d", "e"]
|
||||||
|
result = readability.batch_score(candidates)
|
||||||
|
|
||||||
|
stats = result.stats["flesch_reading_ease"]
|
||||||
|
assert 25 in stats.percentiles
|
||||||
|
assert 50 in stats.percentiles
|
||||||
|
assert 75 in stats.percentiles
|
||||||
|
assert 95 in stats.percentiles
|
||||||
|
|
||||||
|
def test_batch_score_references_ignored(self, readability: Readability) -> None:
|
||||||
|
"""Test that batch scoring ignores references."""
|
||||||
|
candidates = ["The cat sat.", "A dog ran."]
|
||||||
|
|
||||||
|
result1 = readability.batch_score(candidates)
|
||||||
|
result2 = readability.batch_score(candidates, ["ref1", "ref2"])
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
assert result1.results[0].flesch_kincaid_grade == (
|
||||||
|
result2.results[0].flesch_kincaid_grade
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_batch_score_empty_list_raises(self, readability: Readability) -> None:
|
||||||
|
"""Test that empty candidate list raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="empty"):
|
||||||
|
readability.batch_score([])
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadabilityResult:
|
||||||
|
"""Tests for ReadabilityResult type."""
|
||||||
|
|
||||||
|
def test_frozen(self) -> None:
|
||||||
|
"""Test that ReadabilityResult is frozen."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
result = ReadabilityResult(flesch_kincaid_grade=5.0, flesch_reading_ease=70.0)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
result.flesch_kincaid_grade = 6.0 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_values(self) -> None:
|
||||||
|
"""Test that values are stored correctly."""
|
||||||
|
result = ReadabilityResult(flesch_kincaid_grade=8.5, flesch_reading_ease=65.0)
|
||||||
|
assert result.flesch_kincaid_grade == 8.5
|
||||||
|
assert result.flesch_reading_ease == 65.0
|
||||||
|
|
||||||
|
def test_score_property(self) -> None:
|
||||||
|
"""Test that score property returns flesch_reading_ease."""
|
||||||
|
result = ReadabilityResult(flesch_kincaid_grade=8.5, flesch_reading_ease=65.0)
|
||||||
|
assert result.score == 65.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyllableCounting:
|
||||||
|
"""Tests for syllable counting heuristics."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def readability(self) -> Readability:
|
||||||
|
"""Provide a readability metric instance."""
|
||||||
|
return Readability()
|
||||||
|
|
||||||
|
def test_monosyllabic_words(self, readability: Readability) -> None:
|
||||||
|
"""Test that monosyllabic words don't inflate scores."""
|
||||||
|
# All one-syllable words
|
||||||
|
text = "The cat sat on the mat."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should be very easy to read
|
||||||
|
assert result.flesch_reading_ease > 90.0
|
||||||
|
|
||||||
|
def test_polysyllabic_words(self, readability: Readability) -> None:
|
||||||
|
"""Test that polysyllabic words affect scores."""
|
||||||
|
# Words with multiple syllables
|
||||||
|
text = "International communication facilitates understanding."
|
||||||
|
result = readability.score(text)
|
||||||
|
|
||||||
|
# Should be harder to read
|
||||||
|
assert result.flesch_reading_ease < 50.0
|
||||||
295
tests/test_metrics/test_rouge.py
Normal file
295
tests/test_metrics/test_rouge.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Tests for the ROUGE metric."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.metrics import Rouge, RougeResult, RougeScore
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouge:
|
||||||
|
"""Tests for the Rouge metric class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rouge(self) -> Rouge:
|
||||||
|
"""Provide a ROUGE metric instance."""
|
||||||
|
return Rouge()
|
||||||
|
|
||||||
|
def test_name(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that name returns 'rouge'."""
|
||||||
|
assert rouge.name == "rouge"
|
||||||
|
|
||||||
|
def test_requires_reference(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that ROUGE requires reference text."""
|
||||||
|
assert rouge.requires_reference is True
|
||||||
|
|
||||||
|
def test_identical_texts(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that identical texts produce perfect scores."""
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result = rouge.score(text, text)
|
||||||
|
|
||||||
|
assert result.rouge1.precision == 1.0
|
||||||
|
assert result.rouge1.recall == 1.0
|
||||||
|
assert result.rouge1.fmeasure == 1.0
|
||||||
|
assert result.rouge2.fmeasure == 1.0
|
||||||
|
assert result.rouge_l.fmeasure == 1.0
|
||||||
|
|
||||||
|
def test_no_overlap(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that texts with no overlap produce zero scores."""
|
||||||
|
candidate = "apple banana cherry"
|
||||||
|
reference = "dog elephant fox"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
assert result.rouge1.precision == 0.0
|
||||||
|
assert result.rouge1.recall == 0.0
|
||||||
|
assert result.rouge1.fmeasure == 0.0
|
||||||
|
assert result.rouge2.fmeasure == 0.0
|
||||||
|
assert result.rouge_l.fmeasure == 0.0
|
||||||
|
|
||||||
|
def test_partial_overlap_rouge1(self, rouge: Rouge) -> None:
|
||||||
|
"""Test ROUGE-1 with partial overlap."""
|
||||||
|
candidate = "the cat sat"
|
||||||
|
reference = "the dog sat"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
# Candidate: {the, cat, sat}, Reference: {the, dog, sat}
|
||||||
|
# Overlap: {the, sat} = 2
|
||||||
|
# Precision = 2/3, Recall = 2/3
|
||||||
|
assert abs(result.rouge1.precision - 2 / 3) < 1e-10
|
||||||
|
assert abs(result.rouge1.recall - 2 / 3) < 1e-10
|
||||||
|
|
||||||
|
def test_partial_overlap_rouge2(self, rouge: Rouge) -> None:
|
||||||
|
"""Test ROUGE-2 (bigram) with partial overlap."""
|
||||||
|
candidate = "the cat sat on the mat"
|
||||||
|
reference = "the cat lay on the mat"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
# Bigrams in candidate: (the, cat), (cat, sat), (sat, on), (on, the), (the, mat)
|
||||||
|
# Bigrams in reference: (the, cat), (cat, lay), (lay, on), (on, the), (the, mat)
|
||||||
|
# Overlap: (the, cat), (on, the), (the, mat) = 3
|
||||||
|
# Precision = 3/5, Recall = 3/5
|
||||||
|
assert abs(result.rouge2.precision - 3 / 5) < 1e-10
|
||||||
|
assert abs(result.rouge2.recall - 3 / 5) < 1e-10
|
||||||
|
|
||||||
|
def test_rouge_l_basic(self, rouge: Rouge) -> None:
|
||||||
|
"""Test ROUGE-L (LCS) computation."""
|
||||||
|
candidate = "the cat sat on the mat"
|
||||||
|
reference = "the cat sat"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
# LCS = "the cat sat" = 3 tokens
|
||||||
|
# Precision = 3/6 = 0.5, Recall = 3/3 = 1.0
|
||||||
|
assert result.rouge_l.precision == 0.5
|
||||||
|
assert result.rouge_l.recall == 1.0
|
||||||
|
|
||||||
|
def test_rouge_l_non_contiguous(self, rouge: Rouge) -> None:
|
||||||
|
"""Test ROUGE-L with non-contiguous LCS."""
|
||||||
|
candidate = "the big cat sat"
|
||||||
|
reference = "the cat sat"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
# LCS = "the cat sat" = 3 (skipping "big")
|
||||||
|
# Precision = 3/4, Recall = 3/3 = 1.0
|
||||||
|
assert result.rouge_l.precision == 0.75
|
||||||
|
assert result.rouge_l.recall == 1.0
|
||||||
|
|
||||||
|
def test_precision_vs_recall(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that precision and recall differ appropriately."""
|
||||||
|
# Short candidate, long reference
|
||||||
|
candidate = "the cat"
|
||||||
|
reference = "the cat sat on the mat"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
# Precision should be high (all candidate tokens in reference)
|
||||||
|
assert result.rouge1.precision == 1.0
|
||||||
|
# Recall should be lower (not all reference tokens in candidate)
|
||||||
|
assert result.rouge1.recall < 1.0
|
||||||
|
|
||||||
|
def test_empty_candidate(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that empty candidate returns zero scores."""
|
||||||
|
result = rouge.score("", "The cat sat")
|
||||||
|
|
||||||
|
assert result.rouge1.fmeasure == 0.0
|
||||||
|
assert result.rouge2.fmeasure == 0.0
|
||||||
|
assert result.rouge_l.fmeasure == 0.0
|
||||||
|
|
||||||
|
def test_whitespace_only_candidate(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that whitespace-only candidate returns zero scores."""
|
||||||
|
result = rouge.score(" \t\n ", "The cat sat")
|
||||||
|
|
||||||
|
assert result.rouge1.fmeasure == 0.0
|
||||||
|
assert result.rouge_l.fmeasure == 0.0
|
||||||
|
|
||||||
|
def test_empty_reference_raises(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that empty reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
rouge.score("The cat sat", "")
|
||||||
|
|
||||||
|
def test_none_reference_raises(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that None reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
rouge.score("The cat sat", None)
|
||||||
|
|
||||||
|
def test_multiple_references_uses_max(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that multiple references use max scores."""
|
||||||
|
candidate = "the cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"a dog ran across the room", # Low overlap
|
||||||
|
"the cat sat on the mat", # Exact match
|
||||||
|
]
|
||||||
|
result = rouge.score(candidate, references)
|
||||||
|
|
||||||
|
# Should get perfect scores due to exact match
|
||||||
|
assert result.rouge1.fmeasure == 1.0
|
||||||
|
assert result.rouge_l.fmeasure == 1.0
|
||||||
|
|
||||||
|
def test_multiple_references_partial(self, rouge: Rouge) -> None:
|
||||||
|
"""Test multiple references with partial matches."""
|
||||||
|
candidate = "the quick brown fox"
|
||||||
|
references = [
|
||||||
|
"the fast brown fox", # 3/4 match
|
||||||
|
"a quick brown dog", # 3/4 match different tokens
|
||||||
|
]
|
||||||
|
result = rouge.score(candidate, references)
|
||||||
|
|
||||||
|
# Should pick best from either reference
|
||||||
|
assert result.rouge1.fmeasure > 0.0
|
||||||
|
|
||||||
|
def test_result_score_property(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that result.score returns rouge_l.fmeasure."""
|
||||||
|
result = rouge.score("The cat sat", "The cat sat")
|
||||||
|
assert result.score == result.rouge_l.fmeasure
|
||||||
|
|
||||||
|
def test_case_insensitivity(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that ROUGE is case insensitive by default."""
|
||||||
|
result = rouge.score("THE CAT SAT", "the cat sat")
|
||||||
|
assert result.rouge1.fmeasure == 1.0
|
||||||
|
assert result.rouge_l.fmeasure == 1.0
|
||||||
|
|
||||||
|
def test_punctuation_ignored(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that punctuation is ignored by default."""
|
||||||
|
result = rouge.score("The cat sat.", "The cat sat!")
|
||||||
|
assert result.rouge1.fmeasure == 1.0
|
||||||
|
|
||||||
|
def test_single_word(self, rouge: Rouge) -> None:
|
||||||
|
"""Test ROUGE with single word texts."""
|
||||||
|
result = rouge.score("cat", "cat")
|
||||||
|
|
||||||
|
assert result.rouge1.fmeasure == 1.0
|
||||||
|
# ROUGE-2 should be 0 for single words (no bigrams)
|
||||||
|
assert result.rouge2.fmeasure == 0.0
|
||||||
|
assert result.rouge_l.fmeasure == 1.0
|
||||||
|
|
||||||
|
def test_fmeasure_calculation(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that F-measure is calculated correctly."""
|
||||||
|
# Create a case where P != R
|
||||||
|
candidate = "the cat sat on"
|
||||||
|
reference = "the cat"
|
||||||
|
result = rouge.score(candidate, reference)
|
||||||
|
|
||||||
|
# P = 2/4 = 0.5, R = 2/2 = 1.0
|
||||||
|
# F = 2 * 0.5 * 1.0 / (0.5 + 1.0) = 1.0 / 1.5 = 2/3
|
||||||
|
expected_f = 2 * 0.5 * 1.0 / (0.5 + 1.0)
|
||||||
|
assert abs(result.rouge1.fmeasure - expected_f) < 1e-10
|
||||||
|
|
||||||
|
|
||||||
|
class TestRougeBatch:
|
||||||
|
"""Tests for ROUGE batch scoring."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rouge(self) -> Rouge:
|
||||||
|
"""Provide a ROUGE metric instance."""
|
||||||
|
return Rouge()
|
||||||
|
|
||||||
|
def test_batch_score_basic(self, rouge: Rouge) -> None:
|
||||||
|
"""Test basic batch scoring."""
|
||||||
|
candidates = ["The cat sat", "A dog runs"]
|
||||||
|
references = ["The cat sat", "A dog runs"]
|
||||||
|
result = rouge.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
assert len(result.results) == 2
|
||||||
|
assert all(r.rouge_l.fmeasure == 1.0 for r in result.results)
|
||||||
|
|
||||||
|
def test_batch_score_statistics(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that batch scoring computes statistics."""
|
||||||
|
candidates = ["The cat sat", "Completely different words"]
|
||||||
|
references = ["The cat sat", "The cat sat"]
|
||||||
|
result = rouge.batch_score(candidates, references)
|
||||||
|
|
||||||
|
# Check statistics are computed
|
||||||
|
assert "rouge1_fmeasure" in result.stats
|
||||||
|
assert "rouge2_fmeasure" in result.stats
|
||||||
|
assert "rouge_l_fmeasure" in result.stats
|
||||||
|
assert "rouge1_precision" in result.stats
|
||||||
|
assert "rouge1_recall" in result.stats
|
||||||
|
|
||||||
|
# First result should be 1.0, second should be 0.0
|
||||||
|
assert result.results[0].rouge1.fmeasure == 1.0
|
||||||
|
assert result.results[1].rouge1.fmeasure == 0.0
|
||||||
|
|
||||||
|
def test_batch_score_percentiles(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that batch scoring computes percentiles."""
|
||||||
|
candidates = ["a", "b", "c", "d", "e"]
|
||||||
|
references = ["a", "b", "c", "d", "e"]
|
||||||
|
result = rouge.batch_score(candidates, references)
|
||||||
|
|
||||||
|
stats = result.stats["rouge1_fmeasure"]
|
||||||
|
assert 25 in stats.percentiles
|
||||||
|
assert 50 in stats.percentiles
|
||||||
|
assert 75 in stats.percentiles
|
||||||
|
assert 95 in stats.percentiles
|
||||||
|
|
||||||
|
def test_batch_score_none_references_raises(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that batch scoring raises for None references."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
rouge.batch_score(["text"], None)
|
||||||
|
|
||||||
|
def test_batch_score_length_mismatch_raises(self, rouge: Rouge) -> None:
|
||||||
|
"""Test that batch scoring raises for mismatched lengths."""
|
||||||
|
with pytest.raises(ValueError, match="must match"):
|
||||||
|
rouge.batch_score(["a", "b"], ["a"])
|
||||||
|
|
||||||
|
def test_batch_score_with_multiple_references(self, rouge: Rouge) -> None:
|
||||||
|
"""Test batch scoring with multiple references per candidate."""
|
||||||
|
candidates = [
|
||||||
|
"The cat sat on the mat",
|
||||||
|
"A quick brown fox",
|
||||||
|
]
|
||||||
|
references = [
|
||||||
|
["The cat sat on the mat", "A cat rests on floor"],
|
||||||
|
["A quick brown fox", "The fast brown fox"],
|
||||||
|
]
|
||||||
|
result = rouge.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
# Both should get perfect scores due to exact matches
|
||||||
|
assert result.results[0].rouge_l.fmeasure == 1.0
|
||||||
|
assert result.results[1].rouge_l.fmeasure == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestRougeResult:
|
||||||
|
"""Tests for RougeResult and RougeScore types."""
|
||||||
|
|
||||||
|
def test_rouge_score_frozen(self) -> None:
|
||||||
|
"""Test that RougeScore is frozen."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
score = RougeScore(precision=0.5, recall=0.6, fmeasure=0.55)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
score.precision = 0.7 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_rouge_result_frozen(self) -> None:
|
||||||
|
"""Test that RougeResult is frozen."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
score = RougeScore(precision=0.5, recall=0.6, fmeasure=0.55)
|
||||||
|
result = RougeResult(rouge1=score, rouge2=score, rouge_l=score)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
result.rouge1 = score # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_score_property(self) -> None:
|
||||||
|
"""Test that score property returns rouge_l.fmeasure."""
|
||||||
|
r1 = RougeScore(precision=0.9, recall=0.9, fmeasure=0.9)
|
||||||
|
r2 = RougeScore(precision=0.8, recall=0.8, fmeasure=0.8)
|
||||||
|
rl = RougeScore(precision=0.7, recall=0.7, fmeasure=0.7)
|
||||||
|
result = RougeResult(rouge1=r1, rouge2=r2, rouge_l=rl)
|
||||||
|
assert result.score == 0.7
|
||||||
1
tests/test_pytest_plugin/__init__.py
Normal file
1
tests/test_pytest_plugin/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the Veritext pytest plugin."""
|
||||||
32
tests/test_pytest_plugin/conftest.py
Normal file
32
tests/test_pytest_plugin/conftest.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Pytest configuration for pytest_plugin tests."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.pytest_plugin.fixtures import ValidatorFactory
|
||||||
|
|
||||||
|
# Enable the pytester fixture for plugin testing
|
||||||
|
pytest_plugins = ["pytester"]
|
||||||
|
|
||||||
|
# Re-export fixtures from the plugin module for testing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text_validator() -> ValidatorFactory:
|
||||||
|
"""Provide a factory for building validators."""
|
||||||
|
return ValidatorFactory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validation_context() -> type:
|
||||||
|
"""Provide a factory for creating ValidationContext objects."""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
|
||||||
|
def _create(
|
||||||
|
reference: str | list[str] | None = None,
|
||||||
|
**metadata: Any,
|
||||||
|
) -> ValidationContext:
|
||||||
|
return ValidationContext(reference=reference, metadata=metadata)
|
||||||
|
|
||||||
|
return _create
|
||||||
211
tests/test_pytest_plugin/test_assertions.py
Normal file
211
tests/test_pytest_plugin/test_assertions.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""Tests for the validate_text assertion function."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.pytest_plugin import validate_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextBasicValidation:
|
||||||
|
"""Test basic validation scenarios."""
|
||||||
|
|
||||||
|
def test_passes_with_valid_length(self) -> None:
|
||||||
|
"""Test validation passes when length constraints are met."""
|
||||||
|
text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
validate_text(text, min_length=10, max_length=100)
|
||||||
|
|
||||||
|
def test_fails_when_too_short(self) -> None:
|
||||||
|
"""Test validation fails when text is below minimum length."""
|
||||||
|
text = "Short."
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, min_length=50)
|
||||||
|
assert "length" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_fails_when_too_long(self) -> None:
|
||||||
|
"""Test validation fails when text exceeds maximum length."""
|
||||||
|
text = "A" * 100
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, max_length=50)
|
||||||
|
assert "length" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextReadability:
|
||||||
|
"""Test readability validation."""
|
||||||
|
|
||||||
|
def test_passes_with_simple_text(self) -> None:
|
||||||
|
"""Test validation passes for simple, readable text."""
|
||||||
|
text = "The cat sat on the mat. It was a nice day."
|
||||||
|
validate_text(text, max_reading_grade=10.0)
|
||||||
|
|
||||||
|
def test_fails_with_complex_text(self) -> None:
|
||||||
|
"""Test validation fails for overly complex text."""
|
||||||
|
text = (
|
||||||
|
"The implementation of sophisticated metacognitive strategies "
|
||||||
|
"necessitates the comprehensive understanding of epistemological "
|
||||||
|
"frameworks and their corresponding methodological implications."
|
||||||
|
)
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, max_reading_grade=3.0)
|
||||||
|
assert "readability" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextPatterns:
|
||||||
|
"""Test pattern matching validation."""
|
||||||
|
|
||||||
|
def test_passes_when_contains_pattern(self) -> None:
|
||||||
|
"""Test validation passes when required pattern is present."""
|
||||||
|
text = "Please contact support@example.com for assistance."
|
||||||
|
validate_text(text, must_contain=["support@example.com"])
|
||||||
|
|
||||||
|
def test_fails_when_missing_required_pattern(self) -> None:
|
||||||
|
"""Test validation fails when required pattern is missing."""
|
||||||
|
text = "Please contact us for assistance."
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, must_contain=["@example.com"])
|
||||||
|
assert "contains" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_passes_when_excludes_pattern(self) -> None:
|
||||||
|
"""Test validation passes when forbidden pattern is absent."""
|
||||||
|
text = "The report is complete and reviewed."
|
||||||
|
validate_text(text, must_exclude=["TODO", "FIXME"])
|
||||||
|
|
||||||
|
def test_fails_when_contains_forbidden_pattern(self) -> None:
|
||||||
|
"""Test validation fails when forbidden pattern is present."""
|
||||||
|
text = "The report is almost done. TODO: add conclusion."
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, must_exclude=["TODO"])
|
||||||
|
assert "excludes" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextComparisonMetrics:
|
||||||
|
"""Test comparison-based validation (BLEU, ROUGE)."""
|
||||||
|
|
||||||
|
def test_passes_with_high_bleu_score(self) -> None:
|
||||||
|
"""Test validation passes when BLEU score meets threshold."""
|
||||||
|
reference = "The quick brown fox jumps over the lazy dog."
|
||||||
|
text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
validate_text(text, reference=reference, min_bleu=0.9)
|
||||||
|
|
||||||
|
def test_fails_with_low_bleu_score(self) -> None:
|
||||||
|
"""Test validation fails when BLEU score is below threshold."""
|
||||||
|
reference = "The quick brown fox jumps over the lazy dog."
|
||||||
|
text = "A slow red cat sleeps under the active mouse."
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, reference=reference, min_bleu=0.5)
|
||||||
|
assert "bleu" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_passes_with_high_rouge_score(self) -> None:
|
||||||
|
"""Test validation passes when ROUGE score meets threshold."""
|
||||||
|
reference = "Machine learning models require extensive training data."
|
||||||
|
text = "Machine learning models need extensive training data."
|
||||||
|
validate_text(text, reference=reference, min_rouge=0.5)
|
||||||
|
|
||||||
|
def test_fails_with_low_rouge_score(self) -> None:
|
||||||
|
"""Test validation fails when ROUGE score is below threshold."""
|
||||||
|
reference = "The algorithm processes input data efficiently."
|
||||||
|
text = "Cats enjoy sleeping in sunny spots."
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, reference=reference, min_rouge=0.5)
|
||||||
|
assert "rouge" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextErrorHandling:
|
||||||
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
def test_raises_value_error_when_no_criteria(self) -> None:
|
||||||
|
"""Test that ValueError is raised when no validation criteria provided."""
|
||||||
|
with pytest.raises(ValueError, match="At least one validation criterion"):
|
||||||
|
validate_text("Some text")
|
||||||
|
|
||||||
|
def test_raises_value_error_when_bleu_without_reference(self) -> None:
|
||||||
|
"""Test that ValueError is raised when BLEU requested without reference."""
|
||||||
|
with pytest.raises(ValueError, match="Reference text required"):
|
||||||
|
validate_text("Some text", min_bleu=0.5)
|
||||||
|
|
||||||
|
def test_raises_value_error_when_rouge_without_reference(self) -> None:
|
||||||
|
"""Test that ValueError is raised when ROUGE requested without reference."""
|
||||||
|
with pytest.raises(ValueError, match="Reference text required"):
|
||||||
|
validate_text("Some text", min_rouge=0.5)
|
||||||
|
|
||||||
|
def test_raises_value_error_when_semantic_without_reference(self) -> None:
|
||||||
|
"""Test that ValueError is raised for semantic without reference."""
|
||||||
|
with pytest.raises(ValueError, match="Reference text required"):
|
||||||
|
validate_text("Some text", min_semantic=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextMultipleCriteria:
|
||||||
|
"""Test validation with multiple criteria combined."""
|
||||||
|
|
||||||
|
def test_passes_all_criteria(self) -> None:
|
||||||
|
"""Test validation passes when all criteria are met."""
|
||||||
|
reference = "The quick brown fox jumps over the lazy dog."
|
||||||
|
text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
validate_text(
|
||||||
|
text,
|
||||||
|
reference=reference,
|
||||||
|
min_bleu=0.9,
|
||||||
|
min_length=10,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fails_when_one_criterion_fails(self) -> None:
|
||||||
|
"""Test validation fails when any criterion fails."""
|
||||||
|
reference = "The quick brown fox jumps over the lazy dog."
|
||||||
|
text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
validate_text(
|
||||||
|
text,
|
||||||
|
reference=reference,
|
||||||
|
min_bleu=0.9,
|
||||||
|
max_length=10, # This will fail
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextFailureMessage:
|
||||||
|
"""Test failure message formatting."""
|
||||||
|
|
||||||
|
def test_failure_message_includes_text_preview(self) -> None:
|
||||||
|
"""Test that failure message includes preview of the text."""
|
||||||
|
text = "Short text"
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, min_length=100)
|
||||||
|
assert "Short text" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_failure_message_truncates_long_text(self) -> None:
|
||||||
|
"""Test that long text is truncated in failure message."""
|
||||||
|
text = "A" * 200
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, max_length=50)
|
||||||
|
message = str(exc_info.value)
|
||||||
|
assert "..." in message
|
||||||
|
assert "A" * 200 not in message
|
||||||
|
|
||||||
|
def test_failure_message_includes_check_details(self) -> None:
|
||||||
|
"""Test that failure message includes check name and details."""
|
||||||
|
text = "Short"
|
||||||
|
with pytest.raises(AssertionError) as exc_info:
|
||||||
|
validate_text(text, min_length=100)
|
||||||
|
message = str(exc_info.value)
|
||||||
|
assert "Failed checks:" in message
|
||||||
|
assert "length" in message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateTextListReference:
|
||||||
|
"""Test validation with list of reference texts."""
|
||||||
|
|
||||||
|
def test_bleu_with_multiple_references(self) -> None:
|
||||||
|
"""Test BLEU validation accepts multiple reference texts."""
|
||||||
|
references = [
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"A fast brown fox leaps over a sleepy dog.",
|
||||||
|
]
|
||||||
|
text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
validate_text(text, reference=references, min_bleu=0.9)
|
||||||
|
|
||||||
|
def test_rouge_with_multiple_references(self) -> None:
|
||||||
|
"""Test ROUGE validation accepts multiple reference texts."""
|
||||||
|
references = [
|
||||||
|
"Machine learning requires data.",
|
||||||
|
"ML models need training data.",
|
||||||
|
]
|
||||||
|
text = "Machine learning models require training data."
|
||||||
|
validate_text(text, reference=references, min_rouge=0.3)
|
||||||
88
tests/test_pytest_plugin/test_fixtures.py
Normal file
88
tests/test_pytest_plugin/test_fixtures.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Tests for the pytest plugin fixtures."""
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.pytest_plugin.fixtures import ValidatorFactory
|
||||||
|
from veritext.validators import bleu, length
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidatorFactory:
|
||||||
|
"""Test the ValidatorFactory class."""
|
||||||
|
|
||||||
|
def test_creates_validator_from_checks(self) -> None:
|
||||||
|
"""Test that factory creates a callable validator."""
|
||||||
|
factory = ValidatorFactory()
|
||||||
|
validate = factory(checks=[length(min_chars=5)])
|
||||||
|
|
||||||
|
result = validate("Hello, World!")
|
||||||
|
assert result.passed
|
||||||
|
|
||||||
|
def test_validator_uses_provided_reference(self) -> None:
|
||||||
|
"""Test that factory passes reference to context."""
|
||||||
|
factory = ValidatorFactory()
|
||||||
|
reference = "The quick brown fox."
|
||||||
|
validate = factory(
|
||||||
|
checks=[bleu(min_score=0.5)],
|
||||||
|
reference=reference,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exact match should pass
|
||||||
|
result = validate("The quick brown fox.")
|
||||||
|
assert result.passed
|
||||||
|
|
||||||
|
def test_validator_returns_validation_result(self) -> None:
|
||||||
|
"""Test that validator returns a ValidationResult."""
|
||||||
|
factory = ValidatorFactory()
|
||||||
|
validate = factory(checks=[length(min_chars=100)])
|
||||||
|
|
||||||
|
result = validate("Short")
|
||||||
|
assert not result.passed
|
||||||
|
assert len(result.checks) == 1
|
||||||
|
assert result.checks[0].name == "length"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextValidatorFixture:
|
||||||
|
"""Test the text_validator fixture."""
|
||||||
|
|
||||||
|
def test_fixture_returns_factory(self, text_validator: ValidatorFactory) -> None:
|
||||||
|
"""Test that fixture provides a ValidatorFactory."""
|
||||||
|
assert isinstance(text_validator, ValidatorFactory)
|
||||||
|
|
||||||
|
def test_fixture_can_create_validators(
|
||||||
|
self,
|
||||||
|
text_validator: ValidatorFactory,
|
||||||
|
) -> None:
|
||||||
|
"""Test that fixture can be used to create validators."""
|
||||||
|
validate = text_validator(checks=[length(min_chars=5, max_chars=50)])
|
||||||
|
|
||||||
|
assert validate("Hello, World!").passed
|
||||||
|
assert not validate("Hi").passed
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationContextFixture:
|
||||||
|
"""Test the validation_context fixture."""
|
||||||
|
|
||||||
|
def test_fixture_creates_context(
|
||||||
|
self,
|
||||||
|
validation_context: type,
|
||||||
|
) -> None:
|
||||||
|
"""Test that fixture creates ValidationContext."""
|
||||||
|
ctx = validation_context(reference="Test reference")
|
||||||
|
assert isinstance(ctx, ValidationContext)
|
||||||
|
assert ctx.reference == "Test reference"
|
||||||
|
|
||||||
|
def test_fixture_accepts_metadata(
|
||||||
|
self,
|
||||||
|
validation_context: type,
|
||||||
|
) -> None:
|
||||||
|
"""Test that fixture passes metadata to context."""
|
||||||
|
ctx = validation_context(reference="Test", source="unit_test", version=1)
|
||||||
|
assert ctx.metadata["source"] == "unit_test"
|
||||||
|
assert ctx.metadata["version"] == 1
|
||||||
|
|
||||||
|
def test_fixture_allows_no_reference(
|
||||||
|
self,
|
||||||
|
validation_context: type,
|
||||||
|
) -> None:
|
||||||
|
"""Test that fixture allows creating context without reference."""
|
||||||
|
ctx = validation_context()
|
||||||
|
assert ctx.reference is None
|
||||||
99
tests/test_pytest_plugin/test_plugin.py
Normal file
99
tests/test_pytest_plugin/test_plugin.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Tests for the pytest plugin hooks."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def plugin_pytester(pytester: pytest.Pytester) -> pytest.Pytester:
|
||||||
|
"""Configure pytester to use the veritext plugin.
|
||||||
|
|
||||||
|
Note: The plugin is already loaded via the entry point in pyproject.toml,
|
||||||
|
so no explicit pytest_plugins declaration is needed.
|
||||||
|
"""
|
||||||
|
return pytester
|
||||||
|
|
||||||
|
|
||||||
|
def test_plugin_registers_marker(plugin_pytester: pytest.Pytester) -> None:
|
||||||
|
"""Test that the text_validation marker is registered."""
|
||||||
|
plugin_pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.text_validation
|
||||||
|
def test_example():
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Run with strict markers - this will fail if marker isn't registered
|
||||||
|
result = plugin_pytester.runpytest("--strict-markers")
|
||||||
|
result.assert_outcomes(passed=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_marker_can_be_used(plugin_pytester: pytest.Pytester) -> None:
|
||||||
|
"""Test that the text_validation marker can filter tests."""
|
||||||
|
plugin_pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.text_validation
|
||||||
|
def test_marked():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_unmarked():
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Run only marked tests
|
||||||
|
result = plugin_pytester.runpytest("-m", "text_validation")
|
||||||
|
result.assert_outcomes(passed=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_text_is_importable(plugin_pytester: pytest.Pytester) -> None:
|
||||||
|
"""Test that validate_text can be imported from the plugin."""
|
||||||
|
plugin_pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
from veritext.pytest_plugin import validate_text
|
||||||
|
|
||||||
|
def test_import():
|
||||||
|
assert callable(validate_text)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = plugin_pytester.runpytest()
|
||||||
|
result.assert_outcomes(passed=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_text_works_in_tests(plugin_pytester: pytest.Pytester) -> None:
|
||||||
|
"""Test that validate_text can be used in test functions."""
|
||||||
|
plugin_pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
from veritext.pytest_plugin import validate_text
|
||||||
|
|
||||||
|
def test_validation_passes():
|
||||||
|
validate_text(
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
min_length=10,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = plugin_pytester.runpytest()
|
||||||
|
result.assert_outcomes(passed=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_text_failure_in_tests(plugin_pytester: pytest.Pytester) -> None:
|
||||||
|
"""Test that validate_text failures are reported properly."""
|
||||||
|
plugin_pytester.makepyfile(
|
||||||
|
"""
|
||||||
|
from veritext.pytest_plugin import validate_text
|
||||||
|
|
||||||
|
def test_validation_fails():
|
||||||
|
validate_text(
|
||||||
|
"Short",
|
||||||
|
min_length=100,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = plugin_pytester.runpytest()
|
||||||
|
result.assert_outcomes(failed=1)
|
||||||
|
# Check that failure message contains useful information
|
||||||
|
result.stdout.fnmatch_lines(["*Text validation failed*"])
|
||||||
1
tests/test_semantic/__init__.py
Normal file
1
tests/test_semantic/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for semantic similarity module."""
|
||||||
240
tests/test_semantic/test_similarity.py
Normal file
240
tests/test_semantic/test_similarity.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Tests for the semantic similarity metric."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Skip all tests if sentence-transformers is not installed
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
|
||||||
|
from veritext.metrics.results import SemanticResult
|
||||||
|
from veritext.semantic import SemanticSimilarity
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSimilarity:
|
||||||
|
"""Tests for the SemanticSimilarity metric class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def semantic(self) -> SemanticSimilarity:
|
||||||
|
"""Provide a SemanticSimilarity metric instance."""
|
||||||
|
return SemanticSimilarity()
|
||||||
|
|
||||||
|
def test_name(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that name returns 'semantic'."""
|
||||||
|
assert semantic.name == "semantic"
|
||||||
|
|
||||||
|
def test_requires_reference(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that semantic similarity requires reference text."""
|
||||||
|
assert semantic.requires_reference is True
|
||||||
|
|
||||||
|
def test_identical_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that identical texts produce high similarity."""
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Identical texts should have very high similarity (close to 1.0)
|
||||||
|
assert result.similarity >= 0.99
|
||||||
|
assert result.model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
def test_semantically_similar_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that semantically similar texts have high similarity."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
reference = "A feline rested on the rug"
|
||||||
|
result = semantic.score(candidate, reference)
|
||||||
|
|
||||||
|
# Similar meanings should have reasonable similarity
|
||||||
|
assert result.similarity > 0.3
|
||||||
|
|
||||||
|
def test_unrelated_texts(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that unrelated texts have low similarity."""
|
||||||
|
candidate = "The quick brown fox"
|
||||||
|
reference = "Quantum physics describes particle behaviour"
|
||||||
|
result = semantic.score(candidate, reference)
|
||||||
|
|
||||||
|
# Unrelated texts should have low similarity
|
||||||
|
assert result.similarity < 0.5
|
||||||
|
|
||||||
|
def test_empty_candidate(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that empty candidate returns zero similarity."""
|
||||||
|
result = semantic.score("", "The cat sat on the mat")
|
||||||
|
assert result.similarity == 0.0
|
||||||
|
|
||||||
|
def test_whitespace_only_candidate(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that whitespace-only candidate returns zero similarity."""
|
||||||
|
result = semantic.score(" \t\n ", "The cat sat on the mat")
|
||||||
|
assert result.similarity == 0.0
|
||||||
|
|
||||||
|
def test_none_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that None reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
semantic.score("The cat sat", None)
|
||||||
|
|
||||||
|
def test_empty_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that empty reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
semantic.score("The cat sat", "")
|
||||||
|
|
||||||
|
def test_whitespace_reference_raises(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that whitespace-only reference raises ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
semantic.score("The cat sat", " \t\n ")
|
||||||
|
|
||||||
|
def test_multiple_references(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test semantic similarity with multiple references uses max."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"A dog ran through the park",
|
||||||
|
"The cat sat on the mat", # Exact match
|
||||||
|
]
|
||||||
|
result = semantic.score(candidate, references)
|
||||||
|
|
||||||
|
# Should get high similarity due to exact match reference
|
||||||
|
assert result.similarity >= 0.99
|
||||||
|
|
||||||
|
def test_multiple_references_takes_max(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that multiple references returns maximum similarity."""
|
||||||
|
candidate = "The cat sat on the mat"
|
||||||
|
references = [
|
||||||
|
"Quantum physics is complex", # Low similarity
|
||||||
|
"A feline rested on the rug", # Higher similarity
|
||||||
|
]
|
||||||
|
result = semantic.score(candidate, references)
|
||||||
|
|
||||||
|
# Should use the higher similarity
|
||||||
|
assert result.similarity > 0.3
|
||||||
|
|
||||||
|
def test_result_score_property(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that result.score returns similarity."""
|
||||||
|
result = semantic.score("The cat sat", "The cat sat")
|
||||||
|
assert result.score == result.similarity
|
||||||
|
|
||||||
|
def test_caching_behaviour(self) -> None:
|
||||||
|
"""Test that caching works for repeated texts."""
|
||||||
|
semantic = SemanticSimilarity(cache_embeddings=True)
|
||||||
|
|
||||||
|
# Score same texts multiple times
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result1 = semantic.score(text, text)
|
||||||
|
result2 = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Results should be identical
|
||||||
|
assert result1.similarity == result2.similarity
|
||||||
|
|
||||||
|
# Clear cache and check again
|
||||||
|
semantic.clear_cache()
|
||||||
|
result3 = semantic.score(text, text)
|
||||||
|
assert result3.similarity == result1.similarity
|
||||||
|
|
||||||
|
def test_caching_disabled(self) -> None:
|
||||||
|
"""Test that caching can be disabled."""
|
||||||
|
semantic = SemanticSimilarity(cache_embeddings=False)
|
||||||
|
|
||||||
|
text = "The cat sat on the mat"
|
||||||
|
result1 = semantic.score(text, text)
|
||||||
|
result2 = semantic.score(text, text)
|
||||||
|
|
||||||
|
# Results should still be identical (just not cached)
|
||||||
|
assert result1.similarity == result2.similarity
|
||||||
|
|
||||||
|
# Clear cache should not raise even when disabled
|
||||||
|
semantic.clear_cache()
|
||||||
|
|
||||||
|
def test_custom_model(self) -> None:
|
||||||
|
"""Test that custom model name is recorded in result."""
|
||||||
|
# Use the same model but verify it's recorded correctly
|
||||||
|
semantic = SemanticSimilarity(model="all-MiniLM-L6-v2")
|
||||||
|
result = semantic.score("Test text", "Test text")
|
||||||
|
assert result.model == "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticSimilarityBatch:
|
||||||
|
"""Tests for semantic similarity batch scoring."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def semantic(self) -> SemanticSimilarity:
|
||||||
|
"""Provide a SemanticSimilarity metric instance."""
|
||||||
|
return SemanticSimilarity()
|
||||||
|
|
||||||
|
def test_batch_score_basic(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test basic batch scoring."""
|
||||||
|
candidates = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||||
|
references = ["The cat sat on the mat", "A quick brown dog runs fast"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
assert len(result.results) == 2
|
||||||
|
# Identical texts should have very high similarity
|
||||||
|
assert all(r.similarity >= 0.99 for r in result.results)
|
||||||
|
|
||||||
|
def test_batch_score_statistics(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that batch scoring computes statistics."""
|
||||||
|
candidates = ["The cat sat", "Quantum physics is complex"]
|
||||||
|
references = ["The cat sat", "The cat sat"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
# Check statistics are computed
|
||||||
|
assert "similarity" in result.stats
|
||||||
|
|
||||||
|
# Mean should be between min and max
|
||||||
|
stats = result.stats["similarity"]
|
||||||
|
assert stats.min <= stats.mean <= stats.max
|
||||||
|
|
||||||
|
def test_batch_score_percentiles(self, semantic: SemanticSimilarity) -> None:
|
||||||
|
"""Test that batch scoring computes percentiles."""
|
||||||
|
candidates = ["a", "b", "c", "d", "e"]
|
||||||
|
references = ["a", "b", "c", "d", "e"]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
stats = result.stats["similarity"]
|
||||||
|
assert 25 in stats.percentiles
|
||||||
|
assert 50 in stats.percentiles
|
||||||
|
assert 75 in stats.percentiles
|
||||||
|
assert 95 in stats.percentiles
|
||||||
|
|
||||||
|
def test_batch_score_none_references_raises(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test that batch scoring raises for None references."""
|
||||||
|
with pytest.raises(ValueError, match="requires reference"):
|
||||||
|
semantic.batch_score(["text"], None)
|
||||||
|
|
||||||
|
def test_batch_score_length_mismatch_raises(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test that batch scoring raises for mismatched lengths."""
|
||||||
|
with pytest.raises(ValueError, match="must match"):
|
||||||
|
semantic.batch_score(["a", "b"], ["a"])
|
||||||
|
|
||||||
|
def test_batch_score_with_multiple_references(
|
||||||
|
self, semantic: SemanticSimilarity
|
||||||
|
) -> None:
|
||||||
|
"""Test batch scoring with multiple references per candidate."""
|
||||||
|
candidates = [
|
||||||
|
"The cat sat on the mat",
|
||||||
|
"A quick brown dog runs fast",
|
||||||
|
]
|
||||||
|
references = [
|
||||||
|
["The cat sat on the mat", "A cat rests on floor"],
|
||||||
|
["A quick brown dog runs fast", "Dogs run very quickly"],
|
||||||
|
]
|
||||||
|
result = semantic.batch_score(candidates, references)
|
||||||
|
|
||||||
|
assert result.count == 2
|
||||||
|
# First pair has exact match
|
||||||
|
assert result.results[0].similarity >= 0.99
|
||||||
|
assert result.results[1].similarity >= 0.99
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticResult:
|
||||||
|
"""Tests for SemanticResult type."""
|
||||||
|
|
||||||
|
def test_frozen(self) -> None:
|
||||||
|
"""Test that SemanticResult is frozen."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
result = SemanticResult(similarity=0.85, model="test-model")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
result.similarity = 0.9 # type: ignore[misc]
|
||||||
|
|
||||||
|
def test_score_property(self) -> None:
|
||||||
|
"""Test that score property returns similarity."""
|
||||||
|
result = SemanticResult(similarity=0.75, model="test-model")
|
||||||
|
assert result.score == 0.75
|
||||||
1
tests/test_validators/__init__.py
Normal file
1
tests/test_validators/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the validators module."""
|
||||||
198
tests/test_validators/test_composite.py
Normal file
198
tests/test_validators/test_composite.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""Tests for composite validators."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import all_of, any_of, bleu, contains, excludes, length
|
||||||
|
from veritext.validators.composite import AllOf, AnyOf
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllOf:
|
||||||
|
"""Tests for AllOf composite validator."""
|
||||||
|
|
||||||
|
def test_all_of_passes_when_all_checks_pass(self) -> None:
|
||||||
|
"""Test that AllOf passes when all checks pass."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=2),
|
||||||
|
contains(patterns=["hello"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
assert all(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_all_of_fails_when_one_check_fails(self) -> None:
|
||||||
|
"""Test that AllOf fails when any check fails."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=2),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
assert len(result.failed_checks) == 1
|
||||||
|
|
||||||
|
def test_all_of_fails_when_all_checks_fail(self) -> None:
|
||||||
|
"""Test that AllOf fails when all checks fail."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert len(result.failed_checks) == 2
|
||||||
|
|
||||||
|
def test_all_of_with_metric_validators(self) -> None:
|
||||||
|
"""Test AllOf with metric-based validators."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
bleu(min_score=0.5),
|
||||||
|
length(min_words=3),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext(reference="the quick brown fox")
|
||||||
|
result = validator.check("the quick brown fox jumps", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
|
||||||
|
def test_all_of_failure_summary(self) -> None:
|
||||||
|
"""Test the failure summary property."""
|
||||||
|
validator = AllOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
summary = result.failure_summary
|
||||||
|
assert "failed" in summary.lower()
|
||||||
|
assert "length" in summary
|
||||||
|
assert "contains" in summary
|
||||||
|
|
||||||
|
def test_all_of_raises_on_empty_checks(self) -> None:
|
||||||
|
"""Test that empty checks list raises error."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
AllOf(checks=[])
|
||||||
|
|
||||||
|
def test_all_of_name_property(self) -> None:
|
||||||
|
"""Test the name property."""
|
||||||
|
validator = AllOf(checks=[length(min_chars=1)])
|
||||||
|
assert validator.name == "all_of"
|
||||||
|
|
||||||
|
def test_all_of_factory_function(self) -> None:
|
||||||
|
"""Test the all_of() factory function."""
|
||||||
|
validator = all_of(checks=[length(min_chars=1)])
|
||||||
|
assert isinstance(validator, AllOf)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnyOf:
|
||||||
|
"""Tests for AnyOf composite validator."""
|
||||||
|
|
||||||
|
def test_any_of_passes_when_any_check_passes(self) -> None:
|
||||||
|
"""Test that AnyOf passes when any check passes."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10), # Will fail
|
||||||
|
contains(patterns=["hello"]), # Will pass
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert len(result.checks) == 2
|
||||||
|
# At least one check passed
|
||||||
|
assert any(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_any_of_passes_when_all_checks_pass(self) -> None:
|
||||||
|
"""Test that AnyOf passes when all checks pass."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=2),
|
||||||
|
contains(patterns=["hello"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert all(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_any_of_fails_when_all_checks_fail(self) -> None:
|
||||||
|
"""Test that AnyOf fails when all checks fail."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
length(min_words=10),
|
||||||
|
contains(patterns=["goodbye"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert not any(c.passed for c in result.checks)
|
||||||
|
|
||||||
|
def test_any_of_with_metric_validators(self) -> None:
|
||||||
|
"""Test AnyOf with metric-based validators."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
bleu(min_score=0.9), # Might fail
|
||||||
|
length(min_words=3), # Should pass
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext(reference="different text entirely")
|
||||||
|
result = validator.check("the quick brown fox jumps", context)
|
||||||
|
|
||||||
|
assert result.passed is True # Length check passes
|
||||||
|
|
||||||
|
def test_any_of_with_excludes(self) -> None:
|
||||||
|
"""Test AnyOf with excludes validator."""
|
||||||
|
validator = AnyOf(
|
||||||
|
checks=[
|
||||||
|
excludes(patterns=["error"]),
|
||||||
|
excludes(patterns=["warning"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
# Should pass - neither pattern found
|
||||||
|
result = validator.check("All is well", context)
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
# Should pass - one pattern found, other not
|
||||||
|
result = validator.check("This is an error", context)
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
# Should fail - both patterns found
|
||||||
|
result = validator.check("error and warning", context)
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_any_of_raises_on_empty_checks(self) -> None:
|
||||||
|
"""Test that empty checks list raises error."""
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
AnyOf(checks=[])
|
||||||
|
|
||||||
|
def test_any_of_name_property(self) -> None:
|
||||||
|
"""Test the name property."""
|
||||||
|
validator = AnyOf(checks=[length(min_chars=1)])
|
||||||
|
assert validator.name == "any_of"
|
||||||
|
|
||||||
|
def test_any_of_factory_function(self) -> None:
|
||||||
|
"""Test the any_of() factory function."""
|
||||||
|
validator = any_of(checks=[length(min_chars=1)])
|
||||||
|
assert isinstance(validator, AnyOf)
|
||||||
344
tests/test_validators/test_constraint.py
Normal file
344
tests/test_validators/test_constraint.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
"""Tests for constraint validators."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import contains, excludes, length, readability
|
||||||
|
from veritext.validators.constraint import (
|
||||||
|
ContainsValidator,
|
||||||
|
ExcludesValidator,
|
||||||
|
LengthValidator,
|
||||||
|
ReadabilityValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLengthValidator:
|
||||||
|
"""Tests for LengthValidator."""
|
||||||
|
|
||||||
|
def test_length_validator_min_chars_passes(self) -> None:
|
||||||
|
"""Test that validator passes when char count meets minimum."""
|
||||||
|
validator = LengthValidator(min_chars=10)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world!", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "length"
|
||||||
|
assert result.actual["chars"] == 12
|
||||||
|
|
||||||
|
def test_length_validator_min_chars_fails(self) -> None:
|
||||||
|
"""Test that validator fails when char count below minimum."""
|
||||||
|
validator = LengthValidator(min_chars=20)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "< min" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_max_chars_passes(self) -> None:
|
||||||
|
"""Test that validator passes when char count within maximum."""
|
||||||
|
validator = LengthValidator(max_chars=20)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual["chars"] == 11
|
||||||
|
|
||||||
|
def test_length_validator_max_chars_fails(self) -> None:
|
||||||
|
"""Test that validator fails when char count exceeds maximum."""
|
||||||
|
validator = LengthValidator(max_chars=5)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "> max" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_min_words_passes(self) -> None:
|
||||||
|
"""Test that validator passes when word count meets minimum."""
|
||||||
|
validator = LengthValidator(min_words=3)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual["words"] == 4
|
||||||
|
|
||||||
|
def test_length_validator_min_words_fails(self) -> None:
|
||||||
|
"""Test that validator fails when word count below minimum."""
|
||||||
|
validator = LengthValidator(min_words=10)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "words < min" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_max_words_passes(self) -> None:
|
||||||
|
"""Test that validator passes when word count within maximum."""
|
||||||
|
validator = LengthValidator(max_words=5)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_length_validator_max_words_fails(self) -> None:
|
||||||
|
"""Test that validator fails when word count exceeds maximum."""
|
||||||
|
validator = LengthValidator(max_words=2)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "words > max" in result.message
|
||||||
|
|
||||||
|
def test_length_validator_combined_constraints(self) -> None:
|
||||||
|
"""Test validator with multiple constraints."""
|
||||||
|
validator = LengthValidator(
|
||||||
|
min_chars=5, max_chars=50, min_words=2, max_words=10
|
||||||
|
)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert "min_chars" in result.threshold
|
||||||
|
assert "max_chars" in result.threshold
|
||||||
|
assert "min_words" in result.threshold
|
||||||
|
assert "max_words" in result.threshold
|
||||||
|
|
||||||
|
def test_length_validator_raises_when_no_constraints(self) -> None:
|
||||||
|
"""Test that validator raises when no constraints provided."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="At least one"):
|
||||||
|
LengthValidator()
|
||||||
|
|
||||||
|
def test_length_validator_raises_on_negative_values(self) -> None:
|
||||||
|
"""Test that negative constraint values raise error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_chars must be >= 0"):
|
||||||
|
LengthValidator(min_chars=-1)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="max_chars must be >= 0"):
|
||||||
|
LengthValidator(max_chars=-1)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_words must be >= 0"):
|
||||||
|
LengthValidator(min_words=-1)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="max_words must be >= 0"):
|
||||||
|
LengthValidator(max_words=-1)
|
||||||
|
|
||||||
|
def test_length_validator_raises_on_invalid_range(self) -> None:
|
||||||
|
"""Test that min > max raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot exceed max_chars"):
|
||||||
|
LengthValidator(min_chars=100, max_chars=50)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot exceed max_words"):
|
||||||
|
LengthValidator(min_words=20, max_words=5)
|
||||||
|
|
||||||
|
def test_length_factory_function(self) -> None:
|
||||||
|
"""Test the length() factory function."""
|
||||||
|
validator = length(min_chars=10, max_words=100)
|
||||||
|
assert isinstance(validator, LengthValidator)
|
||||||
|
assert validator.name == "length"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadabilityValidator:
|
||||||
|
"""Tests for ReadabilityValidator."""
|
||||||
|
|
||||||
|
def test_readability_validator_max_grade_passes(self) -> None:
|
||||||
|
"""Test that validator passes when grade level within maximum."""
|
||||||
|
validator = ReadabilityValidator(max_grade=12.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
# Simple text should have low grade level
|
||||||
|
result = validator.check("The cat sat on the mat. It was a nice day.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "readability"
|
||||||
|
assert "grade" in result.actual
|
||||||
|
|
||||||
|
def test_readability_validator_max_grade_fails(self) -> None:
|
||||||
|
"""Test that validator fails when grade level exceeds maximum."""
|
||||||
|
validator = ReadabilityValidator(max_grade=1.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
# Complex text
|
||||||
|
result = validator.check(
|
||||||
|
"The implementation of sophisticated methodologies necessitates "
|
||||||
|
"comprehensive analytical frameworks for systematic evaluation.",
|
||||||
|
context,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "grade level" in result.message
|
||||||
|
assert "> max" in result.message
|
||||||
|
|
||||||
|
def test_readability_validator_min_ease_passes(self) -> None:
|
||||||
|
"""Test that validator passes when reading ease meets minimum."""
|
||||||
|
validator = ReadabilityValidator(min_ease=30.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
# Simple text should have high reading ease
|
||||||
|
result = validator.check("The cat sat. The dog ran. It was fun.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert "ease" in result.actual
|
||||||
|
|
||||||
|
def test_readability_validator_min_ease_fails(self) -> None:
|
||||||
|
"""Test that validator fails when reading ease below minimum."""
|
||||||
|
validator = ReadabilityValidator(min_ease=100.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check(
|
||||||
|
"The implementation of sophisticated methodologies necessitates "
|
||||||
|
"comprehensive analytical frameworks.",
|
||||||
|
context,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "reading ease" in result.message
|
||||||
|
assert "< min" in result.message
|
||||||
|
|
||||||
|
def test_readability_validator_combined_constraints(self) -> None:
|
||||||
|
"""Test validator with both grade and ease constraints."""
|
||||||
|
validator = ReadabilityValidator(max_grade=12.0, min_ease=30.0)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("The cat sat on the mat.", context)
|
||||||
|
|
||||||
|
assert "max_grade" in result.threshold
|
||||||
|
assert "min_ease" in result.threshold
|
||||||
|
|
||||||
|
def test_readability_validator_raises_when_no_constraints(self) -> None:
|
||||||
|
"""Test that validator raises when no constraints provided."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="At least one"):
|
||||||
|
ReadabilityValidator()
|
||||||
|
|
||||||
|
def test_readability_factory_function(self) -> None:
|
||||||
|
"""Test the readability() factory function."""
|
||||||
|
validator = readability(max_grade=8.0, min_ease=60.0)
|
||||||
|
assert isinstance(validator, ReadabilityValidator)
|
||||||
|
assert validator.name == "readability"
|
||||||
|
|
||||||
|
|
||||||
|
class TestContainsValidator:
|
||||||
|
"""Tests for ContainsValidator."""
|
||||||
|
|
||||||
|
def test_contains_validator_passes_when_pattern_found(self) -> None:
|
||||||
|
"""Test that validator passes when all patterns are found."""
|
||||||
|
validator = ContainsValidator(patterns=["hello", "world"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("Hello World!", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "contains"
|
||||||
|
assert result.actual["found"] == 2
|
||||||
|
assert result.actual["missing"] == []
|
||||||
|
|
||||||
|
def test_contains_validator_fails_when_pattern_missing(self) -> None:
|
||||||
|
"""Test that validator fails when a pattern is missing."""
|
||||||
|
validator = ContainsValidator(patterns=["hello", "goodbye"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("Hello World!", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "goodbye" in result.actual["missing"]
|
||||||
|
assert "missing" in result.message
|
||||||
|
|
||||||
|
def test_contains_validator_case_insensitive_by_default(self) -> None:
|
||||||
|
"""Test that matching is case-insensitive by default."""
|
||||||
|
validator = ContainsValidator(patterns=["HELLO"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_contains_validator_case_sensitive(self) -> None:
|
||||||
|
"""Test case-sensitive matching."""
|
||||||
|
validator = ContainsValidator(patterns=["HELLO"], case_sensitive=True)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("hello world", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_contains_validator_regex_patterns(self) -> None:
|
||||||
|
"""Test regex pattern matching."""
|
||||||
|
validator = ContainsValidator(patterns=[r"\d{3}-\d{4}"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("Call 555-1234 for info", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_contains_validator_raises_on_empty_patterns(self) -> None:
|
||||||
|
"""Test that empty patterns list raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot be empty"):
|
||||||
|
ContainsValidator(patterns=[])
|
||||||
|
|
||||||
|
def test_contains_validator_raises_on_invalid_regex(self) -> None:
|
||||||
|
"""Test that invalid regex pattern raises error at init time."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="Invalid regex"):
|
||||||
|
ContainsValidator(patterns=[r"[invalid"])
|
||||||
|
|
||||||
|
def test_contains_factory_function(self) -> None:
|
||||||
|
"""Test the contains() factory function."""
|
||||||
|
validator = contains(patterns=["test"], case_sensitive=True)
|
||||||
|
assert isinstance(validator, ContainsValidator)
|
||||||
|
assert validator.name == "contains"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExcludesValidator:
|
||||||
|
"""Tests for ExcludesValidator."""
|
||||||
|
|
||||||
|
def test_excludes_validator_passes_when_pattern_absent(self) -> None:
|
||||||
|
"""Test that validator passes when all patterns are absent."""
|
||||||
|
validator = ExcludesValidator(patterns=["bad", "forbidden"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is good text.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "excludes"
|
||||||
|
assert result.actual["found"] == []
|
||||||
|
|
||||||
|
def test_excludes_validator_fails_when_pattern_found(self) -> None:
|
||||||
|
"""Test that validator fails when a forbidden pattern is found."""
|
||||||
|
validator = ExcludesValidator(patterns=["bad", "forbidden"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is bad text.", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "bad" in result.actual["found"]
|
||||||
|
assert "forbidden" in result.message
|
||||||
|
|
||||||
|
def test_excludes_validator_case_insensitive_by_default(self) -> None:
|
||||||
|
"""Test that matching is case-insensitive by default."""
|
||||||
|
validator = ExcludesValidator(patterns=["BAD"])
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is bad text.", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
def test_excludes_validator_case_sensitive(self) -> None:
|
||||||
|
"""Test case-sensitive matching."""
|
||||||
|
validator = ExcludesValidator(patterns=["BAD"], case_sensitive=True)
|
||||||
|
context = ValidationContext()
|
||||||
|
result = validator.check("This is bad text.", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_excludes_validator_regex_patterns(self) -> None:
|
||||||
|
"""Test regex pattern matching."""
|
||||||
|
validator = ExcludesValidator(patterns=[r"\b\d{4}\b"]) # 4-digit numbers
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
# Should fail when pattern found
|
||||||
|
result = validator.check("PIN is 1234", context)
|
||||||
|
assert result.passed is False
|
||||||
|
|
||||||
|
# Should pass when pattern absent
|
||||||
|
result = validator.check("No numbers here", context)
|
||||||
|
assert result.passed is True
|
||||||
|
|
||||||
|
def test_excludes_validator_raises_on_empty_patterns(self) -> None:
|
||||||
|
"""Test that empty patterns list raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="cannot be empty"):
|
||||||
|
ExcludesValidator(patterns=[])
|
||||||
|
|
||||||
|
def test_excludes_validator_raises_on_invalid_regex(self) -> None:
|
||||||
|
"""Test that invalid regex pattern raises error at init time."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="Invalid regex"):
|
||||||
|
ExcludesValidator(patterns=[r"[invalid"])
|
||||||
|
|
||||||
|
def test_excludes_factory_function(self) -> None:
|
||||||
|
"""Test the excludes() factory function."""
|
||||||
|
validator = excludes(patterns=["test"], case_sensitive=True)
|
||||||
|
assert isinstance(validator, ExcludesValidator)
|
||||||
|
assert validator.name == "excludes"
|
||||||
283
tests/test_validators/test_metric.py
Normal file
283
tests/test_validators/test_metric.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""Tests for metric-based validators."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from veritext.core.exceptions import InvalidThresholdError, ValidationError
|
||||||
|
from veritext.core.types import ValidationContext
|
||||||
|
from veritext.validators import bleu, lexical, rouge
|
||||||
|
from veritext.validators.metric import BleuValidator, LexicalValidator, RougeValidator
|
||||||
|
|
||||||
|
|
||||||
|
class TestBleuValidator:
|
||||||
|
"""Tests for BleuValidator."""
|
||||||
|
|
||||||
|
def test_bleu_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when BLEU score meets threshold."""
|
||||||
|
validator = BleuValidator(min_score=0.5, variant=4)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "bleu-4"
|
||||||
|
assert result.actual == 1.0 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_bleu_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when BLEU score is below threshold."""
|
||||||
|
validator = BleuValidator(min_score=0.9, variant=4)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through the park", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.name == "bleu-4"
|
||||||
|
assert result.actual < 0.9
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_bleu_validator_variant_selection(self) -> None:
|
||||||
|
"""Test different BLEU variants."""
|
||||||
|
context = ValidationContext(reference="the quick brown fox jumps")
|
||||||
|
|
||||||
|
for variant in (1, 2, 3, 4):
|
||||||
|
validator = BleuValidator(min_score=0.0, variant=variant) # type: ignore[arg-type]
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
assert result.name == f"bleu-{variant}"
|
||||||
|
|
||||||
|
def test_bleu_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
validator = BleuValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_bleu_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
BleuValidator(min_score=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
BleuValidator(min_score=-0.1)
|
||||||
|
|
||||||
|
def test_bleu_validator_raises_on_invalid_variant(self) -> None:
|
||||||
|
"""Test that invalid variant raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="variant must be"):
|
||||||
|
BleuValidator(min_score=0.5, variant=5) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_bleu_factory_function(self) -> None:
|
||||||
|
"""Test the bleu() factory function."""
|
||||||
|
validator = bleu(min_score=0.6, variant=2)
|
||||||
|
assert isinstance(validator, BleuValidator)
|
||||||
|
assert validator.name == "bleu-2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRougeValidator:
|
||||||
|
"""Tests for RougeValidator."""
|
||||||
|
|
||||||
|
def test_rouge_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when ROUGE score meets threshold."""
|
||||||
|
validator = RougeValidator(min_score=0.5, variant="l")
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "rouge-l"
|
||||||
|
assert result.actual == 1.0 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_rouge_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when ROUGE score is below threshold."""
|
||||||
|
validator = RougeValidator(min_score=0.9, variant="l")
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through the park", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.actual < 0.9
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_rouge_validator_variant_selection(self) -> None:
|
||||||
|
"""Test different ROUGE variants."""
|
||||||
|
context = ValidationContext(reference="the quick brown fox jumps")
|
||||||
|
|
||||||
|
for variant in ("1", "2", "l"):
|
||||||
|
validator = RougeValidator(min_score=0.0, variant=variant) # type: ignore[arg-type]
|
||||||
|
result = validator.check("the quick brown fox", context)
|
||||||
|
assert result.name == f"rouge-{variant}"
|
||||||
|
|
||||||
|
def test_rouge_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
validator = RougeValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_rouge_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
RougeValidator(min_score=1.5)
|
||||||
|
|
||||||
|
def test_rouge_validator_raises_on_invalid_variant(self) -> None:
|
||||||
|
"""Test that invalid variant raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="variant must be"):
|
||||||
|
RougeValidator(min_score=0.5, variant="3") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_rouge_factory_function(self) -> None:
|
||||||
|
"""Test the rouge() factory function."""
|
||||||
|
validator = rouge(min_score=0.6, variant="2")
|
||||||
|
assert isinstance(validator, RougeValidator)
|
||||||
|
assert validator.name == "rouge-2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLexicalValidator:
|
||||||
|
"""Tests for LexicalValidator."""
|
||||||
|
|
||||||
|
def test_lexical_validator_passes_on_jaccard(self) -> None:
|
||||||
|
"""Test that validator passes when Jaccard similarity meets threshold."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "lexical"
|
||||||
|
assert result.actual["jaccard"] == 1.0
|
||||||
|
|
||||||
|
def test_lexical_validator_fails_on_jaccard(self) -> None:
|
||||||
|
"""Test that validator fails when Jaccard is below threshold."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.9)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through the park", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "Jaccard" in result.message
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_lexical_validator_passes_on_overlap(self) -> None:
|
||||||
|
"""Test that validator passes when token overlap meets threshold."""
|
||||||
|
validator = LexicalValidator(min_overlap=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.actual["token_overlap"] == 1.0
|
||||||
|
|
||||||
|
def test_lexical_validator_fails_on_overlap(self) -> None:
|
||||||
|
"""Test that validator fails when overlap is below threshold."""
|
||||||
|
validator = LexicalValidator(min_overlap=0.9)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("a dog ran through", context)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert "overlap" in result.message
|
||||||
|
|
||||||
|
def test_lexical_validator_with_both_thresholds(self) -> None:
|
||||||
|
"""Test validator with both Jaccard and overlap thresholds."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.3, min_overlap=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat", context)
|
||||||
|
|
||||||
|
# Should check both thresholds
|
||||||
|
assert "min_jaccard" in result.threshold
|
||||||
|
assert "min_overlap" in result.threshold
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_when_no_threshold(self) -> None:
|
||||||
|
"""Test that validator raises when no threshold is provided."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="At least one"):
|
||||||
|
LexicalValidator()
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_on_invalid_jaccard(self) -> None:
|
||||||
|
"""Test that invalid Jaccard threshold raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_jaccard"):
|
||||||
|
LexicalValidator(min_jaccard=1.5)
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_on_invalid_overlap(self) -> None:
|
||||||
|
"""Test that invalid overlap threshold raises error."""
|
||||||
|
with pytest.raises(InvalidThresholdError, match="min_overlap"):
|
||||||
|
LexicalValidator(min_overlap=-0.1)
|
||||||
|
|
||||||
|
def test_lexical_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
validator = LexicalValidator(min_jaccard=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_lexical_factory_function(self) -> None:
|
||||||
|
"""Test the lexical() factory function."""
|
||||||
|
validator = lexical(min_jaccard=0.5, min_overlap=0.6)
|
||||||
|
assert isinstance(validator, LexicalValidator)
|
||||||
|
assert validator.name == "lexical"
|
||||||
|
|
||||||
|
|
||||||
|
# SemanticValidator tests - conditionally run if sentence-transformers is installed
|
||||||
|
class TestSemanticValidator:
|
||||||
|
"""Tests for SemanticValidator."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _skip_if_no_transformers() -> None:
|
||||||
|
"""Skip test if sentence-transformers is not installed."""
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
|
||||||
|
def test_semantic_validator_passes_when_score_meets_threshold(self) -> None:
|
||||||
|
"""Test that validator passes when semantic similarity meets threshold."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.5)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check("the cat sat on the mat", context)
|
||||||
|
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.name == "semantic"
|
||||||
|
assert result.actual >= 0.99 # Identical text
|
||||||
|
assert result.threshold == 0.5
|
||||||
|
|
||||||
|
def test_semantic_validator_fails_when_score_below_threshold(self) -> None:
|
||||||
|
"""Test that validator fails when semantic similarity is below threshold."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.99)
|
||||||
|
context = ValidationContext(reference="the cat sat on the mat")
|
||||||
|
result = validator.check(
|
||||||
|
"quantum physics describes particle behaviour", context
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.passed is False
|
||||||
|
assert result.name == "semantic"
|
||||||
|
assert result.actual < 0.99
|
||||||
|
assert "below minimum" in result.message
|
||||||
|
|
||||||
|
def test_semantic_validator_raises_on_missing_reference(self) -> None:
|
||||||
|
"""Test that validator raises when reference is missing."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = SemanticValidator(min_score=0.5)
|
||||||
|
context = ValidationContext()
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="requires reference text"):
|
||||||
|
validator.check("some text", context)
|
||||||
|
|
||||||
|
def test_semantic_validator_raises_on_invalid_min_score(self) -> None:
|
||||||
|
"""Test that invalid min_score raises error without loading model."""
|
||||||
|
# This test doesn't need sentence-transformers since validation happens first
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
SemanticValidator(min_score=1.5)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidThresholdError, match=r"between 0\.0 and 1\.0"):
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
SemanticValidator(min_score=-0.1)
|
||||||
|
|
||||||
|
def test_semantic_factory_function(self) -> None:
|
||||||
|
"""Test the semantic() factory function."""
|
||||||
|
self._skip_if_no_transformers()
|
||||||
|
from veritext.validators import semantic
|
||||||
|
from veritext.validators.metric import SemanticValidator
|
||||||
|
|
||||||
|
validator = semantic(min_score=0.6)
|
||||||
|
assert isinstance(validator, SemanticValidator)
|
||||||
|
assert validator.name == "semantic"
|
||||||
Reference in New Issue
Block a user