# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportCallable,Optionalimportoumi.core.constantsasconstantsfromoumi.core.collators.text_collator_with_paddingimportTextCollatorWithPaddingfromoumi.core.collators.text_completions_collator_with_paddingimport(TextCompletionsCollatorWithPadding,)fromoumi.core.collators.vision_language_collator_with_paddingimport(VisionLanguageCollatorWithPadding,)fromoumi.core.collators.vision_language_sft_collatorimportVisionLanguageSftCollatorfromoumi.core.configsimportDatasetSplit,TrainingConfigfromoumi.core.configs.internal.supported_modelsimport(find_internal_model_config,)fromoumi.core.tokenizers.base_tokenizerimportBaseTokenizerfromoumi.utils.loggingimportlogger# This is used to set the max input length for a model with infinite size input_VERY_LARGE_INTEGER=int(1e30)
[docs]defbuild_data_collator(collator_name:str,tokenizer:BaseTokenizer,*,max_length:Optional[int],label_ignore_index:Optional[int]=constants.LABEL_IGNORE_INDEX,**kwargs,)->Callable:"""Builds a data collator based on the given collator name. Args: collator_name: The name of the collator to build. Supported values are: - "text_with_padding": Uses `TextCollatorWithPadding`. - "text_completions_only_with_padding": Uses `TextCompletionsCollatorWithPadding`. - "vision_language_with_padding": Uses `VisionLanguageCollatorWithPadding`. - "vision_language_sft": Uses `VisionLanguageSftCollator`. tokenizer: A tokenizer. max_length: An optional maximum sequence length. label_ignore_index: If set, then label values of tokens that shouldn't contribute to the loss computation will be replaced by this special value. For example, this can be `PAD`, or image tokens. PyTorch convention is to use -100 as the `ignore_index` label. Refer to the `ignore_index` parameter of `torch.nn.CrossEntropyLoss()` for more details. **kwargs: Additional keyword arguments to pass to the collator constructor. Returns: Callable: The data collator function or class. Raises: ValueError: If an unsupported collator name is provided. """ifnotcollator_name:raiseValueError("Empty data collator name.")enable_truncation:bool=Falseifmax_lengthisnotNoneandmax_length>0:enable_truncation=Trueif(tokenizer.model_max_lengthisnotNoneandtokenizer.model_max_length<_VERY_LARGE_INTEGERandmax_length!=tokenizer.model_max_length):logger.warning(f"Data collator's maximum length: ({max_length}) is "+("greater than"ifmax_length>tokenizer.model_max_lengthelse"less than")+f" tokenizer's model maximum length ({tokenizer.model_max_length})")ifcollator_name=="text_with_padding":returnTextCollatorWithPadding(tokenizer=tokenizer,max_length=max_length,truncation=enable_truncation,label_ignore_index=label_ignore_index,**kwargs,)elifcollator_name=="vision_language_with_padding":returnVisionLanguageCollatorWithPadding(tokenizer=tokenizer,max_length=max_length,truncation=enable_truncation,label_ignore_index=label_ignore_index,**kwargs,)elifcollator_name=="vision_language_sft":processor_name=kwargs.pop("processor_name",None)ifnotprocessor_name:raiseValueError(f"Empty processor_name for '{collator_name}'")processor_kwargs=kwargs.pop("processor_kwargs",None)returnVisionLanguageSftCollator(tokenizer=tokenizer,processor_name=processor_name,processor_kwargs=processor_kwargs,max_length=max_length,truncation=enable_truncation,label_ignore_index=label_ignore_index,**kwargs,)elifcollator_name=="text_completions_only_with_padding":returnTextCompletionsCollatorWithPadding(tokenizer=tokenizer,instruction_prefix="<|start_header_id|>user<|end_header_id|>\n\n",response_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n",)raiseValueError(f"Unknown data collator name: '{collator_name}'")
[docs]defbuild_collator_from_config(config:TrainingConfig,tokenizer:Optional[BaseTokenizer])->Optional[Callable]:"""Creates data collator if specified in config."""train_split=config.data.get_split(DatasetSplit.TRAIN)ifnottrain_split.collator_name:returnNonecollator_name:str=train_split.collator_nameiftokenizerisNone:raiseValueError("Tokenizer must be provided if collator is specified! "f"collator: '{collator_name}'")model_config=find_internal_model_config(config.model)label_ignore_index:Optional[int]=(config.training.label_ignore_indexifconfig.training.label_ignore_indexisnotNoneelse(model_config.label_ignore_indexifmodel_configisnotNoneelseconstants.LABEL_IGNORE_INDEX))collator_kwargs={}if(collator_namein("vision_language_with_padding","vision_language_sft")andmodel_configisnotNoneandmodel_config.visual_configisnotNone):collator_kwargs["allow_multi_image_inputs"]=(model_config.visual_config.supports_multiple_images)ifcollator_name=="vision_language_with_padding":collator_kwargs["main_image_feature"]=(model_config.visual_config.main_image_feature)ifcollator_name=="vision_language_sft":processor_name=collator_kwargs.get("processor_name",config.model.tokenizer_nameorconfig.model.model_name)ifnotprocessor_name:raiseValueError(f"Processor name must be provided for '{collator_name}'!")collator_kwargs["processor_name"]=processor_namecollator_kwargs["processor_kwargs"]=config.model.processor_kwargscollator_kwargs["trust_remote_code"]=collator_kwargs.get("trust_remote_code",config.model.trust_remote_code)returnbuild_data_collator(collator_name=collator_name,tokenizer=tokenizer,max_length=config.model.model_max_length,label_ignore_index=label_ignore_index,**collator_kwargs,)