Source code for oumi.utils.analysis_utils

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from pathlib import Path
from typing import Any, Optional

from oumi.builders.models import build_tokenizer
from oumi.core.configs.analyze_config import AnalyzeConfig
from oumi.core.configs.params.model_params import ModelParams
from oumi.core.datasets.base_map_dataset import BaseMapDataset
from oumi.core.registry.registry import REGISTRY

logger = logging.getLogger(__name__)


[docs] def build_tokenizer_from_config(tokenizer_config: Optional[dict[str, Any]]): """Build a tokenizer from configuration dictionary. Args: tokenizer_config: Dictionary containing tokenizer configuration Returns: Built tokenizer or None if config is None Raises: ValueError: If required fields are missing from tokenizer_config """ if not tokenizer_config: return None if "model_name" not in tokenizer_config: raise ValueError("tokenizer_config must contain 'model_name' field") model_params = ModelParams( model_name=tokenizer_config["model_name"], tokenizer_kwargs=tokenizer_config.get("tokenizer_kwargs", {}), trust_remote_code=tokenizer_config.get("trust_remote_code", False), ) tokenizer = build_tokenizer(model_params) logger.info(f"Built tokenizer for model: {model_params.model_name}") return tokenizer
[docs] def load_dataset_from_config( config: AnalyzeConfig, tokenizer: Optional[Any] = None ) -> BaseMapDataset: """Load dataset based on configuration. This function loads datasets directly from the registry for analysis purposes. If a tokenizer is provided, it will be passed to the dataset constructor. For custom datasets, it supports loading from local files using TextSftJsonLinesDataset for text data and VLJsonlinesDataset for vision-language data. Args: config: Configuration object containing dataset parameters tokenizer: Optional tokenizer to use with the dataset Returns: Loaded dataset """ dataset_name = config.dataset_name split = config.split subset = config.subset dataset_path = config.dataset_path dataset_format = config.dataset_format if not dataset_name and not dataset_path: raise ValueError("Either dataset_name or dataset_path must be provided") # Handle custom dataset loading from local files if dataset_path: return _load_custom_dataset_from_path( dataset_path, dataset_format, tokenizer, config ) # Handle registered dataset loading try: # Load dataset from the REGISTRY if dataset_name is None: raise ValueError("dataset_name cannot be None for registered datasets") dataset_class = REGISTRY.get_dataset(dataset_name, subset=subset) if dataset_class is not None: # Prepare dataset constructor arguments dataset_kwargs = { "dataset_name": dataset_name, "dataset_path": None, "split": split, "subset": subset, "trust_remote_code": config.trust_remote_code, } # Add tokenizer if provided if tokenizer is not None: dataset_kwargs["tokenizer"] = tokenizer # Add processor parameters for vision-language datasets if config.processor_name: dataset_kwargs["processor_name"] = config.processor_name dataset_kwargs["processor_kwargs"] = config.processor_kwargs dataset_kwargs["trust_remote_code"] = config.trust_remote_code # Load registered dataset with parameters dataset = dataset_class(**dataset_kwargs) # Ensure we return a BaseMapDataset if isinstance(dataset, BaseMapDataset): return dataset else: raise NotImplementedError( f"Dataset type {type(dataset)} is not supported for analysis. " "Please use a dataset that inherits from BaseMapDataset." ) else: # TODO: Implement HuggingFace Hub loading raise NotImplementedError( f"Dataset '{dataset_name}' is not registered in the REGISTRY. " "Loading from HuggingFace Hub is not yet implemented." ) except Exception as e: logger.error(f"Failed to load dataset {dataset_name}: {e}") raise
def _load_custom_dataset_from_path( dataset_path: str, dataset_format: Optional[str], tokenizer: Optional[Any], config: AnalyzeConfig, ) -> BaseMapDataset: """Load a custom dataset from a local file path. Args: dataset_path: Path to the dataset file dataset_format: Format of the dataset ('oumi' or 'alpaca') - required for custom datasets tokenizer: Optional tokenizer to use with the dataset config: Configuration object containing additional parameters Returns: Loaded dataset (TextSftJsonLinesDataset or VLJsonlinesDataset) """ # Import here to avoid circular imports from oumi.datasets.sft.sft_jsonlines import TextSftJsonLinesDataset from oumi.datasets.vision_language.vision_jsonlines import VLJsonlinesDataset path = Path(dataset_path) if not path.exists(): raise FileNotFoundError(f"Dataset file not found: {dataset_path}") if not path.is_file(): raise ValueError( f"Dataset path must be a file, not a directory: {dataset_path}" ) # Multimodal handling is explicit via config.is_multimodal if config.is_multimodal is True: # Note: processor_name requirement is already validated in AnalyzeConfig dataset_kwargs = { "dataset_path": str(path), "tokenizer": tokenizer, "processor_name": config.processor_name, "processor_kwargs": config.processor_kwargs, "trust_remote_code": config.trust_remote_code, } dataset_kwargs = {k: v for k, v in dataset_kwargs.items() if v is not None} dataset = VLJsonlinesDataset(**dataset_kwargs) logger.info(f"Loaded vision-language dataset from: {dataset_path}") return dataset elif config.is_multimodal is False: # If explicitly forced to text, load as text-only dataset_kwargs = { "dataset_path": str(path), "format": dataset_format, } if tokenizer is not None: dataset_kwargs["tokenizer"] = tokenizer dataset_kwargs = {k: v for k, v in dataset_kwargs.items() if v is not None} dataset = TextSftJsonLinesDataset(**dataset_kwargs) logger.info(f"Loaded text dataset from: {dataset_path}") return dataset else: # This should never happen due to config validation # is_multimodal=None case is already caught by AnalyzeConfig.__post_init__ raise ValueError("Invalid vision-language configuration")