Source code for oumi.builders.collators
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
import oumi.core.constants as constants
from oumi.core.collators.text_collator_with_padding import TextCollatorWithPadding
from oumi.core.collators.text_completions_collator_with_padding import (
TextCompletionsCollatorWithPadding,
)
from oumi.core.collators.vision_language_collator_with_padding import (
VisionLanguageCollatorWithPadding,
)
from oumi.core.configs import DatasetSplit, TrainingConfig
from oumi.core.configs.internal.supported_models import (
find_internal_model_config,
)
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
from oumi.utils.logging import logger
# This is used to set the max input length for a model with infinite size input
_VERY_LARGE_INTEGER = int(1e30)
[docs]
def build_data_collator(
collator_name: str,
tokenizer: BaseTokenizer,
*,
max_length: Optional[int],
label_ignore_index: Optional[int] = constants.LABEL_IGNORE_INDEX,
**kwargs,
) -> Callable:
"""Builds a data collator based on the given collator name.
Args:
collator_name: The name of the collator to build.
Supported values are:
- "text_with_padding": Uses `TextCollatorWithPadding`.
- "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`.
tokenizer: A tokenizer.
max_length: An optional maximum sequence length.
label_ignore_index: If set, then label values of tokens that shouldn't
contribute to the loss computation will be replaced by this special value.
For example, this can be `PAD`, or image tokens.
PyTorch convention is to use -100 as the `ignore_index` label. Refer to
the `ignore_index` parameter of `torch.nn.CrossEntropyLoss()`
for more details.
**kwargs: Additional keyword arguments to pass to the collator constructor.
Returns:
Callable: The data collator function or class.
Raises:
ValueError: If an unsupported collator name is provided.
"""
if not collator_name:
raise ValueError("Empty data collator name.")
enable_truncation: bool = False
if max_length is not None and max_length > 0:
enable_truncation = True
if (
tokenizer.model_max_length is not None
and tokenizer.model_max_length < _VERY_LARGE_INTEGER
and max_length != tokenizer.model_max_length
):
logger.warning(
f"Data collator's maximum length: ({max_length}) is "
+ (
"greater than"
if max_length > tokenizer.model_max_length
else "less than"
)
+ f" tokenizer's model maximum length ({tokenizer.model_max_length})"
)
if collator_name == "text_with_padding":
return TextCollatorWithPadding(
tokenizer=tokenizer,
max_length=max_length,
label_ignore_index=label_ignore_index,
truncation=enable_truncation,
**kwargs,
)
elif collator_name == "vision_language_with_padding":
return VisionLanguageCollatorWithPadding(
tokenizer=tokenizer,
max_length=max_length,
label_ignore_index=label_ignore_index,
truncation=enable_truncation,
**kwargs,
)
elif collator_name == "text_completions_only_with_padding":
return TextCompletionsCollatorWithPadding(
tokenizer=tokenizer,
instruction_prefix="<|start_header_id|>user<|end_header_id|>\n\n",
response_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
raise ValueError(f"Unknown data collator name: '{collator_name}'")
[docs]
def build_collator_from_config(
config: TrainingConfig, tokenizer: Optional[BaseTokenizer]
) -> Optional[Callable]:
"""Creates data collator if specified in config."""
train_split = config.data.get_split(DatasetSplit.TRAIN)
if not train_split.collator_name:
return None
if tokenizer is None:
raise ValueError(
"Tokenizer must be provided if collator is specified! "
f"collator: '{train_split.collator_name}'"
)
model_config = find_internal_model_config(config.model)
label_ignore_index: Optional[int] = (
model_config.label_ignore_index
if model_config is not None
else constants.LABEL_IGNORE_INDEX
)
return build_data_collator(
collator_name=train_split.collator_name,
tokenizer=tokenizer,
max_length=config.model.model_max_length,
label_ignore_index=label_ignore_index,
)