Source code for oumi.core.collators.vision_language_collator_with_padding
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcollectionsfromtypingimportAny,Optionalimporttorchfromoumi.core.collators.text_collator_with_paddingimportTextCollatorWithPaddingfromoumi.core.tokenizers.base_tokenizerimportBaseTokenizerfromoumi.utils.torch_utilsimportpad_to_max_dim_and_stack
[docs]classVisionLanguageCollatorWithPadding:def__init__(self,tokenizer:BaseTokenizer,*,max_length:Optional[int],truncation:bool=False,label_ignore_index:Optional[int]=None,allow_multi_image_inputs:bool=True,main_image_feature:str="pixel_values",):"""Custom collator for multi-modal vision-language training. Args: tokenizer: The tokenizer used for encoding the data. max_length: Padding length. truncation: Whether to truncate long inputs to `max_length`. If False, the long inputs are preserved as is even if they exceed `max_length`. Only has effect if `max_length` is specified. label_ignore_index: If set, then label values of tokens that shouldn't contribute to the loss computation will be replaced by this special value. allow_multi_image_inputs: Whether to allow multi-image inputs. main_image_feature: The key to use for fetching the main image data (e.g., raw pixels, patches, etc.) from the input. """self._allow_multi_image_inputs=allow_multi_image_inputsself._main_image_feature=main_image_featureself._text_collator:TextCollatorWithPadding=TextCollatorWithPadding(tokenizer=tokenizer,max_length=max_length,truncation=truncation,label_ignore_index=label_ignore_index,max_variable_sized_dims=(# if multi-image inputs are possible, then# allow 2 variable-sized dimensions: `seq_len`, `num_images`.2ifallow_multi_image_inputselse1),)
[docs]def__call__(self,batch)->dict[str,Any]:"""Custom collator for multi-modal vision-language training. Args: batch: List of batch items. Returns: Dict[str, torch.Tensor]: Processed batch. """# Collate batch promptscollated_batch=self._text_collator(batch)# type: ignoreknown_input_names:set[str]=set(collated_batch.keys()).union({self._main_image_feature})other_input_names:set[str]=set()images=[]foriteminbatch:# TODO Consider relaxing this constraint: a vision/language model# can handle text-only inputs e.g., a follow-up to an answer,# or image-only inputs e.g., captioning.ifself._main_image_featurenotinitem:raiseValueError(f"Item doesn't contain '{self._main_image_feature}' key. "f"Available keys: {item.keys()}")images.append(item[self._main_image_feature])forkeyinitem:if(keyand(keynotinknown_input_names)and(keynotinother_input_names)):other_input_names.add(key)# Collate images.image_input_features=self.collate_images(images)# Add images to other inputs.collated_batch[self._main_image_feature]=image_input_features# For other inputs, let's verify they present in all examples and stack them.iflen(other_input_names)>0:other_inputs:dict[str,list[Any]]=collections.defaultdict(list)foriteminbatch:forinput_nameinother_input_names:ifinput_namenotinitem:raiseValueError(f"Item doesn't contain '{input_name}' key. "f"Available keys: {item.keys()}")other_inputs[input_name].append(item[input_name])forinput_name,values_listinother_inputs.items():collated_value=pad_to_max_dim_and_stack(values_list,max_variable_sized_dims=(# if multi-image inputs are possible, then# allow 1 variable-sized dimension (`num_images`).1ifself._allow_multi_image_inputselse0),)collated_batch[input_name]=collated_valuereturncollated_batch
[docs]defcollate_images(self,images)->torch.Tensor:"""Collate images for multi-modal training. Args: images: List of images to collate. Returns: torch.Tensor: Batch of processed images. """iflen(images)==0:raiseValueError("No images found in the batch")returnpad_to_max_dim_and_stack(images,max_variable_sized_dims=(# if multi-image inputs are possible, then# allow 1 variable-sized dimension (`num_images`).1ifself._allow_multi_image_inputselse0),)