Source code for oumi.performance.mfu

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Based on MFU from PaLM paper: https://arxiv.org/pdf/2204.02311."""

from typing import Optional

import torch

_TFLOPS = "tflops"
_DEVICE_SPECS = {
    # https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
    "NVIDIA A100-PCIE-40GB": {
        _TFLOPS: {
            torch.float32: 19.5,
            torch.float16: 312.0,
            torch.bfloat16: 312.0,
        },
    },
    "NVIDIA A100-PCIE-80GB": {
        _TFLOPS: {
            torch.float32: 19.5,
            torch.float16: 312.0,
            torch.bfloat16: 312.0,
        },
    },
    "NVIDIA A100-SXM4-40GB": {
        _TFLOPS: {
            torch.float32: 19.5,
            torch.float16: 312.0,
            torch.bfloat16: 312.0,
        }
    },
    "NVIDIA A100-SXM4-80GB": {
        _TFLOPS: {
            torch.float32: 19.5,
            torch.float16: 312.0,
            torch.bfloat16: 312.0,
        }
    },
    # https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf
    "NVIDIA GeForce RTX 3090": {
        _TFLOPS: {
            torch.float32: 35.6,
            torch.float16: 71.0,
            torch.bfloat16: 71.0,
        },
    },
    # Only used for testing purposes
    # https://cloud.google.com/tpu/docs/v4
    "TPUv4": {
        _TFLOPS: {
            torch.float16: 275.0,
            torch.bfloat16: 275.0,
        },
    },
    # https://www.nvidia.com/en-us/data-center/l4/
    # Note that values in that page are shown with sparsity.
    "NVIDIA L4": {
        _TFLOPS: {
            torch.float32: 60.0,
            torch.float16: 121.0,
            torch.bfloat16: 121.0,
        },
    },
    # https://www.nvidia.com/en-us/data-center/tesla-t4/
    "Tesla T4": {
        _TFLOPS: {
            torch.float32: 8.1,
            torch.float16: 65.0,
            torch.bfloat16: 65.0,
        },
    },
}


def _get_device_flops(device_name: str, dtype: torch.dtype):
    """Returns peak TFLOPS for the given device name and dtype."""
    if device_name not in _DEVICE_SPECS:
        raise NotImplementedError(
            f"Unknown device name for getting hardware flops: {device_name}"
        )

    specs = _DEVICE_SPECS[device_name]
    if dtype not in specs[_TFLOPS]:
        raise NotImplementedError(f"Unknown dtype {dtype} for device {device_name}")

    return specs[_TFLOPS][dtype] * 1e12


def _get_model_flops_per_token(
    num_params: int,
    num_layers: Optional[int] = None,
    num_attention_heads: Optional[int] = None,
    attention_head_size: Optional[int] = None,
    sequence_length: Optional[int] = None,
    add_rematerialization: bool = False,
) -> int:
    """Returns the number of FLOPs per token for the given model configuration."""
    if num_params <= 0:
        raise ValueError(f"Must have a positive number of model params: {num_params}")

    forward_flops = 2 * num_params
    backward_flops = 4 * num_params
    attention_flops = 0
    if num_layers and num_attention_heads and attention_head_size and sequence_length:
        attention_flops = (
            sequence_length
            * num_layers
            * num_attention_heads
            * attention_head_size
            * 12
        )

    rematerialization_flops = 0
    if add_rematerialization:
        # FIXME: Needs to be calculated based on checkpointing configuration
        # 73% of forward and all of attention
        # PaLM paper mentions 75%, but the calculated value requires 73%, paper error?
        rematerialization_flops = int(0.73 * forward_flops + attention_flops)

    return forward_flops + backward_flops + attention_flops + rematerialization_flops


[docs] def calculate_mfu_from_model_flops_per_second( device_name: str, num_devices: int, dtype: torch.dtype, model_flops_per_second_on_all_devices: float, ) -> float: """Returns the number of MFU for the given model flops per second.""" if num_devices <= 0: raise ValueError(f"Must have a positive number of devices: {num_devices}") device_flops_per_second = _get_device_flops(device_name, dtype) * num_devices model_flop_utilization = ( model_flops_per_second_on_all_devices / device_flops_per_second ) return model_flop_utilization
[docs] def calculate_mfu( device_name: str, num_devices: int, dtype: torch.dtype, num_params: int, num_tokens: int, delta_time_seconds: float, num_layers: Optional[int] = None, num_attention_heads: Optional[int] = None, attention_head_size: Optional[int] = None, sequence_length: Optional[int] = None, add_rematerialization: bool = False, ) -> float: """Returns the number of MFU for the given model configuration.""" if num_tokens <= 0: raise ValueError(f"Must have a positive number of tokens: {num_tokens}") if delta_time_seconds <= 0: raise ValueError(f"Must have a positive delta time: {delta_time_seconds}") model_flops_per_token = _get_model_flops_per_token( num_params, num_layers, num_attention_heads, attention_head_size, sequence_length, add_rematerialization, ) tokens_per_second = num_tokens / delta_time_seconds model_flops_per_second = model_flops_per_token * tokens_per_second return calculate_mfu_from_model_flops_per_second( device_name=device_name, num_devices=num_devices, dtype=dtype, model_flops_per_second_on_all_devices=model_flops_per_second, )