Source code for oumi.core.datasets.base_map_dataset
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importgcimportmathimportosimporttimefromabcimportABC,abstractmethodfromcollections.abcimportGenerator,Iterable,SizedfrompathlibimportPathfromtypingimportAny,NamedTuple,Optional,Union,castimportdatasetsimportpandasaspdfromtorch.utils.dataimportMapDataPipefromoumi.utils.hf_utilsimportis_cached_to_disk_hf_datasetfromoumi.utils.loggingimportloggerfromoumi.utils.torch_utilsimportestimate_sample_dict_size_in_bytes,get_shape_as_listclass_ExamplesIndicesRange(NamedTuple):"""A valid sub-range of example indices."""start_index:intend_index:intclass_InferredFeatureMap(NamedTuple):feature_map:datasets.Features"""Inferred feature map."""is_feature_map_optimized:bool"""Indicates whether the original feature map was optimized. In optimized feature maps, large features use the inferred `ArrayXD` arrow column type (not `sequence`) which supports larger datasets with more elements. """element_size_in_bytes:int"""Estimated element size in bytes."""multimodal:bool"""Whether the features are multimodal."""
[docs]classBaseMapDataset(MapDataPipe,Sized,ABC):"""Abstract base class for map datasets."""_data:pd.DataFramedataset_name:strdataset_path:Optional[str]=Nonedefault_dataset:Optional[str]=Nonedefault_subset:Optional[str]=Nonetrust_remote_code:booltransform_num_workers:Optional[Union[str,int]]=Nonedef__init__(self,*,dataset_name:Optional[str],dataset_path:Optional[str]=None,subset:Optional[str]=None,split:Optional[str]=None,trust_remote_code:bool=False,transform_num_workers:Optional[Union[str,int]]=None,**kwargs,)->None:"""Initializes a new instance of the BaseDataset class."""dataset_type_name=self.__class__.__name__iflen(kwargs)>0:logger.debug(f"Unknown arguments: {', '.join(kwargs.keys())}. ""Please check the class constructor for supported arguments "f"(type: {dataset_type_name}).")dataset_name=dataset_nameorself.default_datasetlogger.info(f"Creating map dataset (type: {dataset_type_name})..."+(f" dataset_name: '{dataset_name}'"ifdataset_nameelse"")+(f" dataset_path: '{dataset_path}'"ifdataset_pathelse""))ifdataset_nameisNone:raiseValueError("Please specify a dataset_name or ""set the default_dataset class attribute "f"(type: {dataset_type_name}).")self.dataset_name=dataset_nameself.dataset_path=dataset_pathself.dataset_subset=subsetorself.default_subsetself.split=splitself.trust_remote_code=trust_remote_codeself.transform_num_workers=transform_num_workers## Main API#
[docs]def__getitem__(self,idx:int)->dict:"""Gets the item at the specified index. Args: idx (int): The index of the item to retrieve. Returns: dict: The item at the specified index. """sample=self.raw(idx)processed=self.transform(sample)returnprocessed
[docs]def__len__(self)->int:"""Gets the number of items in the dataset. Returns: int: The number of items in the dataset. """returnlen(self._data)
@propertydefdata(self)->pd.DataFrame:"""Returns the underlying dataset data."""returnself._data
[docs]defraw(self,idx:int)->pd.Series:"""Returns the raw data at the specified index. Args: idx (int): The index of the data to retrieve. Returns: pd.Series: The raw data at the specified index. """returnself._data.iloc[idx]
[docs]defas_generator(self)->Generator[dict[str,Any],None,None]:"""Returns a generator for the dataset."""foridxinrange(len(self)):yieldself[idx]
def_as_generator_over_shards(self,shards:list[_ExamplesIndicesRange])->Generator[dict[str,Any],None,None]:"""Returns a sharded generator for the dataset."""forshardinshards:foridxinrange(shard.start_index,shard.end_index):yieldself[idx]def_detect_features_and_estimate_element_size_bytes(self,samples_iter:Iterable[dict[str,Any]])->_InferredFeatureMap:"""Returns an estimate of max element size in bytes."""samples_list=list(samples_iter)def_dummy_generator():yield fromsamples_listsample_dataset=cast(datasets.Dataset,datasets.Dataset.from_generator(_dummy_generator,keep_in_memory=True),)iflen(sample_dataset)<=0:raiseValueError("Empty sample dataset!")max_elem_bytes=max([estimate_sample_dict_size_in_bytes(elem)foreleminsamples_list])features=sample_dataset.features.copy()is_feature_map_optimized:bool=Falseis_multimodal:bool=False# At this time, we care mostly about `pixel_values` as it's by far the largest# feature (e.g., 15MB for Llama 3.2 Vision), which causes serialization errors# for large datasets if saved in the default format, which is# a nested sequence (of sequences (of sequences ...)).# TODO: Tune feature types for other features for efficiency.if"pixel_values"insamples_list[0]:is_multimodal=Trueinferred_features=[]variable_shapes_detected:bool=Falseforeleminsamples_list:shape=tuple(get_shape_as_list(elem["pixel_values"]))shape_dims=len(shape)ifshape_dims==2:feature_def=datasets.Array2D(dtype="float32",shape=shape)elifshape_dims==3:feature_def=datasets.Array3D(dtype="float32",shape=shape)elifshape_dims==4:feature_def=datasets.Array4D(dtype="float32",shape=shape)elifshape_dims==5:feature_def=datasets.Array5D(dtype="float32",shape=shape)else:raiseValueError("The `pixel_values` feature has unsupported dimensionality "f"({shape_dims}D). Must be 2D...5D.")inferred_features.append(feature_def)foriinrange(1,len(samples_list)):if(type(inferred_features[i-1]),inferred_features[i-1].dtype,inferred_features[i-1].shape,)!=(type(inferred_features[i]),inferred_features[i].dtype,inferred_features[i].shape,):variable_shapes_detected=Truelogger.warning(f"The `pixel_values` feature has variable shapes: "f"{inferred_features[i-1]} vs {inferred_features[i]}!")ifnotvariable_shapes_detected:# Re-define the feature to be `ArrayXD`# if all shapes are the same.features["pixel_values"]=inferred_features[0]is_feature_map_optimized=Truelogger.info("The `pixel_values` feature has this inferred type: "f"{inferred_features[0]}")delsample_datasetreturn_InferredFeatureMap(feature_map=features,is_feature_map_optimized=is_feature_map_optimized,element_size_in_bytes=max_elem_bytes,multimodal=is_multimodal,)def_compute_effective_transform_num_workers(self)->int:"""Returns an effective number of dataset transform workers. Guaranteed to be a positive integer (>= 1). 1 if no parallelism is used. """num_proc=Noneifself.transform_num_workersisnotNone:ifisinstance(self.transform_num_workers,int):num_proc=self.transform_num_workerselifself.transform_num_workers=="auto":num_proc=os.cpu_count()ifnum_procisnotNone:# Limit the max number of sub-processes.num_proc=min(8,num_proc)assertnum_procisNoneornum_proc>0,(f"transform_num_workers: {self.transform_num_workers}")num_proc=max(1,num_procifnum_procisnotNoneelse1)assertnum_proc>=1returnnum_proc
[docs]defto_hf(self,return_iterable:bool=False)->Union[datasets.Dataset,datasets.IterableDataset]:"""Converts the dataset to a Hugging Face dataset. Args: return_iterable: Whether to return an iterable dataset. Iterable datasets aren't cached to disk, which can sometimes be advantageous. For example, if transformed examples are very large (e.g., if `pixel_values` are large for multimodal data), or if you don't want to post-process the whole dataset before training starts. Returns: A HuggingFace dataset. Can be `datasets.Dataset` or `datasets.IterableDataset` depending on the value of `return_iterable`. """_MAX_SHARD_SIZE=1*1024*1024*1024# ~1GBdataset_type_name=self.__class__.__name__num_proc=self._compute_effective_transform_num_workers()total_examples=len(self)output_features:_InferredFeatureMap=(self._detect_features_and_estimate_element_size_bytes(self._as_generator_over_shards([_ExamplesIndicesRange(start_index=i,end_index=(i+1))foriinrange(0,total_examples,max(1,total_examples//8))])))elements_per_shard:int=int(math.ceil(float(total_examples)/num_proc))ifoutput_features.element_size_in_bytes>0:elements_per_shard=min(elements_per_shard,_MAX_SHARD_SIZE//output_features.element_size_in_bytes,)# Clamp `writer_batch_size` to [1, 200/1000] range.writer_batch_size=max(1,min(elements_per_shard,200ifoutput_features.multimodalelse1000))logger.info(f"{dataset_type_name}: features={output_features.feature_map.keys()}")logger.debug(f"{dataset_type_name}: features={output_features} "f"examples={total_examples} "f"writer_batch_size={writer_batch_size} num_proc={num_proc}")# If feature map isn't "optimized" then ignore it to fallback# to the default behavior in `from_generator()`.feature_map=(output_features.feature_mapifoutput_features.is_feature_map_optimizedelseNone)start_time=time.perf_counter()ifnum_proc>1or(output_features.element_size_in_bytes*total_examples>_MAX_SHARD_SIZE):starts:list[int]=list(range(0,total_examples,writer_batch_size,))stops:list[int]=starts[1:]+[total_examples]shards:list[_ExamplesIndicesRange]=[_ExamplesIndicesRange(start_index=item[0],end_index=item[1])foriteminzip(starts,stops)]ifreturn_iterable:result=datasets.IterableDataset.from_generator(self._as_generator_over_shards,gen_kwargs={"shards":shards},features=feature_map,)else:result=datasets.Dataset.from_generator(self._as_generator_over_shards,gen_kwargs={"shards":shards},keep_in_memory=False,num_proc=(num_procifnum_proc>1elseNone),features=feature_map,writer_batch_size=writer_batch_size,)else:ifreturn_iterable:result=datasets.IterableDataset.from_generator(self.as_generator,features=feature_map,)else:result=datasets.Dataset.from_generator(self.as_generator,keep_in_memory=False,features=feature_map,writer_batch_size=writer_batch_size,)duration_sec=time.perf_counter()-start_timelogger.info(f"Finished transforming dataset ({dataset_type_name})! "f"Speed: {total_examples/duration_sec:.2f} examples/sec. "f"Examples: {total_examples}. "f"Duration: {duration_sec:.1f} sec. Transform workers: {num_proc}.")ifreturn_iterable:result=cast(datasets.IterableDataset,result)logger.debug(f"{dataset_type_name}: IterableDataset: {result}")else:result=cast(datasets.Dataset,result)logger.debug(f"{dataset_type_name}: MapDataset: {result}\n\n"f"Arrow schema: {result.features.arrow_schema}")returnresult
## Abstract Methods#
[docs]@abstractmethoddeftransform(self,sample:pd.Series)->dict:"""Preprocesses the inputs in the given sample. Args: sample (dict): A dictionary containing the input data. Returns: dict: A dictionary containing the preprocessed input data. """raiseNotImplementedError
## Data Loading#def_load_data(self)->pd.DataFrame:"""Loads the dataset from the specified source. Returns: dict: The loaded dataset. """ifself.dataset_path:result=self._load_local_dataset(self.dataset_path)else:result=self._load_hf_hub_dataset()# Reclaim memory after data loading.gc.collect()logger.info(f"Loaded DataFrame with shape: {result.shape}. Columns:\n{result.dtypes}")returnresultdef_load_local_dataset(self,path:str)->pd.DataFrame:"""Loads the dataset from the specified local source. Returns: dict: The loaded dataset. """dataset_path=Path(path)ifnotdataset_path.exists():raiseFileNotFoundError(f"File not found: {dataset_path}")ifdataset_path.suffix.lower()==".jsonl"anddataset_path.is_file():result=self._load_jsonl_dataset(dataset_path)elifdataset_path.suffix.lower()==".parquet"anddataset_path.is_file():result=self._load_parquet_dataset(dataset_path)elifis_cached_to_disk_hf_dataset(dataset_path):result=self._load_dataset_from_disk(dataset_path)else:raiseValueError(f"File format not supported for {self.dataset_name}")returnresultdef_load_hf_hub_dataset(self)->pd.DataFrame:"""Loads the dataset from the specified Hugging Face Hub source. Returns: dict: The loaded dataset. """splits_or_dataset=datasets.load_dataset(path=self.dataset_name,name=self.dataset_subset,split=self.split,trust_remote_code=self.trust_remote_code,)ifisinstance(splits_or_dataset,(datasets.IterableDataset,datasets.IterableDatasetDict)):raiseValueError("IterableDataset is not supported with this class.")# Grab a single dataset splitifisinstance(splits_or_dataset,datasets.Dataset):dataset=splits_or_datasetelifself.splitisnotNone:dataset=splits_or_dataset[self.split]eliflen(splits_or_dataset)==1:dataset=splits_or_dataset.values().__iter__().__next__()else:raiseValueError("Multiple splits found in the dataset. Please specify a single split. "f"Available splits: {list(splits_or_dataset.keys())}")logger.info("\n".join(["Dataset Info:",f"\tSplit: {dataset.split}",f"\tVersion: {dataset.version}",f"\tDataset size: {dataset.dataset_size}",f"\tDownload size: {dataset.download_size}",f"\tSize: {dataset.size_in_bytes} bytes",f"\tRows: {len(dataset)}",f"\tColumns: {dataset.column_names}",]))result=dataset.to_pandas()deldatasetreturncast(pd.DataFrame,result)def_load_jsonl_dataset(self,path:Path)->pd.DataFrame:returnpd.read_json(path,lines=True)def_load_parquet_dataset(self,path:Path)->pd.DataFrame:returnpd.read_parquet(path)def_load_dataset_from_disk(self,path:Path)->pd.DataFrame:dataset:datasets.Dataset=datasets.Dataset.load_from_disk(path)result=dataset.to_pandas()deldatasetreturncast(pd.DataFrame,result)