# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importwarningsfrompprintimportpformatfromtypingimportCallable,Optional,castimporttransformersimporttrlfromoumi.core.configsimportTrainerType,TrainingParamsfromoumi.core.distributedimportis_world_process_zerofromoumi.core.processors.base_processorimportBaseProcessorfromoumi.core.trainersimportBaseTrainer,HuggingFaceTrainer,VerlGrpoTrainerfromoumi.core.trainersimportTrainerasOumiTrainerfromoumi.utils.loggingimportlogger
[docs]defbuild_trainer(trainer_type:TrainerType,processor:Optional[BaseProcessor])->Callable[...,BaseTrainer]:"""Builds a trainer creator functor based on the provided configuration. Args: trainer_type (TrainerType): Enum indicating the type of training. processor: An optional processor. Returns: A builder function that can create an appropriate trainer based on the trainer type specified in the configuration. All function arguments supplied by caller are forwarded to the trainer's constructor. Raises: NotImplementedError: If the trainer type specified in the configuration is not supported. """def_create_hf_builder_fn(cls:type[transformers.Trainer],)->Callable[...,BaseTrainer]:def_init_hf_trainer(*args,**kwargs)->BaseTrainer:training_args=kwargs.pop("args",None)callbacks=kwargs.pop("callbacks",[])iftraining_argsisnotNone:# if set, convert to HuggingFace Trainer args formattraining_args=cast(TrainingParams,training_args)training_args.finalize_and_validate()hf_args=training_args.to_hf()ifis_world_process_zero():logger.info(pformat(hf_args))trainer=HuggingFaceTrainer(cls(*args,**kwargs,args=hf_args),processor)ifcallbacks:# TODO(OPE-250): Define generalizable callback abstraction# Incredibly ugly, but this is the only way to add callbacks that add# metrics to wandb. Transformers trainer has no public method of# allowing us to control the order callbacks are called.training_callbacks=([transformers.trainer_callback.DefaultFlowCallback]+callbacks# Skip the first callback, which is the DefaultFlowCallback above.+trainer._hf_trainer.callback_handler.callbacks[1:])trainer._hf_trainer.callback_handler.callbacks=[]forcintraining_callbacks:trainer._hf_trainer.add_callback(c)returntrainerreturn_init_hf_trainerdef_create_oumi_builder_fn()->Callable[...,BaseTrainer]:def_init_oumi_trainer(*args,**kwargs)->BaseTrainer:kwargs_processor=kwargs.get("processor",None)ifprocessorisnotNone:ifkwargs_processorisNone:kwargs["processor"]=processorelifid(kwargs_processor)!=id(processor):raiseValueError("Different processor instances passed to Oumi trainer, ""and build_trainer().")returnOumiTrainer(*args,**kwargs)return_init_oumi_trainerdef_create_verl_grpo_builder_fn()->Callable[...,BaseTrainer]:def_init_verl_grpo_trainer(*args,**kwargs)->BaseTrainer:returnVerlGrpoTrainer(*args,**kwargs)return_init_verl_grpo_traineriftrainer_type==TrainerType.TRL_SFT:return_create_hf_builder_fn(trl.SFTTrainer)eliftrainer_type==TrainerType.TRL_DPO:return_create_hf_builder_fn(trl.DPOTrainer)eliftrainer_type==TrainerType.TRL_GRPO:return_create_hf_builder_fn(trl.GRPOTrainer)eliftrainer_type==TrainerType.HF:return_create_hf_builder_fn(transformers.Trainer)eliftrainer_type==TrainerType.OUMI:warnings.warn("OUMI trainer is still in alpha mode. ""Prefer to use HF trainer when possible.")return_create_oumi_builder_fn()eliftrainer_type==TrainerType.VERL_GRPO:return_create_verl_grpo_builder_fn()raiseNotImplementedError(f"Trainer type {trainer_type} not supported.")