Source code for oumi.core.callbacks.hf_mfu_callback
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MFU calculator based on theoretical model flops computed by HuggingFace libraries."""
import time
from typing import Optional, Union
import torch
import transformers
from oumi.core.callbacks.base_trainer_callback import BaseTrainerCallback
from oumi.core.configs import TrainingParams
from oumi.core.distributed import get_device_rank_info, is_world_process_zero
from oumi.performance.mfu import (
calculate_mfu_from_model_flops_per_second,
)
from oumi.utils.logging import logger
_LOGS_KWARG = "logs"
# MFU using only the time between on_step_start and on_step_end (except the first step)
# using built-in HuggingFace model's flops estimate.
_HF_TRAIN_STEP_MFU = "hf_train_step_mfu"
# MFU using the time since training started (except the first step)
# using built-in HuggingFace model's flops estimate.
_HF_TRAIN_MFU = "hf_train_mfu"
[docs]
class HfMfuTrainerCallback(BaseTrainerCallback):
"""Trainer callback to calculate the MFU of the model during training.
Relies on model's flops estimate computed by HuggingFace in `total_flos` metric.
"""
def __init__(
self,
dtype: torch.dtype,
):
"""Initialize the HfMfuTrainerCallback.
Args:
dtype: The data type of the model.
"""
self._dtype = dtype
self._time_of_second_step: Optional[float] = None
self._flops_at_second_step: Optional[float] = None
self._time_for_train_steps = 0.0
self._first_step_finished = False
device_rank_info = get_device_rank_info()
self._num_devices = device_rank_info.world_size
self._is_world_rank_zero = is_world_process_zero()
logger.info(f"HF MFU number of devices: {self._num_devices}")
# Assume all devices are identical
self._device_name = "CPU"
if torch.cuda.is_available():
self._device_name = torch.cuda.get_device_name(0)
logger.info(f"HF MFU device name: {self._device_name}")
if self._device_name == "CPU":
logger.warning(
"HF MFU is not supported on CPU, the callback will do nothing."
)
def _callback_disabled(self) -> bool:
"""Check if the callback should be disabled."""
return not self._is_world_rank_zero or self._device_name == "CPU"
[docs]
def on_step_begin(
self,
args: Union[transformers.TrainingArguments, TrainingParams],
state: Optional[transformers.TrainerState] = None,
control: Optional[transformers.TrainerControl] = None,
**kwargs,
):
"""Event called at the beginning of each train step."""
if self._callback_disabled():
return
self._step_start_time = time.time()
if not self._first_step_finished:
return
if self._time_of_second_step is None:
self._time_of_second_step = self._step_start_time
if state is not None:
self._flops_at_second_step = state.total_flos
[docs]
def on_step_end(
self,
args: Union[transformers.TrainingArguments, TrainingParams],
state: Optional[transformers.TrainerState] = None,
control: Optional[transformers.TrainerControl] = None,
**kwargs,
):
"""Event called at the end of each train step.
Note that this will be called after all gradient accumulation substeps.
"""
if self._callback_disabled():
return
# Keep track of only the training step time for "ideal" MFU
delta_time_seconds = time.time() - self._step_start_time
if not self._first_step_finished:
self._first_step_finished = True
logger.info(f"First step time: {delta_time_seconds:.2f}s")
return
self._time_for_train_steps += delta_time_seconds
[docs]
def on_log(
self,
args: Union[transformers.TrainingArguments, TrainingParams],
state: Optional[transformers.TrainerState] = None,
control: Optional[transformers.TrainerControl] = None,
**kwargs,
):
"""Event called after logging the last logs."""
if self._callback_disabled():
return
# Avoid logging until after the first step.
if self._time_of_second_step is None:
return
delta_time_seconds_train = time.time() - self._time_of_second_step
delta_time_seconds_step = self._time_for_train_steps
if self._flops_at_second_step is not None and (
state is not None and state.total_flos > 0.0
):
flops_since_second_step_on_all_devices = (
state.total_flos - self._flops_at_second_step
) * self._num_devices
train_step_mfu = calculate_mfu_from_model_flops_per_second(
device_name=self._device_name,
num_devices=self._num_devices,
dtype=self._dtype,
model_flops_per_second_on_all_devices=(
flops_since_second_step_on_all_devices / delta_time_seconds_step
),
)
train_mfu = calculate_mfu_from_model_flops_per_second(
device_name=self._device_name,
num_devices=self._num_devices,
dtype=self._dtype,
model_flops_per_second_on_all_devices=(
flops_since_second_step_on_all_devices / delta_time_seconds_train
),
)
if _LOGS_KWARG in kwargs:
kwargs[_LOGS_KWARG][_HF_TRAIN_STEP_MFU] = train_step_mfu
kwargs[_LOGS_KWARG][_HF_TRAIN_MFU] = train_mfu