# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Generic class for using HuggingFace datasets with input/output columns.Allows users to specify the prompt and response columns at the config level."""fromtypingimportUnionimportpandasaspdfromoumi.core.datasetsimportBaseSftDatasetfromoumi.core.registryimportregister_datasetfromoumi.core.types.conversationimportConversation,Message,Role
[docs]@register_dataset("PromptResponseDataset")classPromptResponseDataset(BaseSftDataset):"""Converts HuggingFace Datasets with input/output columns to Message format. Example: dataset = PromptResponseDataset(hf_dataset_path="O1-OPEN/OpenO1-SFT", prompt_column="instruction", response_column="output") """default_dataset="O1-OPEN/OpenO1-SFT"def__init__(self,*,hf_dataset_path:str="O1-OPEN/OpenO1-SFT",prompt_column:str="instruction",response_column:str="output",**kwargs,)->None:"""Initializes a new instance of the PromptResponseDataset class."""self.prompt_column=prompt_columnself.response_column=response_columnkwargs["dataset_name"]=hf_dataset_pathsuper().__init__(**kwargs)
[docs]deftransform_conversation(self,example:Union[dict,pd.Series])->Conversation:"""Preprocesses the inputs of the example and returns a dictionary. Args: example (dict or Pandas Series): An example containing `input` (optional), `instruction`, and `output` entries. Returns: dict: The input example converted to messages dictionary format. """messages=[]user_prompt=str(example[self.prompt_column])messages.append(Message(role=Role.USER,content=user_prompt))ifself.response_column:model_output=str(example[self.response_column])messages.append(Message(role=Role.ASSISTANT,content=model_output))returnConversation(messages=messages)