# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importpathlibfromtypingimportOptionalimporttransformersfromoumi.core.configsimportTrainingConfigfromoumi.core.configs.params.peft_paramsimportPeftSaveModefromoumi.core.distributedimportis_world_process_zerofromoumi.core.processors.base_processorimportBaseProcessorfromoumi.core.trainers.base_trainerimportBaseTrainerfromoumi.utils.loggingimportlogger
[docs]deftrain(self,resume_from_checkpoint:Optional[str]=None)->None:"""Trains a model."""self._hf_trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[docs]defsave_state(self)->None:"""See base class. Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model. HuggingFace normally writes state into "trainer_state.json" under output_dir. """ifnotis_world_process_zero():returnself._hf_trainer.save_state()
[docs]defsave_model(self,config:TrainingConfig,final:bool=True)->None:"""Saves the model's weights to the specified output directory. Args: config: The Oumi training config. final: Whether this is the final model being saved during training. - Applies optimizations for the final model checkpoint. - In the case of FSDP, this will always save the FULL_STATE_DICT instead of the default STATE_DICT. Returns: None """ifself._hf_trainer.is_fsdp_enabled:# FSDP is enabled, so we need to save the model in a special way.returnself._save_fsdp_model(config=config,final=final)ifnotis_world_process_zero():returnoutput_dir=config.training.output_dirifnotconfig.training.use_peft:self._hf_trainer.save_model(output_dir)else:ifconfig.peft.peft_save_mode==PeftSaveMode.MERGED:# Saving the merged model only saves the model weights, not the# tokenizer files and training args. To ensure we're saving all relevant# files, we save the PEFT model first, delete the adapter files, then# save the merged model.# The adapter files are moved to the "adapter/" subdirectory to not# interfere with the other saved model files.self._hf_trainer.save_model(output_dir)output_dir_path=pathlib.Path(output_dir)adapter_dir=output_dir_path/"adapter"adapter_dir.mkdir(parents=True,exist_ok=True)forfilenamein["adapter_config.json","adapter_model.safetensors"]:file_path=output_dir_path/filenameiffile_path.exists():file_path.rename(adapter_dir/filename)else:logger.warning(f"{filename} not found in {output_dir} when ""attempting to delete during model saving.")merged_model=self._hf_trainer.model.merge_and_unload(progressbar=True,safe_merge=True)merged_model.save_pretrained(output_dir)elifconfig.peft.peft_save_mode==PeftSaveMode.ADAPTER_ONLY:# Save the LoRA adapter (doesn't include the base model).self._hf_trainer.save_model(output_dir)elifconfig.peft.peft_save_mode==PeftSaveMode.ADAPTER_AND_BASE_MODEL:self._hf_trainer.save_model(output_dir)# Saving the base model requires a separate call.self._hf_trainer.model.base_model.save_pretrained(output_dir)else:raiseValueError(f"Unsupported PEFT save mode: {config.peft.peft_save_mode}")logger.info(f"Model has been saved at {output_dir}")ifself._processorisnotNone:self._processor.save_config(output_dir)logger.info(f"Processor config has been saved at {output_dir}")
def_save_fsdp_model(self,config:TrainingConfig,final:bool=True)->None:"""Saves the model's weights to the specified output directory. For FSDP, all ranks should call into this function """iffinal:# For the final checkpoint, we need to save the FULL_STATE_DICT instead of# the default STATE_DICT.if(self._hf_trainer.is_fsdp_enabledandself._hf_trainer.accelerator.state.fsdp_pluginisnotNone):logger.info("Saving FULL_STATE_DICT for final model checkpoint.")self._hf_trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")output_dir=config.training.output_dirself._hf_trainer.save_model(output_dir)logger.info(f"Model has been saved at {output_dir}")