Source code for oumi.core.datasets.base_iterable_dataset
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importabcfromcollections.abcimportIterablefromtypingimportAny,Optionalimportdatasetsfromtorch.utils.dataimportIterDataPipefromoumi.utils.loggingimportlogger
[docs]classBaseIterableDataset(IterDataPipe,abc.ABC):"""Abstract base class for iterable datasets."""dataset_name:strdataset_path:Optional[str]=Nonedefault_dataset:Optional[str]=Nonedefault_subset:Optional[str]=Nonetrust_remote_code:bool=Falsedef__init__(self,*,dataset_name:Optional[str]=None,dataset_path:Optional[str]=None,subset:Optional[str]=None,split:Optional[str]=None,trust_remote_code:bool=False,stream:bool=True,**kwargs,)->None:"""Initializes a new instance of the BaseIterableDataset class."""dataset_type_name=self.__class__.__name__logger.info(f"Creating iterable dataset (type: {dataset_type_name})...")iflen(kwargs)>0:logger.debug(f"Unknown arguments: {', '.join(kwargs.keys())}. ""Please check the class constructor for supported arguments "f"(type: {dataset_type_name}).")dataset_name=dataset_nameorself.default_datasetifdataset_nameisNone:raiseValueError("Please specify a dataset_name or ""set the default_dataset class attribute "f"(type: {dataset_type_name}).")self.dataset_name=dataset_nameself.dataset_path=dataset_pathself.dataset_subset=subsetorself.default_subsetself.split=splitself.trust_remote_code=trust_remote_codeself.stream=streamself._data=self._load_data()## Main API#
[docs]def__iter__(self):"""Iterates over the dataset."""foriteminself.data:yieldself.transform(item)
[docs]defiter_raw(self):"""Iterates over the raw dataset."""yield fromself.data
[docs]defto_hf(self,return_iterable:bool=True)->datasets.IterableDataset:"""Converts the dataset to a Hugging Face dataset."""ifnotreturn_iterable:raiseNotImplementedError("Only returning IterableDataset is supported.")returndatasets.IterableDataset.from_generator(self.__iter__)
@propertydefdata(self)->Iterable[Any]:"""Returns the underlying dataset data."""returnself._data## Abstract Methods#
[docs]@abc.abstractmethoddeftransform(self,sample:Any)->dict[str,Any]:"""Preprocesses the inputs in the given sample. Args: sample (Any): A sample from the dataset. Returns: dict: A dictionary containing the preprocessed input data. """raiseNotImplementedError
def_load_data(self)->Iterable[Any]:"""Loads the dataset from the specified source."""ifself.dataset_path:result=self._load_local_dataset(self.dataset_path)else:result=self._load_hf_hub_dataset()returnresultdef_load_hf_hub_dataset(self)->Iterable[Any]:"""Loads the dataset from the specified source."""returndatasets.load_dataset(path=self.dataset_name,name=self.dataset_subset,split=self.split,streaming=self.stream,trust_remote_code=self.trust_remote_code,)def_load_dataset_from_disk(self,path:str)->Iterable[Any]:returndatasets.Dataset.load_from_disk(path)