Source code for oumi.cli.judge

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Optional

import jsonlines
import typer

from oumi.cli import cli_utils
from oumi.utils.io_utils import load_jsonlines

if TYPE_CHECKING:
    from oumi.core.configs import InferenceConfig, JudgeConfig


def _load_judge_config(config: str, extra_args: list[str]) -> "JudgeConfig":
    from oumi.core.registry import REGISTRY

    judge_config_builder = REGISTRY.get_judge_config(config)

    if judge_config_builder:
        if extra_args:
            typer.echo(
                "For consistent judge results, a named judge config cannot be "
                f"overridden with extra arguments. Got: {extra_args}. "
                "Please register a new named judge config, or provide a path to a "
                "judge config file."
            )
            raise typer.Exit(code=1)
        return judge_config_builder()

    if not Path(config).exists():
        typer.echo(f"Config file not found: '{config}'")
        raise typer.Exit(code=1)

    return JudgeConfig.from_yaml_and_arg_list(config, extra_args)


[docs] def dataset( ctx: typer.Context, config: Annotated[ str, typer.Option(*cli_utils.CONFIG_FLAGS, help="Path to the judge config file") ], dataset_name: Annotated[ Optional[str], typer.Option(help="Name of the dataset from the registry") ] = None, dataset_subset: Annotated[ Optional[str], typer.Option(help="Subset of the dataset to use, if applicable") ] = None, dataset_split: Annotated[ Optional[str], typer.Option(help="Split of the dataset to use.") ] = "train", output_file: Annotated[ Optional[str], typer.Option(help="Path to the output file (jsonl)") ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ): """Judge a dataset.""" # Delayed imports from oumi import judge_dataset from oumi.core.registry import REGISTRY # End imports if not dataset_name: typer.echo("Dataset name is required.") raise typer.Exit(code=1) # Load the judge config extra_args = cli_utils.parse_extra_cli_args(ctx) judge_config = _load_judge_config(config, extra_args) # Load the dataset class from the registry dataset_class = REGISTRY.get_dataset(dataset_name, subset=dataset_subset) if dataset_class is None: typer.echo(f"Dataset '{dataset_name}' not found in registry.") raise typer.Exit(code=1) dataset = dataset_class( split=dataset_split, subset=dataset_subset, ) # Judge the dataset results = judge_dataset(judge_config, dataset=dataset) # Save the results if output_file: typer.echo(f"Saving results to {output_file}") with jsonlines.open(output_file, mode="w") as writer: writer.write_all(results) else: for result in results: print(json.dumps(result))
[docs] def conversations( ctx: typer.Context, config: Annotated[ str, typer.Option(*cli_utils.CONFIG_FLAGS, help="Path to the judge config file") ], input_file: Annotated[ Optional[str], typer.Option(help="Path to the input file (jsonl)") ] = None, output_file: Annotated[ Optional[str], typer.Option(help="Path to the output file (jsonl)") ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ): """Judge a list of conversations.""" extra_args = cli_utils.parse_extra_cli_args(ctx) # Delayed imports from oumi import judge_conversations from oumi.core.types.conversation import Conversation # End imports # Load the judge config judge_config = _load_judge_config(config, extra_args) # Load the conversations from the input file if not input_file: typer.echo("Input file is required.") raise typer.Exit(code=1) input_data = load_jsonlines(input_file) conversations = [Conversation.from_dict(conv) for conv in input_data] # Judge the conversations results = judge_conversations(judge_config, judge_inputs=conversations) # Save the results if output_file: typer.echo(f"Saving results to {output_file}") with jsonlines.open(output_file, mode="w") as writer: writer.write_all(results) else: for result in results: print(json.dumps(result))
[docs] def model( ctx: typer.Context, config: Annotated[ str, typer.Option(*cli_utils.CONFIG_FLAGS, help="Path to the judge config file") ], inference_config: Annotated[ str, typer.Option(*cli_utils.CONFIG_FLAGS, help="Path to the inference config file"), ], input_file: Annotated[ Optional[str], typer.Option(help="Path to the input file (jsonl)") ] = None, output_file: Annotated[ Optional[str], typer.Option(help="Path to the output file (jsonl)") ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ): """Judge the outputs of a model on a dataset.""" # Delayed imports from oumi import judge_conversations from oumi.builders.inference_engines import build_inference_engine from oumi.core.types.conversation import Conversation # End imports judge_extra_args = cli_utils.parse_extra_cli_args(ctx) # Load the judge config judge_config = _load_judge_config(config, judge_extra_args) # Load the inference config inference_extra_args = cli_utils.parse_extra_cli_args(ctx) model_inference_config: InferenceConfig = InferenceConfig.from_yaml_and_arg_list( inference_config, inference_extra_args ) if not model_inference_config.engine: typer.echo("Inference engine is required.") raise typer.Exit(code=1) # Load the dataset if not input_file: typer.echo("Input file is required.") raise typer.Exit(code=1) input_data = load_jsonlines(input_file) input_conversations = [Conversation.from_dict(output) for output in input_data] # Run inference inference_engine = build_inference_engine( model_inference_config.engine, model_params=model_inference_config.model, remote_params=model_inference_config.remote_params, generation_params=model_inference_config.generation, ) model_outputs = inference_engine.infer( input=input_conversations, inference_config=model_inference_config ) results = judge_conversations(judge_config, judge_inputs=model_outputs) if output_file: typer.echo(f"Saving results to {output_file}") with jsonlines.open(output_file, mode="w") as writer: writer.write_all(results) else: for result in results: print(json.dumps(result))