# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importtimefromtypingimportUnionimportpandasaspdimporttorchfromtorch.utils.dataimportDatasetfromtyping_extensionsimportoverridefromoumi.core.datasets.base_dpo_datasetimportBaseExperimentalDpoDatasetfromoumi.core.datasets.base_pretraining_datasetimportBasePretrainingDatasetfromoumi.core.datasets.base_sft_datasetimportBaseSftDatasetfromoumi.core.registryimportregister_datasetfromoumi.core.types.conversationimportConversation,Message,Role
[docs]@register_dataset("debug_classfication")classDebugClassificationDataset(Dataset):def__init__(self,dataset_size:int=1000,feature_dim:int=128,data_type:str="float32",num_classes:int=10,preprocessing_time_ms:float=0,**kwargs,):"""Initialize a DebugClassificationDataset. This dataset generates random data and labels for debugging purposes. Args: dataset_size: The size of the dataset. feature_dim: The dimension of each feature. data_type: The data type of the dataset. Supported values are "float32", "float16", and "bfloat16". num_classes: The number of classes in the dataset. preprocessing_time_ms: The time taken for preprocessing in milliseconds. **kwargs: Additional keyword arguments. Raises: ValueError: If the data_type is not supported. """self.size=dataset_sizeself.feature_dim=feature_dimself.data_type=data_typeself.num_classes=num_classesself.preprocessing_time_ms=preprocessing_time_msifself.data_type=="float32":dtype=torch.float32elifself.data_type=="float16":dtype=torch.float16elifself.data_type=="bfloat16":dtype=torch.bfloat16else:raiseValueError(f"Unsupported data type: {self.data_type}")self.data=torch.randn(self.size,self.feature_dim,dtype=dtype)self.labels=torch.randint(0,self.num_classes,(self.size,))
[docs]def__len__(self):"""Return the size of the dataset."""returnself.size
[docs]def__getitem__(self,idx):"""Return the data and label at the given index."""ifself.preprocessing_time_ms>0:time.sleep(self.preprocessing_time_ms*1000)return{"features":self.data[idx],"labels":self.labels[idx]}
[docs]@register_dataset("debug_pretraining")classDebugPretrainingDataset(BasePretrainingDataset):default_dataset="debug_pretraining"def__init__(self,dataset_size:int=1000,**kwargs,):"""Initializes a DebugPretrainingDataset. Args: dataset_size: The size of the dataset. **kwargs: Additional keyword arguments. """self.size=dataset_sizesuper().__init__(**kwargs)def_load_data(self)->list[dict]:return[{"text":f"This is document number {idx}."}foridxinrange(self.size)]
[docs]@register_dataset("debug_sft")classDebugSftDataset(BaseSftDataset):default_dataset="debug_sft"def__init__(self,dataset_size:int=5,**kwargs,):"""Initializes a DebugSftDataset."""self.size=dataset_sizesuper().__init__(**kwargs)
[docs]deftransform_conversation(self,example:Union[dict,pd.Series])->Conversation:"""Transforms the example into a Conversation object."""returnConversation(messages=[Message(role=Role.USER,content=(example.get("user_message",None)or"")),Message(role=Role.ASSISTANT,content=(example.get("assistant_message",None)or""),),])
@overridedef_load_data(self)->pd.DataFrame:returnpd.DataFrame({"user_message":[f"Hello, how are you? (Document number {idx})"foridxinrange(self.size)],"assistant_message":[f"I'm fine, thank you! (Document number {idx})"foridxinrange(self.size)],})
@register_dataset("debug_dpo")classDebugDpoDataset(BaseExperimentalDpoDataset):default_dataset="debug_dpo"def__init__(self,dataset_size:int=5,**kwargs,):"""Initializes a DebugSftDataset."""self.size=dataset_sizesuper().__init__(**kwargs)deftransform_preference(self,sample:dict)->dict:"""Transforms the sample into a preference dict."""return{"prompt":sample["prompt"],"chosen":sample["chosen"],"rejected":sample["rejected"],}@overridedef_load_data(self)->pd.DataFrame:returnpd.DataFrame({"prompt":[f"Hello, how are you? (Document number {idx})"foridxinrange(self.size)],"chosen":[f"I'm fine, thank you! (Document number {idx})"foridxinrange(self.size)],"rejected":[f"fine (Document number {idx})"foridxinrange(self.size)],})