Source code for oumi.core.collators.vision_language_collator_with_padding
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from typing import Any, Optional
import torch
from oumi.core.collators.text_collator_with_padding import TextCollatorWithPadding
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.torch_utils import pad_to_max_dim_and_stack
[docs]
class VisionLanguageCollatorWithPadding:
def __init__(
self,
tokenizer: BaseTokenizer,
*,
max_length: Optional[int],
truncation: bool = False,
label_ignore_index: Optional[int] = None,
allow_multi_image_inputs: bool = True,
main_image_feature: str = "pixel_values",
):
"""Custom collator for multi-modal vision-language training.
Args:
tokenizer: The tokenizer used for encoding the data.
max_length: Padding length.
truncation: Whether to truncate long inputs to `max_length`.
If False, the long inputs are preserved as is even if they exceed
`max_length`. Only has effect if `max_length` is specified.
label_ignore_index: If set, then label values of tokens that shouldn't
contribute to the loss computation will be replaced by this special value.
allow_multi_image_inputs: Whether to allow multi-image inputs.
main_image_feature: The key to use for fetching the main image data
(e.g., raw pixels, patches, etc.) from the input.
"""
self._allow_multi_image_inputs = allow_multi_image_inputs
self._main_image_feature = main_image_feature
self._text_collator: TextCollatorWithPadding = TextCollatorWithPadding(
tokenizer=tokenizer,
max_length=max_length,
truncation=truncation,
label_ignore_index=label_ignore_index,
max_variable_sized_dims=(
# if multi-image inputs are possible, then
# allow 2 variable-sized dimensions: `seq_len`, `num_images`.
2 if allow_multi_image_inputs else 1
),
)
[docs]
def __call__(self, batch) -> dict[str, Any]:
"""Custom collator for multi-modal vision-language training.
Args:
batch: List of batch items.
Returns:
Dict[str, torch.Tensor]: Processed batch.
"""
# Collate batch prompts
collated_batch = self._text_collator(batch) # type: ignore
known_input_names: set[str] = set(collated_batch.keys()).union(
{self._main_image_feature}
)
other_input_names: set[str] = set()
images = []
for item in batch:
# TODO Consider relaxing this constraint: a vision/language model
# can handle text-only inputs e.g., a follow-up to an answer,
# or image-only inputs e.g., captioning.
if self._main_image_feature not in item:
raise ValueError(
f"Item doesn't contain '{self._main_image_feature}' key. "
f"Available keys: {item.keys()}"
)
images.append(item[self._main_image_feature])
for key in item:
if (
key
and (key not in known_input_names)
and (key not in other_input_names)
):
other_input_names.add(key)
# Collate images.
image_input_features = self.collate_images(images)
# Add images to other inputs.
collated_batch[self._main_image_feature] = image_input_features
# For other inputs, let's verify they present in all examples and stack them.
if len(other_input_names) > 0:
other_inputs: dict[str, list[Any]] = collections.defaultdict(list)
for item in batch:
for input_name in other_input_names:
if input_name not in item:
raise ValueError(
f"Item doesn't contain '{input_name}' key. "
f"Available keys: {item.keys()}"
)
other_inputs[input_name].append(item[input_name])
for input_name, values_list in other_inputs.items():
collated_value = pad_to_max_dim_and_stack(
values_list,
max_variable_sized_dims=(
# if multi-image inputs are possible, then
# allow 1 variable-sized dimension (`num_images`).
1 if self._allow_multi_image_inputs else 0
),
)
collated_batch[input_name] = collated_value
return collated_batch
[docs]
def collate_images(self, images) -> torch.Tensor:
"""Collate images for multi-modal training.
Args:
images: List of images to collate.
Returns:
torch.Tensor: Batch of processed images.
"""
if len(images) == 0:
raise ValueError("No images found in the batch")
return pad_to_max_dim_and_stack(
images,
max_variable_sized_dims=(
# if multi-image inputs are possible, then
# allow 1 variable-sized dimension (`num_images`).
1 if self._allow_multi_image_inputs else 0
),
)