Source code for oumi.core.configs.params.data_params
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importmathfromdataclassesimportdataclass,field,fieldsfromenumimportEnumfromtypingimportAny,Literal,Optional,UnionfromomegaconfimportMISSINGfromoumi.core.configs.params.base_paramsimportBaseParams# Training Params### Dataset Splits#
[docs]classDatasetSplit(Enum):"""Enum representing the split for a dataset."""TRAIN="train"TEST="test"VALIDATION="validation"
[docs]classMixtureStrategy(str,Enum):"""Enum representing the supported mixture strategies for datasets."""FIRST_EXHAUSTED="first_exhausted"ALL_EXHAUSTED="all_exhausted"
[docs]defget_literal_value(self)->Literal["first_exhausted","all_exhausted"]:"""Returns a literal value of the enum."""ifself.value==MixtureStrategy.FIRST_EXHAUSTED:return"first_exhausted"elifself.value==MixtureStrategy.ALL_EXHAUSTED:return"all_exhausted"else:raiseValueError("Unsupported value for MixtureStrategy")
[docs]@dataclassclassDatasetParams(BaseParams):dataset_name:str=MISSING"""The name of the dataset to load. Required. This field is used to retrieve the appropriate class from the dataset registry that can be used to instantiate and preprocess the data. If `dataset_path` is not specified, then the raw data will be automatically downloaded from the huggingface hub or oumi registry. Otherwise, the dataset will be loaded from the specified `dataset_path`. """dataset_path:Optional[str]=None"""The path to the dataset to load. This can be used to load a dataset of type `dataset_name` from a custom path. If `dataset_path` is not specified, then the raw data will be automatically downloaded from the huggingface hub or oumi registry. """subset:Optional[str]=None"""The subset of the dataset to load. This is usually a subfolder within the dataset root. """split:str="train""""The split of the dataset to load. This is typically one of "train", "test", or "validation". Defaults to "train". """dataset_kwargs:dict[str,Any]=field(default_factory=dict)"""Keyword arguments to pass to the dataset constructor. These arguments will be passed directly to the dataset constructor. """sample_count:Optional[int]=None"""The number of examples to sample from the dataset. Must be non-negative. If `sample_count` is larger than the size of the dataset, then the required additional examples are sampled by looping over the original dataset. """mixture_proportion:Optional[float]=None"""The proportion of examples from this dataset relative to other datasets in the mixture. If specified, all datasets must supply this value. Must be a float in the range [0, 1.0]. The `mixture_proportion` for all input datasets must sum to 1. Examples are sampled after the dataset has been sampled using `sample_count` if specified. """shuffle:bool=False"""Whether to shuffle the dataset before any sampling occurs."""seed:Optional[int]=None"""The random seed used for shuffling the dataset before sampling. If set to `None`, shuffling will be non-deterministic. """shuffle_buffer_size:int=1000"""The size of the shuffle buffer used for shuffling the dataset before sampling."""trust_remote_code:bool=False"""Whether to trust remote code when loading the dataset."""transform_num_workers:Optional[Union[str,int]]=None"""Number of subprocesses to use for dataset post-processing (`ds.transform()`). Multiprocessing is disabled by default (`None`). You can also use the special value "auto" to let oumi automatically select the number of subprocesses. Using multiple processes can speed-up processing e.g., for large or multi-modal datasets. The parameter is only supported for Map (non-iterable) datasets. """
[docs]def__post_init__(self):"""Verifies params."""ifself.sample_countisnotNone:ifself.sample_count<0:raiseValueError("`sample_count` must be greater than 0.")ifself.mixture_proportionisnotNone:ifself.mixture_proportion<0:raiseValueError("`mixture_proportion` must be greater than 0.")ifself.mixture_proportion>1:raiseValueError("`mixture_proportion` must not be greater than 1.0 .")ifself.transform_num_workersisnotNone:ifisinstance(self.transform_num_workers,str):ifnot(self.transform_num_workers=="auto"):raiseValueError("Unknown value of transform_num_workers: "f"{self.transform_num_workers}. Must be 'auto' if string.")elif(notisinstance(self.transform_num_workers,int))or(self.transform_num_workers<=0):raiseValueError("Non-positive value of transform_num_workers: "f"{self.transform_num_workers}.")iflen(self.dataset_kwargs)>0:conflicting_keys={f.nameforfinfields(self)}.intersection(self.dataset_kwargs.keys())iflen(conflicting_keys)>0:raiseValueError("dataset_kwargs attempts to override the following "f"reserved fields: {conflicting_keys}. ""Use properties of DatasetParams instead.")
[docs]@dataclassclassDatasetSplitParams(BaseParams):datasets:list[DatasetParams]=field(default_factory=list)"""The datasets in this split."""collator_name:Optional[str]=None"""Name of Oumi data collator. Data collator controls how to form a mini-batch from individual dataset elements. Valid options are: - "text_with_padding": Dynamically pads the inputs received to the longest length. - "vision_language_with_padding": Uses VisionLanguageCollator for image+text multi-modal data. If None, then a default collator will be assigned. """pack:bool=False"""Whether to pack the text into constant-length chunks. Each chunk will be the size of the model's max input length. This will stream the dataset, and tokenize on the fly if the dataset isn't already tokenized (i.e. has an `input_ids` column). """stream:bool=False"""Whether to stream the dataset."""target_col:Optional[str]=None"""The dataset column name containing the input for training/testing/validation. Deprecated: This parameter is deprecated and will be removed in the future. """mixture_strategy:str=field(default=MixtureStrategy.FIRST_EXHAUSTED.value,metadata={"help":"The mixture strategy to use when multiple datasets are "f"provided. `{MixtureStrategy.FIRST_EXHAUSTED.value}` will sample from all ""datasets until exactly one dataset is completely represented in the "f"mixture. `{MixtureStrategy.ALL_EXHAUSTED.value}` will sample from all ""datasets until every dataset is completely represented in the "f"mixture. Note that `{MixtureStrategy.ALL_EXHAUSTED.value}` may result in ""significant oversampling. Defaults to "f"`{MixtureStrategy.FIRST_EXHAUSTED.value}`."},)"""The strategy for mixing multiple datasets. When multiple datasets are provided, this parameter determines how they are combined. Two strategies are available: 1. FIRST_EXHAUSTED: Samples from all datasets until one is fully represented in the mixture. This is the default strategy. 2. ALL_EXHAUSTED: Samples from all datasets until each one is fully represented in the mixture. This may lead to significant oversampling. """seed:Optional[int]=None"""The random seed used for mixing this dataset split, if specified. If set to `None` mixing will be non-deterministic. """use_async_dataset:bool=False"""Whether to use the PretrainingAsyncTextDataset instead of ConstantLengthDataset. Deprecated: This parameter is deprecated and will be removed in the future. """use_torchdata:Optional[bool]=None"""Whether to use the `torchdata` library for dataset loading and processing. If set to `None`, this setting may be auto-inferred. """
[docs]def__post_init__(self):"""Verifies params."""ifany([dataset.mixture_proportionisnotNonefordatasetinself.datasets]):ifnotall([dataset.mixture_proportionisnotNonefordatasetinself.datasets]):raiseValueError("If `mixture_proportion` is specified it must be "" specified for all datasets")mix_sum=sum(filter(None,[dataset.mixture_proportionfordatasetinself.datasets]))ifnotself._is_sum_normalized(mix_sum):raiseValueError("The sum of `mixture_proportion` must be 1.0. "f"The current sum is {mix_sum} .")if(self.mixture_strategy!=MixtureStrategy.ALL_EXHAUSTEDandself.mixture_strategy!=MixtureStrategy.FIRST_EXHAUSTED):raiseValueError("`mixture_strategy` must be one of "f'["{MixtureStrategy.FIRST_EXHAUSTED.value}", 'f'"{MixtureStrategy.ALL_EXHAUSTED.value}"].')
def_is_sum_normalized(self,mix_sum)->bool:# Note: the underlying interleave implementation requires# the mixture proportions to sum to 1.0.returnmath.isclose(mix_sum,1.0)
[docs]@dataclassclassDataParams(BaseParams):train:DatasetSplitParams=field(default_factory=DatasetSplitParams)"""The input datasets used for training."""test:DatasetSplitParams=field(default_factory=DatasetSplitParams)"""The input datasets used for testing. This field is currently unused."""validation:DatasetSplitParams=field(default_factory=DatasetSplitParams)"""The input datasets used for validation."""
[docs]defget_split(self,split:DatasetSplit)->DatasetSplitParams:"""A public getting for individual dataset splits."""ifsplit==DatasetSplit.TRAIN:returnself.trainelifsplit==DatasetSplit.TEST:returnself.testelifsplit==DatasetSplit.VALIDATION:returnself.validationelse:raiseValueError(f"Received invalid split: {split}.")
[docs]def__finalize_and_validate__(self):"""Verifies params."""iflen(self.train.datasets)==0:raiseValueError("At least one training dataset is required.")all_collators=set()ifself.train.collator_name:all_collators.add(self.train.collator_name)ifself.validation.collator_name:all_collators.add(self.validation.collator_name)ifself.test.collator_name:all_collators.add(self.test.collator_name)iflen(all_collators)>=2:raiseValueError(f"Different data collators are not supported yet: {all_collators}")eliflen(all_collators)==1andnotself.train.collator_name:raiseValueError("Data collator must be also specified "f"on the `train` split: {all_collators}")