Source code for oumi.core.collators.vision_language_sft_collator
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from oumi.core.feature_generators import (
FeatureGeneratorOptions,
VisionLanguageConversationFeatureGenerator,
)
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.core.types import Conversation
[docs]
class VisionLanguageSftCollator:
def __init__(
self,
tokenizer: BaseTokenizer,
processor_name: str,
*,
processor_kwargs: Optional[dict[str, Any]] = None,
max_length: Optional[int] = None,
truncation: bool = False,
truncation_side: str = "right",
label_ignore_index: Optional[int] = None,
allow_multi_image_inputs: bool = True,
trust_remote_code: bool = False,
):
"""Custom collator for multi-modal vision-language training.
Args:
tokenizer: The tokenizer used for encoding the data.
processor_name: The name of the processor to use for feature generation.
processor_kwargs: A dictionary of processor-specific parameters.
These parameters are passed to the processor constructor.
They can override model-specific parameters.
max_length: Padding length.
truncation: Whether to truncate long inputs to `max_length`.
If False, the long inputs are preserved as is even if they exceed
`max_length`. Only has effect if `max_length` is specified.
truncation_side: The side to truncate the tokens ("right" or "left").
label_ignore_index: If set, then label values of tokens that shouldn't
contribute to the loss computation will be replaced by
this special value.
allow_multi_image_inputs: Whether to allow multi-image inputs.
trust_remote_code: Whether to trust remote code execution for the processor.
"""
self._allow_multi_image_inputs = allow_multi_image_inputs
if not processor_name:
raise ValueError("processor_name is required for VisionLanguageSftCollator")
self._conversation_feature_generator = (
VisionLanguageConversationFeatureGenerator(
tokenizer=tokenizer,
processor_name=processor_name,
processor_kwargs=processor_kwargs,
trust_remote_code=trust_remote_code,
return_tensors="pt",
truncation=truncation,
truncation_side=truncation_side,
max_length=max_length,
label_ignore_index=label_ignore_index,
)
)
[docs]
def __call__(self, batch) -> dict[str, Any]:
"""Custom collator for multi-modal vision-language training.
Args:
batch: List of batch items.
Returns:
Dict[str, torch.Tensor]: Processed batch.
"""
batch_size = len(batch)
if batch_size <= 0:
raise ValueError("Batch is empty")
conversations: list[Conversation] = []
for idx in range(batch_size):
example = batch[idx]
if "conversation_json" not in example:
raise ValueError(
f"Example doesn't contain 'conversation_json' key. "
f"Example: {idx + 1} of {batch_size}. "
f"Available keys: {example.keys()}"
)
conversation_json = example["conversation_json"]
conversations.append(Conversation.from_json(conversation_json))
assert len(conversations) == batch_size
result = self._conversation_feature_generator.transform_conversations(
conversations,
FeatureGeneratorOptions(allow_feature_reshape=False),
)
return result