Source code for oumi.core.collators.text_completions_collator_with_padding
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAnyimporttrlfromoumi.core.tokenizers.base_tokenizerimportBaseTokenizerfromoumi.utils.debug_utilsimportlog_example_for_debugging_INPUT_IDS_KEY="input_ids"
[docs]classTextCompletionsCollatorWithPadding:def__init__(self,tokenizer:BaseTokenizer,instruction_prefix:str,response_prefix:str,debug:bool=False,):"""Custom collator for text LLM training. Args: tokenizer: The tokenizer used for encoding the data. instruction_prefix: The prefix marking the beginning of the user instruction. response_prefix: The prefix marking the beginning of the assistant response. debug: If True, enables debug mode for logging. """self._default_collator=trl.DataCollatorForCompletionOnlyLM(tokenizer=tokenizer,instruction_template=instruction_prefix,response_template=response_prefix,)ifnothasattr(tokenizer,"pad_token_id")ortokenizer.pad_token_idisNone:raiseRuntimeError("Tokenizer doesn't define `pad_token_id`.")self._debug=debugself._has_logged_example=Falsedef_collate(self,inputs:list[Any])->dict[str,Any]:result=self._default_collator(inputs)returnresult
[docs]def__call__(self,batch)->dict[str,Any]:"""Pads to the longest length present in the batch. Args: batch: List of batch items. Returns: Dict[str, torch.Tensor]: Processed batch. """foriteminbatch:if_INPUT_IDS_KEYnotinitem:raiseValueError(f"Item doesn't contain '{_INPUT_IDS_KEY}' key. "f"Available keys: {item.keys()}")# Collate batch prompts.collated_text_inputs=self._collate(batch)ifself._debugandnotself._has_logged_example:# Log an example of the data in the first step for debugging purposes.self._log_debug_example(batch,collated_text_inputs)returncollated_text_inputs
def_log_debug_example(self,batch:list[dict[str,Any]],collated_text_inputs:dict[str,Any])->None:"""Logs an example of the data in each step for debugging purposes. Args: batch: The batch of examples to log. collated_text_inputs: The collated inputs after processing. """raw_example=batch[0]token_ids=raw_example[_INPUT_IDS_KEY]# Raw text without special tokensraw_text=self._default_collator.tokenizer.decode(token_ids,skip_special_tokens=True)# Formatted example with special tokensformatted_example=self._default_collator.tokenizer.decode(token_ids,skip_special_tokens=False)tokenized_ids=raw_example[_INPUT_IDS_KEY]tokenized_example=[(token_id,self._default_collator.tokenizer.decode([token_id]))fortoken_idintokenized_ids]self._has_logged_example=True# Extract the first example from the batched tensors for cleaner debug outputdef_to_py(x):"""Convert tensor-like objects to Python native types."""ifhasattr(x,"tolist"):returnx.tolist()elifhasattr(x,"item"):returnx.item()else:returnx# Process the collated inputs to get a clean representation for debuggingmodel_input={}forkey,valueincollated_text_inputs.items():# For batch tensors, extract just the first exampleifhasattr(value,"dim")andvalue.dim()>1:model_input[key]=_to_py(value[0])# For single tensors or other objectselse:model_input[key]=_to_py(value)# Log all components for debugginglog_example_for_debugging(raw_text,formatted_example,tokenized_example,model_input)