# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importitertoolsfromdataclassesimportdataclass,fieldfromtypingimportFinalimporttorchfromoumi.core.configs.base_configimportBaseConfigfromoumi.core.configs.params.data_paramsimportDataParamsfromoumi.core.configs.params.fsdp_paramsimportFSDPParamsfromoumi.core.configs.params.model_paramsimportModelParamsfromoumi.core.configs.params.peft_paramsimportPeftParamsfromoumi.core.configs.params.training_paramsimport(MixedPrecisionDtype,TrainerType,TrainingParams,)fromoumi.utils.loggingimportlogger
[docs]@dataclassclassTrainingConfig(BaseConfig):data:DataParams=field(default_factory=DataParams)"""Parameters for the dataset. This field contains all the necessary settings for data processing and loading. It includes options for train and evaluation datasets and preprocessing steps. For more details, see the :class:`oumi.core.configs.params.data_params.DataParams` class. """model:ModelParams=field(default_factory=ModelParams)"""Parameters for the model. This field defines the model architecture, size, and other model-specific settings. It includes options for model type, pretrained weights, and tokenizer configuration. For more details, see :class:`oumi.core.configs.params.model_params.ModelParams` class. """training:TrainingParams=field(default_factory=TrainingParams)"""Parameters for the training process. This field contains all settings related to the training loop, including learning rate, batch size, number of epochs, and optimization parameters. For more details, see :class:`oumi.core.configs.params.training_params.TrainingParams`. """peft:PeftParams=field(default_factory=PeftParams)"""Parameters for Parameter-Efficient Fine-Tuning (PEFT). This field defines settings for various PEFT methods such as LoRA, or Prefix Tuning. It includes options for rank, alpha values, and other PEFT-specific parameters. For more details, see :class:`oumi.core.configs.params.peft_params.PeftParams`. """fsdp:FSDPParams=field(default_factory=FSDPParams)"""Parameters for FSDP."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifself.model.compile:raiseValueError("Use `training.compile` instead of `model.compile` to ""enable model compilation during training.")ifself.training.compileand(self.fsdp.use_orig_paramsisnotNoneandnotself.fsdp.use_orig_params):raiseValueError("`fsdp.use_orig_params` must be True for model compilation.")# Verify values for model dtype and mixed precision training.ifself.training.mixed_precision_dtypein[MixedPrecisionDtype.FP16,MixedPrecisionDtype.BF16,]:ifself.model.torch_dtype!=torch.float32:raiseValueError("Model must be loaded in fp32 to enable mixed precision training.")trainer_type:Final[TrainerType]=self.training.trainer_type# Check values for model sequence length.ifself.model.model_max_lengthandself.model.model_max_length>0:max_seq_length_value=int(self.model.model_max_length)max_seq_length_key=Noneiftrainer_type==TrainerType.TRL_SFT:max_seq_length_key="max_seq_length"eliftrainer_type==TrainerType.TRL_DPO:max_seq_length_key="max_length"# TODO: DPOTrainer also defines "max_prompt_length" and# "max_target_length". How to handle them?else:logger.warning(f"Ignored model.model_max_length={max_seq_length_value} "f"parameter for trainer {self.training.trainer_type}.")ifmax_seq_length_key:existing_max_seq_length=self.training.trainer_kwargs.get(max_seq_length_key)if(existing_max_seq_lengthisnotNone)and(existing_max_seq_length!=max_seq_length_value):logger.warning(f"Overriding existing '{max_seq_length_key}' value "f"'{existing_max_seq_length}' with '{max_seq_length_value}'")self.training.trainer_kwargs[max_seq_length_key]=max_seq_length_value# Set Liger kernel flags if using a HF trainer, and if so, don't do Liger# patch ourselves.# TODO(OPE-1117): Clean up this logic after upgrading to trl 0.16.ifself.model.enable_liger_kernel:iftrainer_type==TrainerType.TRL_SFT:self.training.trainer_kwargs["use_liger"]=Trueself.training.trainer_kwargs["use_liger_kernel"]=Trueself.model.enable_liger_kernel=Falseeliftrainer_typein(TrainerType.TRL_DPO,TrainerType.HF):self.training.trainer_kwargs["use_liger_kernel"]=Trueself.model.enable_liger_kernel=Falseeliftrainer_type==TrainerType.OUMI:# We need to Liger patch ourselves for our own training loop.passelse:raiseValueError("Unrecognized trainer type!")# Setup and validate params for "vision_language_sft" collator.# The collator expects VLM SFT dataset to only produce just# one column: 'conversation_json' (JSON-encoded `Conversation`)!collator_name:Final[str]=self.data.train.collator_nameor""ifcollator_name=="vision_language_sft":fordataset_paramsinitertools.chain(self.data.train.datasets,self.data.validation.datasets,self.data.test.datasets,):ifnotdataset_params.dataset_kwargs.get("return_conversations",True):raiseValueError("`return_conversations` must be True "f"for the dataset '{dataset_params.dataset_name}' "f"when using '{collator_name}' collator!")dataset_params.dataset_kwargs["return_conversations"]=True# Extra setup for TRL_SFT.iftrainer_type==TrainerType.TRL_SFT:ifself.training.trainer_kwargs.get("remove_unused_columns",False):raiseValueError("`remove_unused_columns` must be False "f"when using '{collator_name}' collator! "'The "unused" columns are consumed by the collator, '"not by a model.")self.training.trainer_kwargs["remove_unused_columns"]=False# `trl` shouldn't be preparing the dataset, as we do it in Oumi.dataset_kwargs=self.training.trainer_kwargs.get("dataset_kwargs",{})dataset_kwargs["skip_prepare_dataset"]=Trueself.training.trainer_kwargs["dataset_kwargs"]=dataset_kwargsiflen(self.model.processor_kwargs)>0:model_processor_name:Final[str]=(self.model.tokenizer_nameorself.model.model_name)fordataset_paramsinitertools.chain(self.data.train.datasets,self.data.validation.datasets,self.data.test.datasets,):if("processor_name"notindataset_params.dataset_kwargsor"processor_kwargs"indataset_params.dataset_kwargs):continuedataset_processor_name:str=dataset_params.dataset_kwargs["processor_name"]ifdataset_processor_name==model_processor_name:# Copy processor kwargs from the model if processor names match# and the dataset doesn't override them.dataset_params.dataset_kwargs["processor_kwargs"]={**self.model.processor_kwargs}# Verl will error without a validation dataset.if(self.training.trainer_type==TrainerType.VERL_GRPOandnotself.data.validation.datasets):raiseValueError("At least one validation dataset is required for VERL_GRPO training.")