Source code for oumi.core.collators.text_completions_collator_with_padding
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import trl
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
_INPUT_IDS_KEY = "input_ids"
[docs]
class TextCompletionsCollatorWithPadding:
def __init__(
self, tokenizer: BaseTokenizer, instruction_prefix: str, response_prefix: str
):
"""Custom collator for text LLM training.
Args:
tokenizer: The tokenizer used for encoding the data.
instruction_prefix: The prefix marking the beginning of the user instruction.
response_prefix: The prefix marking the beginning of the assistant response.
"""
self._default_collator = trl.DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
instruction_template=instruction_prefix,
response_template=response_prefix,
)
if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None:
raise RuntimeError("Tokenizer doesn't define `pad_token_id`.")
def _collate(self, inputs: list[Any]) -> dict[str, Any]:
result = self._default_collator(inputs)
return result
[docs]
def __call__(self, batch) -> dict[str, Any]:
"""Pads to the longest length present in the batch.
Args:
batch: List of batch items.
Returns:
Dict[str, torch.Tensor]: Processed batch.
"""
for item in batch:
if _INPUT_IDS_KEY not in item:
raise ValueError(
f"Item doesn't contain '{_INPUT_IDS_KEY}' key. "
f"Available keys: {item.keys()}"
)
# Collate batch prompts.
collated_text_inputs = self._collate(batch)
return collated_text_inputs