Source code for oumi.builders.collators

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional

import oumi.core.constants as constants
from oumi.core.collators.text_collator_with_padding import TextCollatorWithPadding
from oumi.core.collators.text_completions_collator_with_padding import (
    TextCompletionsCollatorWithPadding,
)
from oumi.core.collators.vision_language_collator_with_padding import (
    VisionLanguageCollatorWithPadding,
)
from oumi.core.configs import DatasetSplit, TrainingConfig
from oumi.core.configs.internal.supported_models import (
    find_internal_model_config,
)
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.logging import logger

# This is used to set the max input length for a model with infinite size input
_VERY_LARGE_INTEGER = int(1e30)


[docs] def build_data_collator( collator_name: str, tokenizer: BaseTokenizer, *, max_length: Optional[int], label_ignore_index: Optional[int] = constants.LABEL_IGNORE_INDEX, **kwargs, ) -> Callable: """Builds a data collator based on the given collator name. Args: collator_name: The name of the collator to build. Supported values are: - "text_with_padding": Uses `TextCollatorWithPadding`. - "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`. tokenizer: A tokenizer. max_length: An optional maximum sequence length. label_ignore_index: If set, then label values of tokens that shouldn't contribute to the loss computation will be replaced by this special value. For example, this can be `PAD`, or image tokens. PyTorch convention is to use -100 as the `ignore_index` label. Refer to the `ignore_index` parameter of `torch.nn.CrossEntropyLoss()` for more details. **kwargs: Additional keyword arguments to pass to the collator constructor. Returns: Callable: The data collator function or class. Raises: ValueError: If an unsupported collator name is provided. """ if not collator_name: raise ValueError("Empty data collator name.") enable_truncation: bool = False if max_length is not None and max_length > 0: enable_truncation = True if ( tokenizer.model_max_length is not None and tokenizer.model_max_length < _VERY_LARGE_INTEGER and max_length != tokenizer.model_max_length ): logger.warning( f"Data collator's maximum length: ({max_length}) is " + ( "greater than" if max_length > tokenizer.model_max_length else "less than" ) + f" tokenizer's model maximum length ({tokenizer.model_max_length})" ) if collator_name == "text_with_padding": return TextCollatorWithPadding( tokenizer=tokenizer, max_length=max_length, label_ignore_index=label_ignore_index, truncation=enable_truncation, **kwargs, ) elif collator_name == "vision_language_with_padding": return VisionLanguageCollatorWithPadding( tokenizer=tokenizer, max_length=max_length, label_ignore_index=label_ignore_index, truncation=enable_truncation, **kwargs, ) elif collator_name == "text_completions_only_with_padding": return TextCompletionsCollatorWithPadding( tokenizer=tokenizer, instruction_prefix="<|start_header_id|>user<|end_header_id|>\n\n", response_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", ) raise ValueError(f"Unknown data collator name: '{collator_name}'")
[docs] def build_collator_from_config( config: TrainingConfig, tokenizer: Optional[BaseTokenizer] ) -> Optional[Callable]: """Creates data collator if specified in config.""" train_split = config.data.get_split(DatasetSplit.TRAIN) if not train_split.collator_name: return None if tokenizer is None: raise ValueError( "Tokenizer must be provided if collator is specified! " f"collator: '{train_split.collator_name}'" ) model_config = find_internal_model_config(config.model) label_ignore_index: Optional[int] = ( model_config.label_ignore_index if model_config is not None else constants.LABEL_IGNORE_INDEX ) return build_data_collator( collator_name=train_split.collator_name, tokenizer=tokenizer, max_length=config.model.model_max_length, label_ignore_index=label_ignore_index, )