Source code for oumi.core.collators.text_collator_with_padding
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcollectionsfromtypingimportAny,NamedTuple,Optionalfromoumi.core.tokenizers.base_tokenizerimportBaseTokenizerfromoumi.utils.debug_utilsimportlog_example_for_debuggingfromoumi.utils.loggingimportloggerfromoumi.utils.torch_utilsimport(create_ones_like,pad_sequences,pad_to_max_dim_and_stack,)_INPUT_IDS_KEY="input_ids"_ATTENTION_MASK_KEY="attention_mask"_CROSS_ATTENTION_MASK_KEY="cross_attention_mask"_LABELS_KEY="labels"class_SpecialTokens(NamedTuple):"""Special tokens used by VisionLanguageCollatorWithPadding."""pad_token_id:int"""Token id of `PAD` token."""label_ignore_index:Optional[int]"""If set, then `PAD` tokens will be replaced by this special value to exclude them from the loss computation. """
[docs]classTextCollatorWithPadding:def__init__(self,tokenizer:BaseTokenizer,*,max_length:Optional[int],truncation:bool=False,label_ignore_index:Optional[int]=None,max_variable_sized_dims:int=1,debug:bool=False,):"""Custom collator for text LLM training. Args: tokenizer: The tokenizer used for encoding the data. max_length: Padding length. truncation: Whether to truncate long inputs to `max_length`. If False, the long inputs are preserved as is even if they exceed `max_length`. Only has effect if `max_length` is specified. label_ignore_index: If set, then label values of tokens that shouldn't contribute to the loss computation will be replaced by this special value. max_variable_sized_dims: Maximum number of variable-sized dimensions. Normally, it's 1 (sequence length dimension), but can sometimes be higher e.g., 2 for "cross_attention_mask" for VLM-s with multi-image inputs. Negative value mean `Unlimited`. debug: Whether to log a debug example. """self._max_length:Optional[int]=(int(max_length)ifmax_lengthisnotNoneandmax_length>0elseNone)self._truncation:bool=bool(truncation)ifnothasattr(tokenizer,"padding_side")ornottokenizer.padding_side:raiseRuntimeError("Tokenizer doesn't define `padding_side`.")self._padding_side=str(tokenizer.padding_side)ifnothasattr(tokenizer,"pad_token_id")ortokenizer.pad_token_idisNone:raiseRuntimeError("Tokenizer doesn't define `pad_token_id`.")elifnotisinstance(tokenizer.pad_token_id,int):raiseRuntimeError("Tokenizer's `pad_token_id` is not an integer. "f"{tokenizer.pad_token_id}. Type: {type(tokenizer.pad_token_id)}")self._special_tokens:_SpecialTokens=_SpecialTokens(pad_token_id=int(tokenizer.pad_token_id),label_ignore_index=label_ignore_index,)self._max_input_ids_length:int=0self._max_previously_logged_input_ids_length:int=0self._max_variable_sized_dims:int=max_variable_sized_dimsself._debug:bool=debug# Track if we've already logged an exampleself._has_logged_example:bool=Falseself._tokenizer=tokenizer# Store tokenizer for debuggingdef_collate_simple(self,inputs_dict:dict[str,list[Any]],*,batch_max_length:int,padding_value_overrides:dict[str,int],)->dict[str,Any]:result:dict[str,Any]={}forkey,sequences_listininputs_dict.items():try:padding_value=padding_value_overrides.get(key,0)ifself._max_variable_sized_dims==1:collated_tensor=pad_sequences(sequences_list,padding_side=self._padding_side,padding_value=padding_value,)else:collated_tensor=pad_to_max_dim_and_stack(sequences_list,max_variable_sized_dims=self._max_variable_sized_dims,padding_side=self._padding_side,padding_value=padding_value,)result[key]=collated_tensorexceptException:logger.error(f"Failed to collate '{key}'! "f"Max variable size dims: {self._max_variable_sized_dims}, "f"Batch maximum length: {batch_max_length}, "f"Maximum allowed length: {self._max_length}, "f"Truncation: {self._truncation}.")raisereturnresult
[docs]def__call__(self,batch)->dict[str,Any]:"""Pads to the longest length present in the batch. Args: batch: List of batch items. Returns: Dict[str, torch.Tensor]: Processed batch. """collation_inputs:dict[str,list[Any]]=collections.defaultdict(list)labels_on=_LABELS_KEYinbatch[0]attention_mask_on=_ATTENTION_MASK_KEYinbatch[0]cross_attention_mask_on=_CROSS_ATTENTION_MASK_KEYinbatch[0]# Maximum sequence lengths in this batch.batch_max_input_ids_length:int=0foriteminbatch:if_INPUT_IDS_KEYnotinitem:raiseValueError(f"Item doesn't contain '{_INPUT_IDS_KEY}' key. "f"Available keys: {item.keys()}")batch_max_input_ids_length=max(batch_max_input_ids_length,len(item[_INPUT_IDS_KEY]))collation_inputs[_INPUT_IDS_KEY].append(item[_INPUT_IDS_KEY])collation_inputs[_ATTENTION_MASK_KEY].append(item[_ATTENTION_MASK_KEY]ifattention_mask_onelsecreate_ones_like(item[_INPUT_IDS_KEY]))ifcross_attention_mask_on:collation_inputs[_CROSS_ATTENTION_MASK_KEY].append(item[_CROSS_ATTENTION_MASK_KEY])iflabels_on:collation_inputs[_LABELS_KEY].append(item[_LABELS_KEY])ifself._max_lengthisnotNone:ifself._truncation:forkeyincollation_inputs:collation_inputs[key]=[item[0:self._max_length]foritemincollation_inputs[key]]else:forkeyincollation_inputs:foritemincollation_inputs[key]:seq_len=len(item)ifseq_len>self._max_length:raiseValueError("Maximum sequence length exceeded. ""You should probably activate truncation. "f"'{key}' length: ({seq_len}). "f"Maximum model length: ({self._max_length})")# Update global (dataset) maximum lengths, and log a warning# about truncation if needed.self._update_max_lengths_and_log(max_input_ids_length=batch_max_input_ids_length)# Collate batch prompts.pad_token_id=self._special_tokens.pad_token_idcollated_text_inputs=self._collate_simple(collation_inputs,batch_max_length=batch_max_input_ids_length,padding_value_overrides={_INPUT_IDS_KEY:pad_token_id,_LABELS_KEY:(self._special_tokens.label_ignore_indexifself._special_tokens.label_ignore_indexisnotNoneelsepad_token_id),},)# Combine all inputs.combined_batch={_INPUT_IDS_KEY:collated_text_inputs[_INPUT_IDS_KEY],_ATTENTION_MASK_KEY:collated_text_inputs.get(_ATTENTION_MASK_KEY),}ifcross_attention_mask_on:combined_batch[_CROSS_ATTENTION_MASK_KEY]=collated_text_inputs[_CROSS_ATTENTION_MASK_KEY]# Add labels if present.iflabels_on:combined_batch[_LABELS_KEY]=collated_text_inputs[_LABELS_KEY]# If debug is on and we haven't logged an example yet, log the first exampleifself._debugandnotself._has_logged_exampleandlen(batch)>0:# Log an example of the data in the first step for debugging purposes.self._log_debug_example(batch,combined_batch)returncombined_batch
def_log_debug_example(self,batch:list[dict[str,Any]],combined_batch:dict[str,Any],)->None:"""Logs a debug example if debug is enabled. Args: batch: The original batch of data. combined_batch: The collated batch after processing. """first_input_ids=combined_batch[_INPUT_IDS_KEY][0]formatted_example=self._tokenizer.decode(first_input_ids,skip_special_tokens=False)# Decode raw text without special tokens for raw exampleraw_text=self._tokenizer.decode(first_input_ids,skip_special_tokens=True)tokenized_example=[]fortidinfirst_input_ids:ifhasattr(tid,"item"):token_id=int(tid.item())decoded_token=self._tokenizer.decode([tid])else:token_id=int(tid)decoded_token=self._tokenizer.decode(tid)tokenized_example.append((token_id,decoded_token))model_input={"input_ids":(first_input_ids.tolist()ifhasattr(first_input_ids,"tolist")elsefirst_input_ids),"attention_mask":(combined_batch[_ATTENTION_MASK_KEY][0].tolist()ifhasattr(combined_batch[_ATTENTION_MASK_KEY][0],"tolist")elsecombined_batch[_ATTENTION_MASK_KEY][0]),}if_LABELS_KEYincombined_batch:lbl=combined_batch[_LABELS_KEY][0]model_input["labels"]=lbl.tolist()ifhasattr(lbl,"tolist")elselbl# Mark that we've logged an example to avoid logging againself._has_logged_example=Truelog_example_for_debugging(raw_example=raw_text,formatted_example=str(formatted_example),tokenized_example=tokenized_example,model_input=model_input,)def_update_max_lengths_and_log(self,*,max_input_ids_length:int):"""Updates max length counters. Also, logs a truncation warning if increment is large enough. """_LOG_REL_INCREMENT=0.1# log if max length is up 10%log_max_lengths:bool=Falseifmax_input_ids_length>self._max_input_ids_length:ifself._max_lengthisnotNoneandmax_input_ids_length>self._max_length:if(max_input_ids_length-self._max_previously_logged_input_ids_length)>=_LOG_REL_INCREMENT*self._max_previously_logged_input_ids_length:log_max_lengths=Trueself._max_previously_logged_input_ids_length=max_input_ids_lengthself._max_input_ids_length=max_input_ids_lengthiflog_max_lengths:logger.warning("Input sequence exceeded max length"+(" and truncated! "ifself._truncationelse". ")+(f"Max allowed length: {self._max_length}. "f"'input_ids' length: {self._max_input_ids_length}."))