feat(cli): add validate command
Implement validate command with inline and file-based modes supporting BLEU, ROUGE, and lexical metrics with multiple output formats.
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import typer
|
||||
|
||||
import veritext
|
||||
from veritext.cli.validate import validate
|
||||
|
||||
app = typer.Typer(
|
||||
name="veritext",
|
||||
@@ -10,6 +11,9 @@ app = typer.Typer(
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
# Register commands
|
||||
app.command()(validate)
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def main(
|
||||
|
||||
213
src/veritext/cli/validate.py
Normal file
213
src/veritext/cli/validate.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Validate command for computing text metrics."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from veritext.cli.formatters import console, print_validation_output
|
||||
from veritext.cli.readers import read_jsonl, read_paired_jsonl
|
||||
from veritext.metrics.bleu import Bleu
|
||||
from veritext.metrics.lexical import Lexical
|
||||
from veritext.metrics.rouge import Rouge
|
||||
|
||||
# Available metrics mapped to their computation functions
|
||||
AVAILABLE_METRICS = frozenset(
|
||||
{"bleu", "bleu1", "bleu2", "bleu3", "bleu4", "rouge", "rouge_l", "lexical"}
|
||||
)
|
||||
|
||||
|
||||
def _compute_metrics(
|
||||
candidate: str,
|
||||
reference: str,
|
||||
metric_names: list[str],
|
||||
) -> dict[str, float]:
|
||||
"""Compute requested metrics for a single text pair."""
|
||||
results: dict[str, float] = {}
|
||||
bleu = Bleu()
|
||||
rouge = Rouge()
|
||||
lexical = Lexical()
|
||||
|
||||
for metric in metric_names:
|
||||
if metric == "bleu" or metric == "bleu4":
|
||||
bleu_result = bleu.score(candidate, reference)
|
||||
results["bleu4"] = bleu_result.bleu4
|
||||
elif metric == "bleu1":
|
||||
bleu_result = bleu.score(candidate, reference)
|
||||
results["bleu1"] = bleu_result.bleu1
|
||||
elif metric == "bleu2":
|
||||
bleu_result = bleu.score(candidate, reference)
|
||||
results["bleu2"] = bleu_result.bleu2
|
||||
elif metric == "bleu3":
|
||||
bleu_result = bleu.score(candidate, reference)
|
||||
results["bleu3"] = bleu_result.bleu3
|
||||
elif metric == "rouge" or metric == "rouge_l":
|
||||
rouge_result = rouge.score(candidate, reference)
|
||||
results["rouge_l"] = rouge_result.rouge_l.fmeasure
|
||||
elif metric == "lexical":
|
||||
lexical_result = lexical.score(candidate, reference)
|
||||
results["jaccard"] = lexical_result.jaccard
|
||||
results["token_overlap"] = lexical_result.token_overlap
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _compute_batch_metrics(
|
||||
candidates: list[str],
|
||||
references: list[str],
|
||||
metric_names: list[str],
|
||||
) -> dict[str, float]:
|
||||
"""Compute average metrics for a batch of text pairs."""
|
||||
bleu = Bleu()
|
||||
rouge = Rouge()
|
||||
lexical = Lexical()
|
||||
|
||||
results: dict[str, float] = {}
|
||||
|
||||
for metric in metric_names:
|
||||
if metric == "bleu" or metric == "bleu4":
|
||||
bleu_batch = bleu.batch_score(candidates, references)
|
||||
stats = bleu_batch.stats.get("bleu4")
|
||||
if stats:
|
||||
results["bleu4"] = stats.mean
|
||||
elif metric == "bleu1":
|
||||
bleu_batch = bleu.batch_score(candidates, references)
|
||||
stats = bleu_batch.stats.get("bleu1")
|
||||
if stats:
|
||||
results["bleu1"] = stats.mean
|
||||
elif metric == "bleu2":
|
||||
bleu_batch = bleu.batch_score(candidates, references)
|
||||
stats = bleu_batch.stats.get("bleu2")
|
||||
if stats:
|
||||
results["bleu2"] = stats.mean
|
||||
elif metric == "bleu3":
|
||||
bleu_batch = bleu.batch_score(candidates, references)
|
||||
stats = bleu_batch.stats.get("bleu3")
|
||||
if stats:
|
||||
results["bleu3"] = stats.mean
|
||||
elif metric == "rouge" or metric == "rouge_l":
|
||||
rouge_batch = rouge.batch_score(candidates, references)
|
||||
stats = rouge_batch.stats.get("rouge_l_fmeasure")
|
||||
if stats:
|
||||
results["rouge_l"] = stats.mean
|
||||
elif metric == "lexical":
|
||||
lexical_batch = lexical.batch_score(candidates, references)
|
||||
jaccard_stats = lexical_batch.stats.get("jaccard")
|
||||
overlap_stats = lexical_batch.stats.get("token_overlap")
|
||||
if jaccard_stats:
|
||||
results["jaccard"] = jaccard_stats.mean
|
||||
if overlap_stats:
|
||||
results["token_overlap"] = overlap_stats.mean
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_metrics(metrics_str: str) -> list[str]:
|
||||
"""Parse comma-separated metric names."""
|
||||
metrics = [m.strip().lower() for m in metrics_str.split(",")]
|
||||
|
||||
# Validate metric names
|
||||
invalid = [m for m in metrics if m not in AVAILABLE_METRICS]
|
||||
if invalid:
|
||||
raise typer.BadParameter(
|
||||
f"Unknown metrics: {', '.join(invalid)}. "
|
||||
f"Available: {', '.join(sorted(AVAILABLE_METRICS))}"
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def validate(
|
||||
text: Annotated[
|
||||
str | None,
|
||||
typer.Argument(help="Candidate text to validate (inline mode)."),
|
||||
] = None,
|
||||
reference: Annotated[
|
||||
str | None,
|
||||
typer.Option("--reference", "-r", help="Reference text for comparison."),
|
||||
] = None,
|
||||
file: Annotated[
|
||||
Path | None,
|
||||
typer.Option("--file", "-f", help="JSONL file with candidate/reference pairs."),
|
||||
] = None,
|
||||
reference_file: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--reference-file",
|
||||
"-R",
|
||||
help="Separate JSONL file with references (requires --file).",
|
||||
),
|
||||
] = None,
|
||||
metrics: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--metrics",
|
||||
"-m",
|
||||
help="Comma-separated metrics: bleu, bleu1-4, rouge, rouge_l, lexical.",
|
||||
),
|
||||
] = "bleu,rouge",
|
||||
output: Annotated[
|
||||
str,
|
||||
typer.Option("--output", "-o", help="Output format: table, json, or simple."),
|
||||
] = "table",
|
||||
threshold: Annotated[
|
||||
float | None,
|
||||
typer.Option("--threshold", "-t", help="Score threshold for pass/fail status."),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate text quality using various metrics.
|
||||
|
||||
Use inline mode for single texts:
|
||||
veritext validate "text" -r "reference" -m bleu,rouge
|
||||
|
||||
Use file mode for batches:
|
||||
veritext validate -f outputs.jsonl -m bleu,rouge
|
||||
"""
|
||||
# Parse and validate metric names
|
||||
try:
|
||||
metric_names = _parse_metrics(metrics)
|
||||
except typer.BadParameter as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
raise typer.Exit(code=1) from e
|
||||
|
||||
# Validate output format
|
||||
if output not in ("table", "json", "simple"):
|
||||
console.print(f"[red]Error:[/red] Invalid output format: {output}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Determine mode: inline vs file
|
||||
if file is not None:
|
||||
# File mode
|
||||
try:
|
||||
if reference_file is not None:
|
||||
pairs = read_paired_jsonl(file, reference_file)
|
||||
else:
|
||||
pairs = read_jsonl(file)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
raise typer.Exit(code=1) from e
|
||||
|
||||
if not pairs:
|
||||
console.print("[yellow]Warning:[/yellow] No text pairs found in file.")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
candidates = [p.candidate for p in pairs]
|
||||
references = [p.reference for p in pairs]
|
||||
|
||||
results = _compute_batch_metrics(candidates, references, metric_names)
|
||||
console.print(f"[dim]Evaluated {len(pairs)} text pairs.[/dim]\n")
|
||||
|
||||
elif text is not None and reference is not None:
|
||||
# Inline mode
|
||||
results = _compute_metrics(text, reference, metric_names)
|
||||
|
||||
else:
|
||||
# Invalid usage
|
||||
console.print(
|
||||
"[red]Error:[/red] Provide either text and --reference, "
|
||||
"or --file for batch mode."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
print_validation_output(results, output, threshold)
|
||||
Reference in New Issue
Block a user