Source code for oumi.utils.str_utils

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import hashlib
import logging
import os
import re
from typing import Optional

from oumi.core.tokenizers.base_tokenizer import BaseTokenizer


[docs] def sanitize_run_name(run_name: Optional[str]) -> Optional[str]: """Computes a sanitized version of wandb run name. A valid run name may only contain alphanumeric characters, dashes, underscores, and dots, with length not exceeding max limit. Args: run_name: The original raw value of run name. """ if not run_name: return run_name # Technically, the limit is 128 chars, but we limit to 100 characters # because the system may generate aux artifact names e.g., by prepending a prefix # (e.g., "model-") to our original run name, which are also subject # to max 128 chars limit. _MAX_RUN_NAME_LENGTH = 100 # Replace all unsupported characters with '_'. result = re.sub("[^a-zA-Z0-9\\_\\-\\.]", "_", run_name) if len(result) > _MAX_RUN_NAME_LENGTH: suffix = "..." + hashlib.shake_128(run_name.encode("utf-8")).hexdigest(8) result = result[0 : (_MAX_RUN_NAME_LENGTH - len(suffix))] + suffix if result != run_name: logger = logging.getLogger("oumi") logger.warning(f"Run name '{run_name}' got sanitized to '{result}'") return result
[docs] def try_str_to_bool(s: str) -> Optional[bool]: """Attempts to convert a string representation to a boolean value. This function interprets various string inputs as boolean values. It is case-insensitive and recognizes common boolean representations. Args: s: The string to convert to a boolean. Returns: bool: The boolean interpretation of the input string, or `None` for unrecognized string values. Examples: >>> str_to_bool("true") # doctest: +SKIP True >>> str_to_bool("FALSE") # doctest: +SKIP False >>> str_to_bool("1") # doctest: +SKIP True >>> str_to_bool("no") # doctest: +SKIP False >>> str_to_bool("peach") # doctest: +SKIP None """ s = s.strip().lower() if s in ("true", "yes", "1", "on", "t", "y"): return True elif s in ("false", "no", "0", "off", "f", "n"): return False return None
[docs] def str_to_bool(s: str) -> bool: """Convert a string representation to a boolean value. This function interprets various string inputs as boolean values. It is case-insensitive and recognizes common boolean representations. Args: s: The string to convert to a boolean. Returns: bool: The boolean interpretation of the input string. Raises: ValueError: If the input string cannot be interpreted as a boolean. Examples: >>> str_to_bool("true") # doctest: +SKIP True >>> str_to_bool("FALSE") # doctest: +SKIP False >>> str_to_bool("1") # doctest: +SKIP True >>> str_to_bool("no") # doctest: +SKIP False """ result = try_str_to_bool(s) if result is None: raise ValueError(f"Cannot convert '{s}' to boolean.") return result
[docs] def compute_utf8_len(s: str) -> int: """Computes string length in UTF-8 bytes.""" # This is inefficient: allocates a temporary copy of string content. # FIXME Can we do better? return len(s.encode("utf-8"))
[docs] def get_editable_install_override_env_var() -> bool: """Returns whether OUMI_FORCE_EDITABLE_INSTALL env var is set to a truthy value.""" s = os.environ.get("OUMI_FORCE_EDITABLE_INSTALL", "") mode = s.lower().strip() bool_result = try_str_to_bool(mode) if bool_result is not None: return bool_result return False
# Experimental function, only for developer usage.
[docs] def set_oumi_install_editable(setup: str) -> str: """Tries to replace oumi PyPi installs with editable installation from source. For example, the following line: `pip install uv && uv pip -q install oumi[gpu,dev] vllm` will be replaced with: `pip install uv && uv pip -q install -e '.[gpu,dev]' vllm` Args: setup (str): The bash setup script to modify. May be multi-line. Returns: The modified setup script. """ setup_lines = setup.split("\n") for i, line in enumerate(setup_lines): # Skip comments. if line.strip().startswith("#"): continue # In summary, this regex looks for variants of `pip install oumi` and replaces # the oumi package with an editable install from the current directory. # # Tip: Use https://regexr.com/ or an LLM to help understand the regex. # It captures any misc. tokens like flags for the pip and # install commands, in addition to any optional dependencies oumi is installed # with. # # `((?:[-'\"\w]+ +)*)` matches whitespace-separated tokens potentially # containing quotes, such as flag names and values. # `((?:[-'\",\[\]\w]+ +)*)` does the same, with the addition of commas and # brackets, which may be present for packages with optional dependencies. # Since these don't include special characters like && and ;, it shouldn't span # across multiple pip install commands. # `(?<!-e )` means we don't match if the previous token is -e. This means an # editable install of a local dir called "oumi" is being done, so we skip it. # NOTE: We ideally should check for `--editable` as well, but Python re doesn't # support lookbehinds with variable length. # We additionally consume quotation marks around oumi if present. # Finally, `(\[[^\]]*\])?['\"]?` captures optional dependencies, if present. pattern = ( r"pip3? +((?:[-'\"\w]+ +)*)install +((?:[-'\",\[\]\w]+ +)*)" r"(?<!-e )['\"]?oumi(\[[^\]]*\])?['\"]?" ) # Compared to the pattern we captured, the changes are replacing `oumi` with # `.` and adding `-e` to make the install editable. replacement = r"pip \1install \2-e '.\3'" result = re.sub(pattern, replacement, line) if result == line: continue # Replace the line in the setup script. logger = logging.getLogger("oumi") logger.info(f"Detected the following oumi installation: `{line}`") logger.info(f"Replaced with: `{result}`") setup_lines[i] = result return "\n".join(setup_lines)
[docs] def truncate_to_max_tokens_limit( text: str, tokenizer: BaseTokenizer, *, max_tokens: int, truncation_side: str = "right", ) -> tuple[str, int]: """Truncates text to `max_length` in tokens. Args: text: A text prompt. tokenizer: The tokenizer used for encoding the data. max_tokens: Maximum number of tokens to keep. truncation_side: The side to truncate the tokens ("right" or "left"). Returns: A tuple containing truncated text prompt and the number of tokens. """ if max_tokens <= 0: raise ValueError("`max_tokens` must be a positive integer") elif truncation_side not in ("left", "right"): raise ValueError( f"Invalid truncation_side: '{truncation_side}'. Expected 'left' or 'right'." ) if not text: return ("", 0) left_side = truncation_side == "left" # The `truncation_side` parameter isn't universally supported by all tokenizers. # Let's do left-side truncation as post-processing. result = tokenizer( text, return_offsets_mapping=True, return_length=True, max_length=(None if left_side else max_tokens), truncation=(not left_side), # Left-side truncation is done as post-processing. ) if "offset_mapping" not in result: raise RuntimeError( f"Tokenizer must return offset mapping for truncation! Got: {result.keys()}" ) token2char_offsets = result["offset_mapping"] if not isinstance(token2char_offsets, list): raise RuntimeError( "offset_mapping returned by tokenizer is not a list! " f"Got: {type(token2char_offsets)}" ) truncated_text: str = "" num_truncated_tokens: int = 0 if len(token2char_offsets) > 0: num_truncated_tokens = min(len(token2char_offsets), max_tokens) if left_side: lead_token_start = token2char_offsets[-num_truncated_tokens][0] if not (lead_token_start >= 0 and lead_token_start < len(text)): raise RuntimeError( f"Truncation error: lead_token_start={lead_token_start} " f"for text of length {len(text)}" ) truncated_text = text[lead_token_start:] else: last_token_end = token2char_offsets[num_truncated_tokens - 1][1] if not (last_token_end >= 0 and last_token_end <= len(text)): raise RuntimeError( f"Truncation error: last_token_end={last_token_end} " f"for text of length {len(text)}" ) truncated_text = text[:last_token_end] return (truncated_text, num_truncated_tokens)
[docs] def truncate_text_pieces_to_max_tokens_limit( text_pieces: list[str], tokenizer: BaseTokenizer, *, max_tokens: int, truncation_side: str = "right", ) -> list[str]: """Truncates text pieces to total length not exceeding `max_length`. Args: text_pieces: A list of text prompts. tokenizer: The tokenizer used for encoding the data. max_tokens: Maximum number of tokens to keep in all text pieces combined. truncation_side: The side to truncate the tokens ("right" or "left"). Returns: A list of truncated text prompts. """ if max_tokens <= 0: raise ValueError("`max_tokens` must be a positive integer") remaining_tokens = max_tokens result = copy.deepcopy(text_pieces) if truncation_side == "left": result.reverse() for idx, text_piece in enumerate(result): if len(text_piece) == 0: continue elif remaining_tokens > 0: truncated_text_piece, num_tokens = truncate_to_max_tokens_limit( text_piece, tokenizer=tokenizer, max_tokens=remaining_tokens, truncation_side=truncation_side, ) result[idx] = truncated_text_piece remaining_tokens -= num_tokens else: result[idx] = "" if truncation_side == "left": result.reverse() return result