Source code for oumi.core.datasets.base_dpo_dataset
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportOptionalfromtyping_extensionsimportoverridefromoumi.core.datasets.base_map_datasetimportBaseMapDatasetfromoumi.core.tokenizers.base_tokenizerimportBaseTokenizer_PROMPT_KEY="prompt"_CHOSEN_KEY="chosen"_REJECTED_KEY="rejected"_ROLE="role"_CONTENT="content"_ASSISTANT="assistant"
[docs]classBaseExperimentalDpoDataset(BaseMapDataset):"""Preprocess the samples to the Oumi format. Warning: This class is experimental and subject to change. """def__init__(self,*,dataset_name:Optional[str]=None,dataset_path:Optional[str]=None,split:Optional[str]=None,tokenizer:Optional[BaseTokenizer]=None,return_tensors:bool=False,**kwargs,)->None:"""Initializes a new instance of the BaseExperimentalDpoDataset class."""super().__init__(dataset_name=dataset_name,dataset_path=dataset_path,split=split,**kwargs,)ifreturn_tensors:raiseNotImplementedError("return_tensors=True is not implemented for this class")self._tokenizer=tokenizerself._return_tensors=return_tensorsself._data=self._load_data()
[docs]deftransform_preference(self,samples:dict)->dict:"""Transform the samples to the Oumi format."""prompt=samples[_PROMPT_KEY]chosen_chat=samples[_CHOSEN_KEY]rejected_chat=samples[_REJECTED_KEY]chosen_chat_response=self._extract_from_chat_format(chosen_chat)rejected_chat_response=self._extract_from_chat_format(rejected_chat)return{_PROMPT_KEY:prompt,_CHOSEN_KEY:chosen_chat_response,_REJECTED_KEY:rejected_chat_response,}
[docs]@overridedeftransform(self,sample:dict)->dict:"""Transform the samples to the Oumi format."""returnself.transform_preference(sample)
def_extract_from_chat_format(self,sample:dict)->str:"""Extract the last 'assistant' turn in the chat."""forturninsample[::-1]:ifturn[_ROLE]==_ASSISTANT:returnturn[_CONTENT]raiseValueError("No chat turn was found with an 'assistant' role.")