cli validate command
Implement validate command with inline and file-based modes supporting BLEU, ROUGE, and lexical metrics with multiple output formats.
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
import typer
|
import typer
|
||||||
|
|
||||||
import veritext
|
import veritext
|
||||||
|
from veritext.cli.validate import validate
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="veritext",
|
name="veritext",
|
||||||
@@ -10,6 +11,9 @@ app = typer.Typer(
|
|||||||
no_args_is_help=True,
|
no_args_is_help=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Register commands
|
||||||
|
app.command()(validate)
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
@app.callback(invoke_without_command=True)
|
||||||
def main(
|
def main(
|
||||||
|
|||||||
222
src/veritext/cli/validate.py
Normal file
222
src/veritext/cli/validate.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Validate command for computing text metrics."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from veritext.cli.formatters import console, print_validation_output
|
||||||
|
from veritext.cli.readers import read_jsonl, read_paired_jsonl
|
||||||
|
from veritext.metrics.bleu import Bleu
|
||||||
|
from veritext.metrics.lexical import Lexical
|
||||||
|
from veritext.metrics.rouge import Rouge
|
||||||
|
|
||||||
|
AVAILABLE_METRICS = frozenset(
|
||||||
|
{"bleu", "bleu1", "bleu2", "bleu3", "bleu4", "rouge", "rouge_l", "lexical"}
|
||||||
|
)
|
||||||
|
|
||||||
|
_bleu: Bleu | None = None
|
||||||
|
_rouge: Rouge | None = None
|
||||||
|
_lexical: Lexical | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bleu() -> Bleu:
|
||||||
|
global _bleu
|
||||||
|
if _bleu is None:
|
||||||
|
_bleu = Bleu()
|
||||||
|
return _bleu
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rouge() -> Rouge:
|
||||||
|
global _rouge
|
||||||
|
if _rouge is None:
|
||||||
|
_rouge = Rouge()
|
||||||
|
return _rouge
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lexical() -> Lexical:
|
||||||
|
global _lexical
|
||||||
|
if _lexical is None:
|
||||||
|
_lexical = Lexical()
|
||||||
|
return _lexical
|
||||||
|
def _bleu_single(candidate: str, reference: str, key: str) -> dict[str, float]:
|
||||||
|
result = _get_bleu().score(candidate, reference)
|
||||||
|
return {key: getattr(result, key)}
|
||||||
|
|
||||||
|
|
||||||
|
def _bleu_batch(
|
||||||
|
candidates: list[str], references: list[str], key: str
|
||||||
|
) -> dict[str, float]:
|
||||||
|
batch = _get_bleu().batch_score(candidates, references)
|
||||||
|
stats = batch.stats.get(key)
|
||||||
|
return {key: stats.mean} if stats else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _rouge_single(candidate: str, reference: str) -> dict[str, float]:
|
||||||
|
result = _get_rouge().score(candidate, reference)
|
||||||
|
return {"rouge_l": result.rouge_l.fmeasure}
|
||||||
|
|
||||||
|
|
||||||
|
def _rouge_batch(candidates: list[str], references: list[str]) -> dict[str, float]:
|
||||||
|
batch = _get_rouge().batch_score(candidates, references)
|
||||||
|
stats = batch.stats.get("rouge_l_fmeasure")
|
||||||
|
return {"rouge_l": stats.mean} if stats else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _lexical_single(candidate: str, reference: str) -> dict[str, float]:
|
||||||
|
result = _get_lexical().score(candidate, reference)
|
||||||
|
return {"jaccard": result.jaccard, "token_overlap": result.token_overlap}
|
||||||
|
|
||||||
|
|
||||||
|
def _lexical_batch(candidates: list[str], references: list[str]) -> dict[str, float]:
|
||||||
|
batch = _get_lexical().batch_score(candidates, references)
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
jaccard_stats = batch.stats.get("jaccard")
|
||||||
|
overlap_stats = batch.stats.get("token_overlap")
|
||||||
|
if jaccard_stats:
|
||||||
|
results["jaccard"] = jaccard_stats.mean
|
||||||
|
if overlap_stats:
|
||||||
|
results["token_overlap"] = overlap_stats.mean
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_metrics(
|
||||||
|
candidate: str,
|
||||||
|
reference: str,
|
||||||
|
metric_names: list[str],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
|
||||||
|
for metric in metric_names:
|
||||||
|
if metric in ("bleu", "bleu4"):
|
||||||
|
results.update(_bleu_single(candidate, reference, "bleu4"))
|
||||||
|
elif metric in ("bleu1", "bleu2", "bleu3"):
|
||||||
|
results.update(_bleu_single(candidate, reference, metric))
|
||||||
|
elif metric in ("rouge", "rouge_l"):
|
||||||
|
results.update(_rouge_single(candidate, reference))
|
||||||
|
elif metric == "lexical":
|
||||||
|
results.update(_lexical_single(candidate, reference))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_batch_metrics(
|
||||||
|
candidates: list[str],
|
||||||
|
references: list[str],
|
||||||
|
metric_names: list[str],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
results: dict[str, float] = {}
|
||||||
|
|
||||||
|
for metric in metric_names:
|
||||||
|
if metric in ("bleu", "bleu4"):
|
||||||
|
results.update(_bleu_batch(candidates, references, "bleu4"))
|
||||||
|
elif metric in ("bleu1", "bleu2", "bleu3"):
|
||||||
|
results.update(_bleu_batch(candidates, references, metric))
|
||||||
|
elif metric in ("rouge", "rouge_l"):
|
||||||
|
results.update(_rouge_batch(candidates, references))
|
||||||
|
elif metric == "lexical":
|
||||||
|
results.update(_lexical_batch(candidates, references))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_metrics(metrics_str: str) -> list[str]:
|
||||||
|
metrics = [m.strip().lower() for m in metrics_str.split(",")]
|
||||||
|
invalid = [m for m in metrics if m not in AVAILABLE_METRICS]
|
||||||
|
if invalid:
|
||||||
|
raise typer.BadParameter(
|
||||||
|
f"Unknown metrics: {', '.join(invalid)}. "
|
||||||
|
f"Available: {', '.join(sorted(AVAILABLE_METRICS))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def validate(
|
||||||
|
text: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Argument(help="Candidate text to validate (inline mode)."),
|
||||||
|
] = None,
|
||||||
|
reference: Annotated[
|
||||||
|
str | None,
|
||||||
|
typer.Option("--reference", "-r", help="Reference text for comparison."),
|
||||||
|
] = None,
|
||||||
|
file: Annotated[
|
||||||
|
Path | None,
|
||||||
|
typer.Option("--file", "-f", help="JSONL file with candidate/reference pairs."),
|
||||||
|
] = None,
|
||||||
|
reference_file: Annotated[
|
||||||
|
Path | None,
|
||||||
|
typer.Option(
|
||||||
|
"--reference-file",
|
||||||
|
"-R",
|
||||||
|
help="Separate JSONL file with references (requires --file).",
|
||||||
|
),
|
||||||
|
] = None,
|
||||||
|
metrics: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option(
|
||||||
|
"--metrics",
|
||||||
|
"-m",
|
||||||
|
help="Comma-separated metrics: bleu, bleu1-4, rouge, rouge_l, lexical.",
|
||||||
|
),
|
||||||
|
] = "bleu,rouge",
|
||||||
|
output: Annotated[
|
||||||
|
str,
|
||||||
|
typer.Option("--output", "-o", help="Output format: table, json, or simple."),
|
||||||
|
] = "table",
|
||||||
|
threshold: Annotated[
|
||||||
|
float | None,
|
||||||
|
typer.Option("--threshold", "-t", help="Score threshold for pass/fail status."),
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate text quality using various metrics.
|
||||||
|
|
||||||
|
Use inline mode for single texts:
|
||||||
|
veritext validate "text" -r "reference" -m bleu,rouge
|
||||||
|
|
||||||
|
Use file mode for batches:
|
||||||
|
veritext validate -f outputs.jsonl -m bleu,rouge
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metric_names = _parse_metrics(metrics)
|
||||||
|
except typer.BadParameter as e:
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
raise typer.Exit(code=1) from e
|
||||||
|
|
||||||
|
if output not in ("table", "json", "simple"):
|
||||||
|
console.print(f"[red]Error:[/red] Invalid output format: {output}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
if file is not None:
|
||||||
|
try:
|
||||||
|
if reference_file is not None:
|
||||||
|
pairs = read_paired_jsonl(file, reference_file)
|
||||||
|
else:
|
||||||
|
pairs = read_jsonl(file)
|
||||||
|
except (FileNotFoundError, ValueError) as e:
|
||||||
|
console.print(f"[red]Error:[/red] {e}")
|
||||||
|
raise typer.Exit(code=1) from e
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
console.print("[yellow]Warning:[/yellow] No text pairs found in file.")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
candidates = [p.candidate for p in pairs]
|
||||||
|
references = [p.reference for p in pairs]
|
||||||
|
|
||||||
|
results = _compute_batch_metrics(candidates, references, metric_names)
|
||||||
|
console.print(f"[dim]Evaluated {len(pairs)} text pairs.[/dim]\n")
|
||||||
|
|
||||||
|
elif text is not None and reference is not None:
|
||||||
|
results = _compute_metrics(text, reference, metric_names)
|
||||||
|
|
||||||
|
else:
|
||||||
|
console.print(
|
||||||
|
"[red]Error:[/red] Provide either text and --reference, "
|
||||||
|
"or --file for batch mode."
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
print_validation_output(results, output, threshold)
|
||||||
Reference in New Issue
Block a user