Source code for oumi.core.datasets.base_sft_dataset
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importrefromabcimportABC,abstractmethodfromtypingimportLiteral,Optional,Union,castimportpandasaspdfromtyping_extensionsimportoverridefromoumi.core.datasets.base_map_datasetimportBaseMapDatasetfromoumi.core.tokenizersimportBaseTokenizerfromoumi.core.tokenizers.utilsimport(tokenize_for_completions_only_training_with_prefix,tokenize_for_completions_only_training_with_template,)fromoumi.core.types.conversationimportConversationfromoumi.utils.loggingimportlogger
[docs]classBaseSftDataset(BaseMapDataset,ABC):"""In-memory dataset for SFT data."""default_dataset=Nonedef__init__(self,*,dataset_name:Optional[str]=None,dataset_path:Optional[str]=None,split:Optional[str]=None,tokenizer:Optional[BaseTokenizer]=None,task:Literal["sft","generation","auto"]="auto",return_tensors:bool=False,text_col:str="text",assistant_only:bool=False,response_template:Optional[str]=None,instruction_template:Optional[str]=None,return_conversations:bool=False,**kwargs,)->None:"""Initializes a new instance of the BaseSftDataset class."""super().__init__(dataset_name=dataset_name,dataset_path=dataset_path,split=split,**kwargs,)self._task=taskself._text_col=text_colself._tokenizer=tokenizerself._return_tensors="pt"ifreturn_tensorselseNoneself._assistant_only=assistant_onlyself._response_template=response_templateself._instruction_template=instruction_templateself._return_conversations=return_conversationsifself._assistant_only:self._verify_assistant_only_compatibility()self._data=self._load_data()## Properties#@propertydeftext_col(self)->str:"""Gets the text target column. The generated text will be stored in this column. """returnself._text_col@propertydeftask(self)->str:"""Gets the task mode for the dataset. The generated prompt is often different for generation vs SFT tasks. """returnself._task@propertydefassistant_only(self)->bool:"""Gets whether the dataset is set to train only on assistant turns."""returnself._assistant_only## Main API#
[docs]defprompt(self,idx:int)->str:"""Returns the prompt at the specified index. Args: idx (int): The index of the prompt to retrieve. Returns: str: The prompt at the specified index. """returnself.tokenize(self.conversation(idx),tokenize=False)[self.text_col]
[docs]defconversation(self,idx:int)->Conversation:"""Returns the conversation at the specified index. Args: idx (int): The index of the conversation to retrieve. Returns: str: The conversation at the specified index. """sample=self.raw(idx)returnself.transform_conversation(sample)
[docs]defconversations(self)->list[Conversation]:"""Returns a list of all conversations."""indexes=range(len(self))return[self.conversation(index)forindexinindexes]
## Abstract Methods#
[docs]@abstractmethoddeftransform_conversation(self,example:Union[dict,pd.Series])->Conversation:"""Preprocesses the inputs of the example and returns a dictionary. Args: example (dict): The example containing the input and instruction. Returns: dict: The preprocessed inputs as a dictionary. """raiseNotImplementedError
## Pre-processing#
[docs]@overridedeftransform(self,sample:pd.Series)->dict:"""Preprocesses the inputs in the given sample."""conversation=self.transform_conversation(sample)ifself._return_conversations:# This may require `use_torchdata=True` for TRL_SFT trainer,# but compatible with TRL_GRPO trainer.conversation_json=conversation.to_json()return{"conversation_json":conversation_json}returnself.tokenize(conversation)
[docs]deftokenize(self,sample:Union[dict,pd.Series,Conversation],tokenize:bool=True,)->dict:"""Applies the chat template carried by the tokenizer to the input example. Args: sample (Dict): Mapping `messages` to a List containing the (ordered) messages exchanged within a single chat dialogue. Each item of example["messages"] is a dict mapping the `content` of the message and the `role` of the one relayed it. E.g., role == 'user' or role == 'assistant'. tokenize (bool): Whether to tokenize the messages or not. Raises: NotImplementedError: Currently only the `sft` task mode is supported. ValueError: if requested `task` is not in "sft" or "generation" Returns: Dict: It adds a `text` key in the input `example` dictionary, mapped to a string carrying the `messages` to the tokenizer's chat format. """ifself._tokenizerisNone:raiseValueError("Tokenizer is required for tokenization.")ifisinstance(sample,Conversation):conversation=sampleelse:ifisinstance(sample,pd.Series):sample=sample.to_dict()ifisinstance(sample,dict)and"messages"insample:conversation=Conversation.from_dict(sample)else:raiseValueError("Input samples must be a Conversation or a dict with ""'messages' key.")ifnotself._assistant_onlyornottokenize:returnself._tokenize(conversation,tokenize)ifself._is_template_compatible_with_completions_only_training:returntokenize_for_completions_only_training_with_template(tokenizer=self._tokenizer,conversation=conversation,)else:returntokenize_for_completions_only_training_with_prefix(tokenizer=self._tokenizer,conversation=conversation,response_template=cast(str,self._response_template),instruction_template=cast(str,self._instruction_template),response_token_ids=self.response_token_ids,instruction_token_ids=self.instruction_token_ids,)
def_tokenize(self,sample:Union[dict,pd.Series,Conversation],tokenize:bool=True)->dict:ifself._tokenizerisNone:raiseValueError("Tokenizer is required for tokenization.")results=self._tokenizer.apply_chat_template(sample,# type: ignoretokenize=tokenize,return_dict=tokenize,return_tensors=self._return_tensors,max_length=self._tokenizer.model_max_length,truncation=True,add_generation_prompt=(self.task=="generation"),)iftokenize:returncast(dict,results)else:return{self.text_col:results,}def_verify_assistant_only_compatibility(self)->None:ifself._tokenizerisNone:raiseValueError("Tokenizer is required to enable tokenization ""for training on assistant-only turns.")ifself._tokenizer.chat_templateisNone:raiseValueError("Tokenizer must have a chat template to enable ""tokenization for training on assistant-only turns.")template:str=self._tokenizer.chat_template# type: ignoreifre.search(r"\{\%-?\s*generation\s*-?\%\}",template):logger.info("Tokenizer template contains `{% generation %}` keyword. ""We will use it for completions-only training.")self._is_template_compatible_with_completions_only_training=Trueelse:if(self._response_templateisNoneorlen(self._response_template.strip())==0):raiseValueError("Response template is required for completions-only training.")ifself._response_template.strip()!=self._response_template:logger.warning(f"Response template '{self._response_template}' contains ""leading or trailing whitespaces. These will be ignored.")self._response_template=self._response_template.strip()if(self._instruction_templateisNoneorlen(self._instruction_template.strip())==0):raiseValueError("Instruction template is required for completions-only training.")ifself._instruction_template.strip()!=self._instruction_template:logger.warning(f"Instruction template '{self._instruction_template}' contains ""leading or trailing whitespaces. These will be ignored.")self._instruction_template=self._instruction_template.strip()self.response_token_ids=self._tokenizer.encode(self._response_template,add_special_tokens=False)self.instruction_token_ids=self._tokenizer.encode(self._instruction_template,add_special_tokens=False)self._is_template_compatible_with_completions_only_training=False