# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcopyfromcollections.abcimportSequencefromtypingimportCallable,Optional,TypeVar,Union,castimportdatasetsfromoumi.core.configsimport(DataParams,DatasetParams,DatasetSplit,DatasetSplitParams,MixtureStrategy,)fromoumi.core.datasets.base_pretraining_datasetimportBasePretrainingDatasetfromoumi.core.datasets.pretraining_async_text_datasetimport(PretrainingAsyncTextDataset,)fromoumi.core.registryimportREGISTRYfromoumi.core.tokenizersimportBaseTokenizerfromoumi.utils.hf_utilsimportis_cached_to_disk_hf_datasetfromoumi.utils.loggingimportloggerDatasetType=TypeVar("DatasetType",datasets.Dataset,datasets.IterableDataset)
[docs]defbuild_dataset_mixture(data_params:DataParams,tokenizer:Optional[BaseTokenizer],dataset_split:DatasetSplit,seq_length:Optional[int]=None,seed:Optional[int]=None,)->Union[DatasetType,PretrainingAsyncTextDataset]:"""Builds a dataset for the specified split. Args: data_params: The data params. tokenizer: The tokenizer object to use for preprocessing. dataset_split: The split of the dataset to load. seq_length: The length each example will be packed to. This is only used if packing is requested, and the dataset isn't already packed. If not provided, defaults to 1024. seed: If specified, a seed used for random sampling. kwargs: Keyword arguments. Returns: dataset: The built dataset for `dataset_split`. """dataset_split_params:DatasetSplitParams=data_params.get_split(dataset_split)ifdataset_split_params.use_torchdata:fromoumi.builders.oumi_dataimportbuild_dataset_mixtureasbuild_oumi_datasetlogger.warning("Using torchdata preprocessing pipeline. ""This is currently in beta and may not be stable.")# TODO: OPE-271. Some type hackery going on here.# We return a torchdata.IterDataPipe instead of a HuggingFace Dataset or# IterableDataset. This is a temporary workaround until torchdata is stable# and becomes the default processing pipeline.returnbuild_oumi_dataset(data_params,tokenizer,dataset_split,seed)# type: ignore# Check if the underlying dataset is already packed, or if we need to pack it# ourselves.is_packed=_is_mixture_packed(dataset_split_params)datasets=[_sample_dataset(_load_dataset(dataset_params=dataset_params,stream=dataset_split_params.stream,tokenizer=tokenizer,),dataset_params=dataset_params,stream=dataset_split_params.stream,)fordataset_paramsindataset_split_params.datasets]mixture_proportions=[dataset.mixture_proportionfordatasetindataset_split_params.datasets]# Interleave datasets using mixture_strategy.dataset=_mix_datasets(datasets,mixture_proportions,dataset_split_params.mixture_strategy,dataset_split_params.seed,)ifdataset_split_params.packandnotis_packed:# Fetch max sequence length. If not specified, defaults to 1024.dataset_kwargs={}ifseq_lengthisnotNone:dataset_kwargs["seq_length"]=seq_lengthdataset=PretrainingAsyncTextDataset(tokenizer,dataset,**dataset_kwargs,)returndataset
[docs]defbuild_dataset(dataset_name:str,tokenizer:Optional[BaseTokenizer],seed:Optional[int]=None,stream:bool=False,pack:bool=False,use_torchdata:Optional[bool]=None,**kwargs,)->Union[DatasetType,PretrainingAsyncTextDataset]:"""Builds a dataset from a dataset name. Please refer to `DatasetParams` & `DatasetSplitParams` for a description of the all the arguments. """dataset_params=DatasetParams(dataset_name=dataset_name,**kwargs,)data_params=DataParams(train=DatasetSplitParams(datasets=[dataset_params],stream=stream,pack=pack,use_torchdata=use_torchdata,))returnbuild_dataset_mixture(data_params=data_params,dataset_split=DatasetSplit.TRAIN,tokenizer=tokenizer,seed=seed,)
def_mix_datasets(dataset_list:list[DatasetType],mixture_proportions:Sequence[Optional[float]],mixture_strategy:str,seed:Optional[int],)->DatasetType:"""Joins multiple datasets using the provided `mixture_strategy`."""ifany([proportionisNoneforproportioninmixture_proportions]):# All datasets should be concatenated when no proportion is specified.returndatasets.concatenate_datasets(dataset_list)else:# All mixture_proportions are not None.mixture_proportions=cast(list[float],mixture_proportions)# Interleave datasets using the specified proportions and mixture strategy.returndatasets.interleave_datasets(dataset_list,probabilities=mixture_proportions,seed=seed,stopping_strategy=(MixtureStrategy(mixture_strategy).get_literal_value()),)def_sample_dataset(dataset:Union[datasets.DatasetDict,datasets.Dataset,datasets.IterableDatasetDict,datasets.IterableDataset,],dataset_params:DatasetParams,stream:bool,)->DatasetType:"""Samples the specified dataset."""ifdataset_params.sample_countisNone:# No sampling.ifdataset_params.shuffle:dataset=dataset.shuffle(dataset_params.seed)dataset=cast(DatasetType,dataset)returndatasetifstream:dataset=cast(datasets.IterableDataset,dataset)ifdataset_params.shuffle:dataset=dataset.shuffle(dataset_params.seed)generator=_build_iterable_dataset_sampler(dataset,dataset_params.sample_count)returncast(DatasetType,datasets.IterableDataset.from_generator(generator,dataset.features),)dataset=cast(datasets.Dataset,dataset)ifdataset.num_rows>=dataset_params.sample_count:ifdataset_params.shuffle:dataset=dataset.shuffle(dataset_params.seed).flatten_indices()returncast(DatasetType,dataset.take(dataset_params.sample_count))# Oversample the dataset.oversampling_copies=int(dataset_params.sample_count//dataset.num_rows)dataset_list=[cast(datasets.Dataset,copy.deepcopy(dataset))for_inrange(oversampling_copies)]remaining_rows=dataset_params.sample_count%dataset.num_rowsifremaining_rows>0:sampled_dataset=cast(datasets.Dataset,dataset)ifdataset_params.shuffle:sampled_dataset=sampled_dataset.shuffle(dataset_params.seed)dataset_list.append(sampled_dataset.take(remaining_rows))oversampled_dataset=datasets.concatenate_datasets(dataset_list)ifdataset_params.shuffle:oversampled_dataset=oversampled_dataset.shuffle(dataset_params.seed).flatten_indices()returncast(DatasetType,oversampled_dataset)def_build_iterable_dataset_sampler(dataset:datasets.IterableDataset,n:int)->Callable:"""Returns a generator that supports oversampling an IterableDataset."""def_generator():generation_count=0whilegeneration_count<n:forgenerationindataset:generation_count+=1yieldgenerationifgeneration_count>=n:breakreturn_generatordef_load_dataset(dataset_params:DatasetParams,stream:bool,tokenizer:Optional[BaseTokenizer]=None,)->Union[datasets.DatasetDict,datasets.Dataset,datasets.IterableDatasetDict,datasets.IterableDataset,]:"""Loads a dataset with the specified name and subset. Note: For custom map datasets, streaming is only partially supported: - The full dataset is downloaded (or loaded from disk), and loaded in memory. - However, transformations are applied lazily in streaming mode. The raw dataset is not post-processed (i.e., not "transformed") before training starts. Instead, it's returned as `IterableDataset` with lazy feature generation i.e., `transform()` is called on-demand during training. """dataset_class=REGISTRY.get_dataset(dataset_params.dataset_name,subset=dataset_params.subset)ifdataset_classisnotNone:dataset_kwargs={**dataset_params.dataset_kwargs}ifdataset_params.transform_num_workersisnotNone:dataset_kwargs["transform_num_workers"]=(dataset_params.transform_num_workers)# Use the dataset name override from 'dataset_kwargs' if specified (OPE-897).dataset_name=(dataset_kwargs.pop("dataset_name_override",None)ordataset_params.dataset_name)dataset=dataset_class(dataset_name=dataset_name,dataset_path=dataset_params.dataset_path,split=dataset_params.split,subset=dataset_params.subset,tokenizer=tokenizer,trust_remote_code=dataset_params.trust_remote_code,**dataset_kwargs,)returndataset.to_hf(return_iterable=stream)# Load a fully preprocessed (tokenized, etc) dataset from disk.# The raw data will be used for training, with any processing# other than collation (if enabled).dataset_path=dataset_params.dataset_pathifdataset_pathandis_cached_to_disk_hf_dataset(dataset_path):returndatasets.Dataset.load_from_disk(dataset_path)else:returndatasets.load_dataset(dataset_params.dataset_name,name=dataset_params.subset,split=dataset_params.split,streaming=stream,trust_remote_code=dataset_params.trust_remote_code,**dataset_params.dataset_kwargs,)def_is_mixture_packed(dataset_split_params:DatasetSplitParams)->bool:"""Returns whether all datasets in the mixture are packed. Raises: ValueError: If a mixture of packed and unpacked datasets is detected. """num_packed=0fordatasetindataset_split_params.datasets:dataset_class=REGISTRY.get_dataset(dataset.dataset_name,subset=dataset.subset)ifdataset_classisnotNoneandissubclass(dataset_class,# type: ignoreBasePretrainingDataset,):num_packed+=1ifnum_packed==len(dataset_split_params.datasets):returnTrueelifnum_packed==0:returnFalseelse:# Currently, registered datasets get packed and unregistered ones don't. We# don't support mixing both at the moment.raiseValueError("We currently don't support mixing registered and unregistered datasets.")