Source code for oumi.cli.infer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Annotated, Final, Optional

import typer
from rich.table import Table

import oumi.cli.cli_utils as cli_utils
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
from oumi.utils.logging import logger

_DEFAULT_CLI_PDF_DPI: Final[int] = 200


[docs] def infer( ctx: typer.Context, config: Annotated[ str, typer.Option( *cli_utils.CONFIG_FLAGS, help="Path to the configuration file for inference.", ), ], interactive: Annotated[ bool, typer.Option("-i", "--interactive", help="Run in an interactive session."), ] = False, image: Annotated[ Optional[str], typer.Option( "--image", help=( "File path or URL of an input image to be used with image+text VLLMs. " "Only used in interactive mode." ), ), ] = None, system_prompt: Annotated[ Optional[str], typer.Option( "--system-prompt", help=( "System prompt for task-specific instructions. " "Only used in interactive mode." ), ), ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ): """Run inference on a model. If `input_filepath` is provided in the configuration file, inference will run on those input examples. Otherwise, inference will run interactively with user-provided inputs. Args: ctx: The Typer context object. config: Path to the configuration file for inference. output_dir: Directory to save configs (defaults to OUMI_DIR env var or ~/.oumi/fetch). interactive: Whether to run in an interactive session. image: Path to the input image for `image+text` VLLMs. system_prompt: System prompt for task-specific instructions. level: The logging level for the specified command. """ extra_args = cli_utils.parse_extra_cli_args(ctx) config = str( cli_utils.resolve_and_fetch_config( try_get_config_name_for_alias(config, AliasType.INFER), ) ) # Delayed imports from oumi import infer as oumi_infer from oumi import infer_interactive as oumi_infer_interactive from oumi.core.configs import InferenceConfig from oumi.utils.image_utils import ( create_png_bytes_from_image_list, load_image_png_bytes_from_path, load_image_png_bytes_from_url, load_pdf_pages_from_path, load_pdf_pages_from_url, ) # End imports parsed_config: InferenceConfig = InferenceConfig.from_yaml_and_arg_list( config, extra_args, logger=logger ) parsed_config.finalize_and_validate() # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning os.environ["TOKENIZERS_PARALLELISM"] = "false" input_image_png_bytes: Optional[list[bytes]] = None if image: image_lower = image.lower() if image_lower.startswith("http://") or image_lower.startswith("https://"): if image_lower.endswith(".pdf"): input_image_png_bytes = create_png_bytes_from_image_list( load_pdf_pages_from_url(image, dpi=_DEFAULT_CLI_PDF_DPI) ) else: input_image_png_bytes = [load_image_png_bytes_from_url(image)] else: if image_lower.endswith(".pdf"): input_image_png_bytes = create_png_bytes_from_image_list( load_pdf_pages_from_path(image, dpi=_DEFAULT_CLI_PDF_DPI) ) else: input_image_png_bytes = [load_image_png_bytes_from_path(image)] if parsed_config.input_path: if interactive: logger.warning( "Input path provided, skipping interactive inference. " "To run in interactive mode, do not provide an input path." ) generations = oumi_infer(parsed_config) # Don't print results if output_filepath is provided. if parsed_config.output_path: return table = Table( title="Inference Results", title_style="bold magenta", show_edge=False, show_lines=True, ) table.add_column("Conversation", style="green") for generation in generations: table.add_row(repr(generation)) cli_utils.CONSOLE.print(table) return if not interactive: logger.warning( "No input path provided, running in interactive mode. " "To run with an input path, provide one in the configuration file." ) return oumi_infer_interactive( parsed_config, input_image_bytes=input_image_png_bytes, system_prompt=system_prompt, )