Source code for oumi.core.trainers.hf_trainer

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib
from typing import Optional

import transformers

from oumi.core.configs import TrainingConfig
from oumi.core.configs.params.peft_params import PeftSaveMode
from oumi.core.distributed import is_world_process_zero
from oumi.core.processors.base_processor import BaseProcessor
from oumi.core.trainers.base_trainer import BaseTrainer
from oumi.utils.logging import logger


[docs] class HuggingFaceTrainer(BaseTrainer): def __init__( self, hf_trainer: transformers.Trainer, processor: Optional[BaseProcessor] = None, ): """Initializes HuggingFace-specific Trainer version.""" self._hf_trainer = hf_trainer self._processor = processor
[docs] def train(self, resume_from_checkpoint: Optional[str] = None) -> None: """Trains a model.""" self._hf_trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[docs] def save_state(self) -> None: """See base class. Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model. HuggingFace normally writes state into "trainer_state.json" under output_dir. """ if not is_world_process_zero(): return self._hf_trainer.save_state()
[docs] def save_model(self, config: TrainingConfig, final: bool = True) -> None: """Saves the model's weights to the specified output directory. Args: config: The Oumi training config. final: Whether this is the final model being saved during training. - Applies optimizations for the final model checkpoint. - In the case of FSDP, this will always save the FULL_STATE_DICT instead of the default STATE_DICT. Returns: None """ if self._hf_trainer.is_fsdp_enabled: # FSDP is enabled, so we need to save the model in a special way. return self._save_fsdp_model(config=config, final=final) if not is_world_process_zero(): return output_dir = config.training.output_dir if not config.training.use_peft: self._hf_trainer.save_model(output_dir) else: if config.peft.peft_save_mode == PeftSaveMode.MERGED: # Saving the merged model only saves the model weights, not the # tokenizer files and training args. To ensure we're saving all relevant # files, we save the PEFT model first, delete the adapter files, then # save the merged model. # The adapter files are moved to the "adapter/" subdirectory to not # interfere with the other saved model files. self._hf_trainer.save_model(output_dir) output_dir_path = pathlib.Path(output_dir) adapter_dir = output_dir_path / "adapter" adapter_dir.mkdir(parents=True, exist_ok=True) for filename in ["adapter_config.json", "adapter_model.safetensors"]: file_path = output_dir_path / filename if file_path.exists(): file_path.rename(adapter_dir / filename) else: logger.warning( f"{filename} not found in {output_dir} when " "attempting to delete during model saving." ) merged_model = self._hf_trainer.model.merge_and_unload( progressbar=True, safe_merge=True ) merged_model.save_pretrained(output_dir) elif config.peft.peft_save_mode == PeftSaveMode.ADAPTER_ONLY: # Save the LoRA adapter (doesn't include the base model). self._hf_trainer.save_model(output_dir) elif config.peft.peft_save_mode == PeftSaveMode.ADAPTER_AND_BASE_MODEL: self._hf_trainer.save_model(output_dir) # Saving the base model requires a separate call. self._hf_trainer.model.base_model.save_pretrained(output_dir) else: raise ValueError( f"Unsupported PEFT save mode: {config.peft.peft_save_mode}" ) logger.info(f"Model has been saved at {output_dir}") if self._processor is not None: self._processor.save_config(output_dir) logger.info(f"Processor config has been saved at {output_dir}")
def _save_fsdp_model(self, config: TrainingConfig, final: bool = True) -> None: """Saves the model's weights to the specified output directory. For FSDP, all ranks should call into this function """ if final: # For the final checkpoint, we need to save the FULL_STATE_DICT instead of # the default STATE_DICT. if ( self._hf_trainer.is_fsdp_enabled and self._hf_trainer.accelerator.state.fsdp_plugin is not None ): logger.info("Saving FULL_STATE_DICT for final model checkpoint.") self._hf_trainer.accelerator.state.fsdp_plugin.set_state_dict_type( "FULL_STATE_DICT" ) output_dir = config.training.output_dir self._hf_trainer.save_model(output_dir) logger.info(f"Model has been saved at {output_dir}")