diff --git a/src/veritext/cli/main.py b/src/veritext/cli/main.py index 2e35a18..a177036 100644 --- a/src/veritext/cli/main.py +++ b/src/veritext/cli/main.py @@ -3,6 +3,7 @@ import typer import veritext +from veritext.cli.validate import validate app = typer.Typer( name="veritext", @@ -10,6 +11,9 @@ app = typer.Typer( no_args_is_help=True, ) +# Register commands +app.command()(validate) + @app.callback(invoke_without_command=True) def main( diff --git a/src/veritext/cli/validate.py b/src/veritext/cli/validate.py new file mode 100644 index 0000000..179c395 --- /dev/null +++ b/src/veritext/cli/validate.py @@ -0,0 +1,222 @@ +"""Validate command for computing text metrics.""" + +from pathlib import Path +from typing import Annotated + +import typer + +from veritext.cli.formatters import console, print_validation_output +from veritext.cli.readers import read_jsonl, read_paired_jsonl +from veritext.metrics.bleu import Bleu +from veritext.metrics.lexical import Lexical +from veritext.metrics.rouge import Rouge + +AVAILABLE_METRICS = frozenset( + {"bleu", "bleu1", "bleu2", "bleu3", "bleu4", "rouge", "rouge_l", "lexical"} +) + +_bleu: Bleu | None = None +_rouge: Rouge | None = None +_lexical: Lexical | None = None + + +def _get_bleu() -> Bleu: + global _bleu + if _bleu is None: + _bleu = Bleu() + return _bleu + + +def _get_rouge() -> Rouge: + global _rouge + if _rouge is None: + _rouge = Rouge() + return _rouge + + +def _get_lexical() -> Lexical: + global _lexical + if _lexical is None: + _lexical = Lexical() + return _lexical +def _bleu_single(candidate: str, reference: str, key: str) -> dict[str, float]: + result = _get_bleu().score(candidate, reference) + return {key: getattr(result, key)} + + +def _bleu_batch( + candidates: list[str], references: list[str], key: str +) -> dict[str, float]: + batch = _get_bleu().batch_score(candidates, references) + stats = batch.stats.get(key) + return {key: stats.mean} if stats else {} + + +def _rouge_single(candidate: str, reference: str) -> dict[str, float]: + result = _get_rouge().score(candidate, reference) + return {"rouge_l": result.rouge_l.fmeasure} + + +def _rouge_batch(candidates: list[str], references: list[str]) -> dict[str, float]: + batch = _get_rouge().batch_score(candidates, references) + stats = batch.stats.get("rouge_l_fmeasure") + return {"rouge_l": stats.mean} if stats else {} + + +def _lexical_single(candidate: str, reference: str) -> dict[str, float]: + result = _get_lexical().score(candidate, reference) + return {"jaccard": result.jaccard, "token_overlap": result.token_overlap} + + +def _lexical_batch(candidates: list[str], references: list[str]) -> dict[str, float]: + batch = _get_lexical().batch_score(candidates, references) + results: dict[str, float] = {} + jaccard_stats = batch.stats.get("jaccard") + overlap_stats = batch.stats.get("token_overlap") + if jaccard_stats: + results["jaccard"] = jaccard_stats.mean + if overlap_stats: + results["token_overlap"] = overlap_stats.mean + return results + + +def _compute_metrics( + candidate: str, + reference: str, + metric_names: list[str], +) -> dict[str, float]: + results: dict[str, float] = {} + + for metric in metric_names: + if metric in ("bleu", "bleu4"): + results.update(_bleu_single(candidate, reference, "bleu4")) + elif metric in ("bleu1", "bleu2", "bleu3"): + results.update(_bleu_single(candidate, reference, metric)) + elif metric in ("rouge", "rouge_l"): + results.update(_rouge_single(candidate, reference)) + elif metric == "lexical": + results.update(_lexical_single(candidate, reference)) + + return results + + +def _compute_batch_metrics( + candidates: list[str], + references: list[str], + metric_names: list[str], +) -> dict[str, float]: + results: dict[str, float] = {} + + for metric in metric_names: + if metric in ("bleu", "bleu4"): + results.update(_bleu_batch(candidates, references, "bleu4")) + elif metric in ("bleu1", "bleu2", "bleu3"): + results.update(_bleu_batch(candidates, references, metric)) + elif metric in ("rouge", "rouge_l"): + results.update(_rouge_batch(candidates, references)) + elif metric == "lexical": + results.update(_lexical_batch(candidates, references)) + + return results + + +def _parse_metrics(metrics_str: str) -> list[str]: + metrics = [m.strip().lower() for m in metrics_str.split(",")] + invalid = [m for m in metrics if m not in AVAILABLE_METRICS] + if invalid: + raise typer.BadParameter( + f"Unknown metrics: {', '.join(invalid)}. " + f"Available: {', '.join(sorted(AVAILABLE_METRICS))}" + ) + + return metrics + + +def validate( + text: Annotated[ + str | None, + typer.Argument(help="Candidate text to validate (inline mode)."), + ] = None, + reference: Annotated[ + str | None, + typer.Option("--reference", "-r", help="Reference text for comparison."), + ] = None, + file: Annotated[ + Path | None, + typer.Option("--file", "-f", help="JSONL file with candidate/reference pairs."), + ] = None, + reference_file: Annotated[ + Path | None, + typer.Option( + "--reference-file", + "-R", + help="Separate JSONL file with references (requires --file).", + ), + ] = None, + metrics: Annotated[ + str, + typer.Option( + "--metrics", + "-m", + help="Comma-separated metrics: bleu, bleu1-4, rouge, rouge_l, lexical.", + ), + ] = "bleu,rouge", + output: Annotated[ + str, + typer.Option("--output", "-o", help="Output format: table, json, or simple."), + ] = "table", + threshold: Annotated[ + float | None, + typer.Option("--threshold", "-t", help="Score threshold for pass/fail status."), + ] = None, +) -> None: + """ + Validate text quality using various metrics. + + Use inline mode for single texts: + veritext validate "text" -r "reference" -m bleu,rouge + + Use file mode for batches: + veritext validate -f outputs.jsonl -m bleu,rouge + """ + try: + metric_names = _parse_metrics(metrics) + except typer.BadParameter as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(code=1) from e + + if output not in ("table", "json", "simple"): + console.print(f"[red]Error:[/red] Invalid output format: {output}") + raise typer.Exit(code=1) + + if file is not None: + try: + if reference_file is not None: + pairs = read_paired_jsonl(file, reference_file) + else: + pairs = read_jsonl(file) + except (FileNotFoundError, ValueError) as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(code=1) from e + + if not pairs: + console.print("[yellow]Warning:[/yellow] No text pairs found in file.") + raise typer.Exit(code=0) + + candidates = [p.candidate for p in pairs] + references = [p.reference for p in pairs] + + results = _compute_batch_metrics(candidates, references, metric_names) + console.print(f"[dim]Evaluated {len(pairs)} text pairs.[/dim]\n") + + elif text is not None and reference is not None: + results = _compute_metrics(text, reference, metric_names) + + else: + console.print( + "[red]Error:[/red] Provide either text and --reference, " + "or --file for batch mode." + ) + raise typer.Exit(code=1) + + print_validation_output(results, output, threshold)