# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Based on MFU from PaLM paper: https://arxiv.org/pdf/2204.02311."""importtimefromtypingimportOptional,Unionimporttorchimporttransformersfromoumi.core.callbacks.base_trainer_callbackimportBaseTrainerCallbackfromoumi.core.configsimportTrainingParamsfromoumi.core.distributedimportget_device_rank_info,is_world_process_zerofromoumi.performance.mfuimportcalculate_mfufromoumi.utils.loggingimportloggerfromoumi.utils.torch_utilsimportget_device_name_LOGS_KWARG="logs"# MFU using only the time between on_step_start and on_step_end (except the first step)_TRAIN_STEP_MFU="train_step_mfu"# MFU using the time since training started (except the first step)_TRAIN_MFU="train_mfu"
[docs]classMfuTrainerCallback(BaseTrainerCallback):"""Trainer callback to calculate the MFU of the model during training. Should be compatible with all trainers that inherit from transformers.Trainer. """def__init__(self,dtype:torch.dtype,num_params:int,sequence_length:int,num_layers:Optional[int]=None,num_attention_heads:Optional[int]=None,attention_head_size:Optional[int]=None,add_rematerialization:bool=False,):"""Initialize the MfuTrainerCallback. Args: dtype: The data type of the model. num_params: The number of parameters in the model. start_time_seconds: The start time of the program. sequence_length: The sequence length of the model. num_layers: The number of layers in the model. num_attention_heads: The number of attention heads in the model. attention_head_size: The size of each attention head in the model. add_rematerialization: Whether to add rematerialization to FLOPs per token. """self._dtype=dtypeself._num_params=num_paramsself._time_of_second_step:Optional[float]=Noneself._time_for_train_steps=0.0self._tokens_seen_so_far=0self._sequence_length=sequence_lengthself._num_layers=num_layersself._num_attention_heads=num_attention_headsself._attention_head_size=attention_head_sizeself._add_rematerialization=add_rematerializationself._first_step_finished=Falseself._steps_since_last_log=0device_rank_info=get_device_rank_info()self._num_devices=device_rank_info.world_sizeself._is_world_rank_zero=is_world_process_zero()logger.info(f"MFU number of devices: {self._num_devices}")self._device_name=get_device_name()logger.info(f"MFU device name: {self._device_name}")ifself._device_name=="CPU":logger.warning("MFU is not supported on CPU, the callback will do nothing.")def_callback_disabled(self)->bool:"""Check if the callback should be disabled."""returnnotself._is_world_rank_zeroorself._device_name=="CPU"
[docs]defon_step_begin(self,args:Union[transformers.TrainingArguments,TrainingParams],state:Optional[transformers.TrainerState]=None,control:Optional[transformers.TrainerControl]=None,**kwargs,):"""Event called at the beginning of each train step."""ifself._callback_disabled():returnself._step_start_time=time.time()ifnotself._first_step_finished:# Calculate the number of tokens processed per step during the first stepself._tokens_per_step=(args.gradient_accumulation_steps*args.per_device_train_batch_size*self._num_devices*self._sequence_length)returnifself._time_of_second_stepisNone:self._time_of_second_step=self._step_start_time
[docs]defon_step_end(self,args:Union[transformers.TrainingArguments,TrainingParams],state:Optional[transformers.TrainerState]=None,control:Optional[transformers.TrainerControl]=None,**kwargs,):"""Event called at the end of each train step. Note that this will be called after all gradient accumulation substeps. """ifself._callback_disabled():return# Keep track of only the training step time for "ideal" MFUdelta_time_seconds=time.time()-self._step_start_timeifnotself._first_step_finished:self._first_step_finished=Truelogger.info(f"First step time: {delta_time_seconds:.2f}s")returnself._time_for_train_steps+=delta_time_secondsself._steps_since_last_log+=1
[docs]defon_log(self,args:Union[transformers.TrainingArguments,TrainingParams],state:Optional[transformers.TrainerState]=None,control:Optional[transformers.TrainerControl]=None,**kwargs,):"""Event called after logging the last logs."""ifself._callback_disabled():return# Avoid logging until after the first step.ifself._time_of_second_stepisNone:returndelta_time_seconds_train=time.time()-self._time_of_second_stepdelta_time_seconds_step=self._time_for_train_stepstokens_since_last_log=self._tokens_per_step*self._steps_since_last_logtotal_tokens=self._tokens_seen_so_far+tokens_since_last_log# MFU using only the time spent on training steps (excluding the first step).train_step_mfu=calculate_mfu(device_name=self._device_name,num_devices=self._num_devices,dtype=self._dtype,num_params=self._num_params,num_tokens=total_tokens,delta_time_seconds=delta_time_seconds_step,num_layers=self._num_layers,num_attention_heads=self._num_attention_heads,attention_head_size=self._attention_head_size,sequence_length=self._sequence_length,add_rematerialization=self._add_rematerialization,)# MFU using the time since training started (excluding the first step).train_mfu=calculate_mfu(device_name=self._device_name,num_devices=self._num_devices,dtype=self._dtype,num_params=self._num_params,num_tokens=total_tokens,delta_time_seconds=delta_time_seconds_train,num_layers=self._num_layers,num_attention_heads=self._num_attention_heads,attention_head_size=self._attention_head_size,sequence_length=self._sequence_length,add_rematerialization=self._add_rematerialization,)if_LOGS_KWARGinkwargs:kwargs[_LOGS_KWARG][_TRAIN_STEP_MFU]=train_step_mfukwargs[_LOGS_KWARG][_TRAIN_MFU]=train_mfu# Cleanup valuesself._tokens_seen_so_far=total_tokensself._steps_since_last_log=0