Source code for oumi.builders.data

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections.abc import Sequence
from typing import Callable, Optional, TypeVar, Union, cast

import datasets
from trl.trainer import ConstantLengthDataset

from oumi.core.configs import (
    DataParams,
    DatasetParams,
    DatasetSplit,
    DatasetSplitParams,
    MixtureStrategy,
    TrainingConfig,
)
from oumi.core.datasets.base_pretraining_dataset import BasePretrainingDataset
from oumi.core.datasets.pretraining_async_text_dataset import (
    PretrainingAsyncTextDataset,
)
from oumi.core.registry import REGISTRY
from oumi.core.tokenizers import BaseTokenizer
from oumi.utils.hf_datasets_utils import is_cached_to_disk_hf_dataset
from oumi.utils.logging import logger

DatasetType = TypeVar("DatasetType", datasets.Dataset, datasets.IterableDataset)


[docs] def build_dataset_mixture( config: TrainingConfig, tokenizer: Optional[BaseTokenizer], dataset_split: DatasetSplit, seed: Optional[int] = None, ) -> Union[ConstantLengthDataset, DatasetType, PretrainingAsyncTextDataset]: """Builds a dataset for the specified split. Args: config: The training config. tokenizer: The tokenizer object to use for preprocessing. dataset_split: The split of the dataset to load. seed: If specified, a seed used for random sampling. kwargs: Keyword arguments. Returns: dataset: The built dataset for `dataset_split`. """ dataset_split_params: DatasetSplitParams = config.data.get_split(dataset_split) if dataset_split_params.use_torchdata: from oumi.builders.oumi_data import build_dataset_mixture as build_oumi_dataset logger.warning( "Using torchdata preprocessing pipeline. " "This is currently in beta and may not be stable." ) # TODO: OPE-271. Some type hackery going on here. # We return a torchdata.IterDataPipe instead of a HuggingFace Dataset or # IterableDataset. This is a temporary workaround until torchdata is stable # and becomes the default processign pipeline. return build_oumi_dataset(config, tokenizer, dataset_split, seed) # type: ignore # Check if the underlying dataset is already packed, or if we need to pack it # ourselves. is_packed = _is_mixture_packed(dataset_split_params) datasets = [ _sample_dataset( _load_dataset( dataset_params=dataset_params, stream=dataset_split_params.stream, tokenizer=tokenizer, ), dataset_params=dataset_params, stream=dataset_split_params.stream, ) for dataset_params in dataset_split_params.datasets ] mixture_proportions = [ dataset.mixture_proportion for dataset in dataset_split_params.datasets ] # Interleave datasets using mixture_strategy. dataset = _mix_datasets( datasets, mixture_proportions, dataset_split_params.mixture_strategy, dataset_split_params.seed, ) if dataset_split_params.pack and not is_packed: # Fetch max sequence length. If not specified, defaults to 1024. dataset_kwargs = {} if config.model.model_max_length: dataset_kwargs["seq_length"] = config.model.model_max_length dataset = PretrainingAsyncTextDataset( tokenizer, dataset, **dataset_kwargs, ) return dataset
[docs] def build_dataset_from_params( dataset_params: DatasetParams, tokenizer: Optional[BaseTokenizer], seed: Optional[int] = None, stream: bool = False, pack: bool = False, use_torchdata: Optional[bool] = None, ) -> Union[ConstantLengthDataset, DatasetType, PretrainingAsyncTextDataset]: """Builds a dataset from a dataset params object. Please refer to `DatasetParams` & `DatasetSplitParams` for a description of all the arguments. """ training_config = TrainingConfig( data=DataParams( train=DatasetSplitParams( datasets=[dataset_params], stream=stream, pack=pack, use_torchdata=use_torchdata, ) ) ) return build_dataset_mixture( config=training_config, dataset_split=DatasetSplit.TRAIN, tokenizer=tokenizer, seed=seed, )
[docs] def build_dataset( dataset_name: str, tokenizer: Optional[BaseTokenizer], seed: Optional[int] = None, stream: bool = False, pack: bool = False, use_torchdata: Optional[bool] = None, **kwargs, ) -> Union[ConstantLengthDataset, DatasetType, PretrainingAsyncTextDataset]: """Builds a dataset from a dataset name. Please refer to `DatasetParams` & `DatasetSplitParams` for a description of the all the arguments. """ dataset_params = DatasetParams( dataset_name=dataset_name, **kwargs, ) return build_dataset_from_params( dataset_params=dataset_params, tokenizer=tokenizer, seed=seed, stream=stream, pack=pack, use_torchdata=use_torchdata, )
def _mix_datasets( dataset_list: list[DatasetType], mixture_proportions: Sequence[Optional[float]], mixture_strategy: str, seed: Optional[int], ) -> DatasetType: """Joins multiple datasets using the provided `mixture_strategy`.""" if any([proportion is None for proportion in mixture_proportions]): # All datasets should be concatenated when no proportion is specified. return datasets.concatenate_datasets(dataset_list) else: # All mixture_proportions are not None. mixture_proportions = cast(list[float], mixture_proportions) # Interleave datasets using the specified proportions and mixture strategy. return datasets.interleave_datasets( dataset_list, probabilities=mixture_proportions, seed=seed, stopping_strategy=(MixtureStrategy(mixture_strategy).get_literal_value()), ) def _sample_dataset( dataset: Union[ datasets.DatasetDict, datasets.Dataset, datasets.IterableDatasetDict, datasets.IterableDataset, ], dataset_params: DatasetParams, stream: bool, ) -> DatasetType: """Samples the specified dataset.""" if dataset_params.sample_count is None: # No sampling. dataset = cast(DatasetType, dataset) if dataset_params.shuffle: dataset = dataset.shuffle(dataset_params.seed) return dataset if stream: dataset = cast(datasets.IterableDataset, dataset) if dataset_params.shuffle: dataset = dataset.shuffle(dataset_params.seed) generator = _build_iterable_dataset_sampler( dataset, dataset_params.sample_count ) return cast( DatasetType, datasets.IterableDataset.from_generator(generator, dataset.features), ) dataset = cast(datasets.Dataset, dataset) if dataset.num_rows >= dataset_params.sample_count: if dataset_params.shuffle: dataset = dataset.shuffle(dataset_params.seed).flatten_indices() return cast(DatasetType, dataset.take(dataset_params.sample_count)) # Oversample the dataset. oversampling_copies = int(dataset_params.sample_count // dataset.num_rows) dataset_list = [ cast(datasets.Dataset, copy.deepcopy(dataset)) for _ in range(oversampling_copies) ] remaining_rows = dataset_params.sample_count % dataset.num_rows if remaining_rows > 0: sampled_dataset = cast(datasets.Dataset, dataset) if dataset_params.shuffle: sampled_dataset = sampled_dataset.shuffle(dataset_params.seed) dataset_list.append(sampled_dataset.take(remaining_rows)) oversampled_dataset = datasets.concatenate_datasets(dataset_list) if dataset_params.shuffle: oversampled_dataset = oversampled_dataset.shuffle( dataset_params.seed ).flatten_indices() return cast(DatasetType, oversampled_dataset) def _build_iterable_dataset_sampler( dataset: datasets.IterableDataset, n: int ) -> Callable: """Returns a generator that supports oversampling an IterableDataset.""" def _generator(): generation_count = 0 while generation_count < n: for generation in dataset: generation_count += 1 yield generation if generation_count >= n: break return _generator def _load_dataset( dataset_params: DatasetParams, stream: bool, tokenizer: Optional[BaseTokenizer] = None, ) -> Union[ datasets.DatasetDict, datasets.Dataset, datasets.IterableDatasetDict, datasets.IterableDataset, ]: """Loads a dataset with the specified name and subset. Note: For custom map datasets, streaming is only partially supported: - The full dataset is downloaded (or loaded from disk), and loaded in memory. - However, transformations are applied lazily in streaming mode. The raw dataset is not post-processed (i.e., not "transformed") before training starts. Instead, it's returned as `IterableDataset` with lazy feature generation i.e., `transform()` is called on-demand during training. """ dataset_class = REGISTRY.get_dataset( dataset_params.dataset_name, subset=dataset_params.subset ) if dataset_class is not None: dataset_kwargs = {**dataset_params.dataset_kwargs} if dataset_params.transform_num_workers is not None: dataset_kwargs["transform_num_workers"] = ( dataset_params.transform_num_workers ) dataset = dataset_class( dataset_name=dataset_params.dataset_name, dataset_path=dataset_params.dataset_path, split=dataset_params.split, subset=dataset_params.subset, tokenizer=tokenizer, trust_remote_code=dataset_params.trust_remote_code, **dataset_kwargs, ) return dataset.to_hf(return_iterable=stream) # Load a fully preprocessed (tokenized, etc) dataset from disk. # The raw data will be used for training, with any processing # other than collation (if enabled). dataset_path = dataset_params.dataset_path if dataset_path and is_cached_to_disk_hf_dataset(dataset_path): return datasets.Dataset.load_from_disk(dataset_path) else: return datasets.load_dataset( dataset_params.dataset_name, name=dataset_params.subset, split=dataset_params.split, streaming=stream, trust_remote_code=dataset_params.trust_remote_code, **dataset_params.dataset_kwargs, ) def _is_mixture_packed(dataset_split_params: DatasetSplitParams) -> bool: """Returns whether all datasets in the mixture are packed. Raises: ValueError: If a mixture of packed and unpacked datasets is detected. """ num_packed = 0 for dataset in dataset_split_params.datasets: dataset_class = REGISTRY.get_dataset( dataset.dataset_name, subset=dataset.subset ) if dataset_class is not None and issubclass( dataset_class, # type: ignore BasePretrainingDataset, ): num_packed += 1 if num_packed == len(dataset_split_params.datasets): return True elif num_packed == 0: return False else: # Currently, registered datasets get packed and unregistered ones don't. We # don't support mixing both at the moment. raise ValueError( "We currently don't support mixing registered and unregistered datasets." )