Source code for oumi.cli.cli_utils

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from enum import Enum
from typing import Annotated, Optional

import typer

from oumi.utils.logging import logger

CONTEXT_ALLOW_EXTRA_ARGS = {"allow_extra_args": True, "ignore_unknown_options": True}
CONFIG_FLAGS = ["--config", "-c"]


[docs] def parse_extra_cli_args(ctx: typer.Context) -> list[str]: """Parses extra CLI arguments into a list of strings. Args: ctx: The Typer context object. Returns: List[str]: The extra CLI arguments """ args = [] # The following formats are supported: # 1. Space separated: "--foo" "2" # 2. `=`-separated: "--foo=2" try: num_args = len(ctx.args) idx = 0 while idx < num_args: original_key = ctx.args[idx] key = original_key.strip() if not key.startswith("--"): raise typer.BadParameter( "Extra arguments must start with '--'. " f"Found argument `{original_key}` at position {idx}: `{ctx.args}`" ) # Strip leading "--" key = key[2:] pos = key.find("=") if pos >= 0: # '='-separated argument value = key[(pos + 1) :].strip() key = key[:pos].strip() if not key: raise typer.BadParameter( "Empty key name for `=`-separated argument. " f"Found argument `{original_key}` at position {idx}: " f"`{ctx.args}`" ) idx += 1 else: # Space separated argument if idx + 1 >= num_args: raise typer.BadParameter( "Trailing argument has no value assigned. " f"Found argument `{original_key}` at position {idx}: " f"`{ctx.args}`" ) value = ctx.args[idx + 1].strip() idx += 2 if value.startswith("--"): logger.warning( f"Argument value ('{value}') starts with `--`! " f"Key: '{original_key}'" ) cli_arg = f"{key}={value}" args.append(cli_arg) except ValueError: bad_args = " ".join(ctx.args) raise typer.BadParameter( "Extra arguments must be in `--argname value` pairs. " f"Recieved: `{bad_args}`" ) logger.debug(f"\n\nParsed CLI args:\n{args}\n\n") return args
[docs] def configure_common_env_vars() -> None: """Sets common environment variables if needed.""" if "ACCELERATE_LOG_LEVEL" not in os.environ: os.environ["ACCELERATE_LOG_LEVEL"] = "info" if "TOKENIZERS_PARALLELISM" not in os.environ: os.environ["TOKENIZERS_PARALLELISM"] = "false"
[docs] class LogLevel(str, Enum): """The available logging levels.""" DEBUG = logging.getLevelName(logging.DEBUG) INFO = logging.getLevelName(logging.INFO) WARNING = logging.getLevelName(logging.WARNING) ERROR = logging.getLevelName(logging.ERROR) CRITICAL = logging.getLevelName(logging.CRITICAL)
[docs] def set_log_level(level: Optional[LogLevel]): """Sets the logging level for the current command. Args: level (Optional[LogLevel]): The log level to use. """ if not level: return uppercase_level = level.upper() logger.setLevel(uppercase_level) print(f"Set log level to {uppercase_level}")
LOG_LEVEL_TYPE = Annotated[ Optional[LogLevel], typer.Option( "--log-level", "-log", help="The logging level for the specified command.", show_default=False, show_choices=True, case_sensitive=False, callback=set_log_level, ), ]