diff --git a/src/veritext/cli/main.py b/src/veritext/cli/main.py index 2e35a18..a177036 100644 --- a/src/veritext/cli/main.py +++ b/src/veritext/cli/main.py @@ -3,6 +3,7 @@ import typer import veritext +from veritext.cli.validate import validate app = typer.Typer( name="veritext", @@ -10,6 +11,9 @@ app = typer.Typer( no_args_is_help=True, ) +# Register commands +app.command()(validate) + @app.callback(invoke_without_command=True) def main( diff --git a/src/veritext/cli/validate.py b/src/veritext/cli/validate.py new file mode 100644 index 0000000..b07455c --- /dev/null +++ b/src/veritext/cli/validate.py @@ -0,0 +1,213 @@ +"""Validate command for computing text metrics.""" + +from pathlib import Path +from typing import Annotated + +import typer + +from veritext.cli.formatters import console, print_validation_output +from veritext.cli.readers import read_jsonl, read_paired_jsonl +from veritext.metrics.bleu import Bleu +from veritext.metrics.lexical import Lexical +from veritext.metrics.rouge import Rouge + +# Available metrics mapped to their computation functions +AVAILABLE_METRICS = frozenset( + {"bleu", "bleu1", "bleu2", "bleu3", "bleu4", "rouge", "rouge_l", "lexical"} +) + + +def _compute_metrics( + candidate: str, + reference: str, + metric_names: list[str], +) -> dict[str, float]: + """Compute requested metrics for a single text pair.""" + results: dict[str, float] = {} + bleu = Bleu() + rouge = Rouge() + lexical = Lexical() + + for metric in metric_names: + if metric == "bleu" or metric == "bleu4": + bleu_result = bleu.score(candidate, reference) + results["bleu4"] = bleu_result.bleu4 + elif metric == "bleu1": + bleu_result = bleu.score(candidate, reference) + results["bleu1"] = bleu_result.bleu1 + elif metric == "bleu2": + bleu_result = bleu.score(candidate, reference) + results["bleu2"] = bleu_result.bleu2 + elif metric == "bleu3": + bleu_result = bleu.score(candidate, reference) + results["bleu3"] = bleu_result.bleu3 + elif metric == "rouge" or metric == "rouge_l": + rouge_result = rouge.score(candidate, reference) + results["rouge_l"] = rouge_result.rouge_l.fmeasure + elif metric == "lexical": + lexical_result = lexical.score(candidate, reference) + results["jaccard"] = lexical_result.jaccard + results["token_overlap"] = lexical_result.token_overlap + + return results + + +def _compute_batch_metrics( + candidates: list[str], + references: list[str], + metric_names: list[str], +) -> dict[str, float]: + """Compute average metrics for a batch of text pairs.""" + bleu = Bleu() + rouge = Rouge() + lexical = Lexical() + + results: dict[str, float] = {} + + for metric in metric_names: + if metric == "bleu" or metric == "bleu4": + bleu_batch = bleu.batch_score(candidates, references) + stats = bleu_batch.stats.get("bleu4") + if stats: + results["bleu4"] = stats.mean + elif metric == "bleu1": + bleu_batch = bleu.batch_score(candidates, references) + stats = bleu_batch.stats.get("bleu1") + if stats: + results["bleu1"] = stats.mean + elif metric == "bleu2": + bleu_batch = bleu.batch_score(candidates, references) + stats = bleu_batch.stats.get("bleu2") + if stats: + results["bleu2"] = stats.mean + elif metric == "bleu3": + bleu_batch = bleu.batch_score(candidates, references) + stats = bleu_batch.stats.get("bleu3") + if stats: + results["bleu3"] = stats.mean + elif metric == "rouge" or metric == "rouge_l": + rouge_batch = rouge.batch_score(candidates, references) + stats = rouge_batch.stats.get("rouge_l_fmeasure") + if stats: + results["rouge_l"] = stats.mean + elif metric == "lexical": + lexical_batch = lexical.batch_score(candidates, references) + jaccard_stats = lexical_batch.stats.get("jaccard") + overlap_stats = lexical_batch.stats.get("token_overlap") + if jaccard_stats: + results["jaccard"] = jaccard_stats.mean + if overlap_stats: + results["token_overlap"] = overlap_stats.mean + + return results + + +def _parse_metrics(metrics_str: str) -> list[str]: + """Parse comma-separated metric names.""" + metrics = [m.strip().lower() for m in metrics_str.split(",")] + + # Validate metric names + invalid = [m for m in metrics if m not in AVAILABLE_METRICS] + if invalid: + raise typer.BadParameter( + f"Unknown metrics: {', '.join(invalid)}. " + f"Available: {', '.join(sorted(AVAILABLE_METRICS))}" + ) + + return metrics + + +def validate( + text: Annotated[ + str | None, + typer.Argument(help="Candidate text to validate (inline mode)."), + ] = None, + reference: Annotated[ + str | None, + typer.Option("--reference", "-r", help="Reference text for comparison."), + ] = None, + file: Annotated[ + Path | None, + typer.Option("--file", "-f", help="JSONL file with candidate/reference pairs."), + ] = None, + reference_file: Annotated[ + Path | None, + typer.Option( + "--reference-file", + "-R", + help="Separate JSONL file with references (requires --file).", + ), + ] = None, + metrics: Annotated[ + str, + typer.Option( + "--metrics", + "-m", + help="Comma-separated metrics: bleu, bleu1-4, rouge, rouge_l, lexical.", + ), + ] = "bleu,rouge", + output: Annotated[ + str, + typer.Option("--output", "-o", help="Output format: table, json, or simple."), + ] = "table", + threshold: Annotated[ + float | None, + typer.Option("--threshold", "-t", help="Score threshold for pass/fail status."), + ] = None, +) -> None: + """ + Validate text quality using various metrics. + + Use inline mode for single texts: + veritext validate "text" -r "reference" -m bleu,rouge + + Use file mode for batches: + veritext validate -f outputs.jsonl -m bleu,rouge + """ + # Parse and validate metric names + try: + metric_names = _parse_metrics(metrics) + except typer.BadParameter as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(code=1) from e + + # Validate output format + if output not in ("table", "json", "simple"): + console.print(f"[red]Error:[/red] Invalid output format: {output}") + raise typer.Exit(code=1) + + # Determine mode: inline vs file + if file is not None: + # File mode + try: + if reference_file is not None: + pairs = read_paired_jsonl(file, reference_file) + else: + pairs = read_jsonl(file) + except (FileNotFoundError, ValueError) as e: + console.print(f"[red]Error:[/red] {e}") + raise typer.Exit(code=1) from e + + if not pairs: + console.print("[yellow]Warning:[/yellow] No text pairs found in file.") + raise typer.Exit(code=0) + + candidates = [p.candidate for p in pairs] + references = [p.reference for p in pairs] + + results = _compute_batch_metrics(candidates, references, metric_names) + console.print(f"[dim]Evaluated {len(pairs)} text pairs.[/dim]\n") + + elif text is not None and reference is not None: + # Inline mode + results = _compute_metrics(text, reference, metric_names) + + else: + # Invalid usage + console.print( + "[red]Error:[/red] Provide either text and --reference, " + "or --file for batch mode." + ) + raise typer.Exit(code=1) + + print_validation_output(results, output, threshold)