Source code for oumi.core.callbacks.mfu_callback

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Based on MFU from PaLM paper: https://arxiv.org/pdf/2204.02311."""

import time
from typing import Optional, Union

import torch
import transformers

from oumi.core.callbacks.base_trainer_callback import BaseTrainerCallback
from oumi.core.configs import TrainingParams
from oumi.core.distributed import get_device_rank_info, is_world_process_zero
from oumi.performance.mfu import calculate_mfu
from oumi.utils.logging import logger

_LOGS_KWARG = "logs"

# MFU using only the time between on_step_start and on_step_end (except the first step)
_TRAIN_STEP_MFU = "train_step_mfu"
# MFU using the time since training started (except the first step)
_TRAIN_MFU = "train_mfu"


[docs] class MfuTrainerCallback(BaseTrainerCallback): """Trainer callback to calculate the MFU of the model during training. Should be compatible with all trainers that inherit from transformers.Trainer. """ def __init__( self, dtype: torch.dtype, num_params: int, sequence_length: int, num_layers: Optional[int] = None, num_attention_heads: Optional[int] = None, attention_head_size: Optional[int] = None, add_rematerialization: bool = False, ): """Initialize the MfuTrainerCallback. Args: dtype: The data type of the model. num_params: The number of parameters in the model. start_time_seconds: The start time of the program. sequence_length: The sequence length of the model. num_layers: The number of layers in the model. num_attention_heads: The number of attention heads in the model. attention_head_size: The size of each attention head in the model. add_rematerialization: Whether to add rematerialization to FLOPs per token. """ self._dtype = dtype self._num_params = num_params self._time_of_second_step: Optional[float] = None self._time_for_train_steps = 0.0 self._tokens_seen_so_far = 0 self._sequence_length = sequence_length self._num_layers = num_layers self._num_attention_heads = num_attention_heads self._attention_head_size = attention_head_size self._add_rematerialization = add_rematerialization self._first_step_finished = False self._steps_since_last_log = 0 device_rank_info = get_device_rank_info() self._num_devices = device_rank_info.world_size self._is_world_rank_zero = is_world_process_zero() logger.info(f"MFU number of devices: {self._num_devices}") # Assume all devices are identical self._device_name = "CPU" if torch.cuda.is_available(): self._device_name = torch.cuda.get_device_name(0) logger.info(f"MFU device name: {self._device_name}") if self._device_name == "CPU": logger.warning("MFU is not supported on CPU, the callback will do nothing.") def _callback_disabled(self) -> bool: """Check if the callback should be disabled.""" return not self._is_world_rank_zero or self._device_name == "CPU"
[docs] def on_step_begin( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the beginning of each train step.""" if self._callback_disabled(): return self._step_start_time = time.time() if not self._first_step_finished: # Calculate the number of tokens processed per step during the first step self._tokens_per_step = ( args.gradient_accumulation_steps * args.per_device_train_batch_size * self._num_devices * self._sequence_length ) return if self._time_of_second_step is None: self._time_of_second_step = self._step_start_time
[docs] def on_step_end( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the end of each train step. Note that this will be called after all gradient accumulation substeps. """ if self._callback_disabled(): return # Keep track of only the training step time for "ideal" MFU delta_time_seconds = time.time() - self._step_start_time if not self._first_step_finished: self._first_step_finished = True logger.info(f"First step time: {delta_time_seconds:.2f}s") return self._time_for_train_steps += delta_time_seconds self._steps_since_last_log += 1
[docs] def on_log( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called after logging the last logs.""" if self._callback_disabled(): return # Avoid logging until after the first step. if self._time_of_second_step is None: return delta_time_seconds_train = time.time() - self._time_of_second_step delta_time_seconds_step = self._time_for_train_steps tokens_since_last_log = self._tokens_per_step * self._steps_since_last_log total_tokens = self._tokens_seen_so_far + tokens_since_last_log # MFU using only the time spent on training steps (excluding the first step). train_step_mfu = calculate_mfu( device_name=self._device_name, num_devices=self._num_devices, dtype=self._dtype, num_params=self._num_params, num_tokens=total_tokens, delta_time_seconds=delta_time_seconds_step, num_layers=self._num_layers, num_attention_heads=self._num_attention_heads, attention_head_size=self._attention_head_size, sequence_length=self._sequence_length, add_rematerialization=self._add_rematerialization, ) # MFU using the time since training started (excluding the first step). train_mfu = calculate_mfu( device_name=self._device_name, num_devices=self._num_devices, dtype=self._dtype, num_params=self._num_params, num_tokens=total_tokens, delta_time_seconds=delta_time_seconds_train, num_layers=self._num_layers, num_attention_heads=self._num_attention_heads, attention_head_size=self._attention_head_size, sequence_length=self._sequence_length, add_rematerialization=self._add_rematerialization, ) if _LOGS_KWARG in kwargs: kwargs[_LOGS_KWARG][_TRAIN_STEP_MFU] = train_step_mfu kwargs[_LOGS_KWARG][_TRAIN_MFU] = train_mfu # Cleanup values self._tokens_seen_so_far = total_tokens self._steps_since_last_log = 0