# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from pprint import pformat
from typing import NamedTuple, Optional
from oumi.utils.logging import logger
try:
# The library is only useful for NVIDIA GPUs, and
# may not be installed for other vendors e.g., AMD
import pynvml # pyright: ignore[reportMissingImports]
except ModuleNotFoundError:
pynvml = None
# TODO: OPE-562 - Add support for `amdsmi.amdsmi_init()`` for AMD GPUs
def _initialize_pynvml() -> bool:
"""Attempts to initialize pynvml library. Returns True on success."""
global pynvml
if pynvml is None:
return False
try:
pynvml.nvmlInit()
except Exception:
logger.error(
"Failed to initialize pynvml library. All pynvml calls will be disabled."
)
pynvml = None
return pynvml is not None
def _initialize_pynvml_and_get_pynvml_device_count() -> Optional[int]:
"""Attempts to initialize pynvml library.
Returns device count on success, or None otherwise.
"""
global pynvml
# The call to `pynvml is None` is technically redundant but exists here
# to make pyright happy.
if pynvml is None or not _initialize_pynvml():
return None
return int(pynvml.nvmlDeviceGetCount())
[docs]
class NVidiaGpuRuntimeInfo(NamedTuple):
"""Contains misc NVIDIA GPU measurements and stats retrieved by `pynvml`.
The majority of fields are optional. You can control whether they are
populated by setting boolean query parameters of
`_get_nvidia_gpu_runtime_info_impl(, ...)` such as `memory`, `temperature`,
`fan_speed`, etc.
"""
device_index: int
"""Zero-based device index."""
device_count: int
"""Total number of GPU devices on this node."""
used_memory_mb: Optional[float] = None
"""Used GPU memory in MB."""
temperature: Optional[int] = None
"""GPU temperature in Celcius."""
fan_speed: Optional[int] = None
"""GPU fan speed in [0,100] range."""
fan_speeds: Optional[Sequence[int]] = None
"""An array of GPU fan speeds.
The array's length is equal to the number of fans per GPU (can be multiple).
Speed values are in [0, 100] range.
"""
power_usage_watts: Optional[float] = None
"""GPU power usage in Watts."""
power_limit_watts: Optional[float] = None
"""GPU power limit in Watts."""
gpu_utilization: Optional[int] = None
"""GPU compute utilization. Range: [0,100]."""
memory_utilization: Optional[int] = None
"""GPU memory utilization. Range: [0,100]."""
performance_state: Optional[int] = None
"""See `nvmlPstates_t`. Valid values are in [0,15] range, or 32 if unknown.
0 corresponds to Maximum Performance.
15 corresponds to Minimum Performance.
"""
clock_speed_graphics: Optional[int] = None
"""Graphics clock speed (`NVML_CLOCK_GRAPHICS`) in MHz."""
clock_speed_sm: Optional[int] = None
"""SM clock speed (`NVML_CLOCK_SM`) in MHz."""
clock_speed_memory: Optional[int] = None
"""Memory clock speed (`NVML_CLOCK_MEM`) in MHz."""
def _get_nvidia_gpu_runtime_info_impl(
device_index: int = 0,
*,
memory: bool = False,
temperature: bool = False,
fan_speed: bool = False,
power_usage: bool = False,
utilization: bool = False,
performance_state: bool = False,
clock_speed: bool = False,
) -> Optional[NVidiaGpuRuntimeInfo]:
global pynvml
if pynvml is None:
return None
device_count = _initialize_pynvml_and_get_pynvml_device_count()
if device_count is None or device_count <= 0:
return None
elif device_index < 0 or device_index >= device_count:
raise ValueError(
f"Device index ({device_index}) must be "
f"within the [0, {device_count}) range."
)
try:
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
except Exception:
logger.exception(f"Failed to get GPU handle for device: {device_index}")
return None
used_memory_mb_value: Optional[float] = None
if memory:
try:
info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
used_memory_mb_value = float(info.used) // 1024**2
except Exception:
logger.exception(
f"Failed to get GPU memory info for device: {device_index}"
)
return None
temperature_value: Optional[int] = None
if temperature:
try:
temperature_value = pynvml.nvmlDeviceGetTemperature(
gpu_handle, pynvml.NVML_TEMPERATURE_GPU
)
except Exception:
logger.exception(
f"Failed to get GPU temperature for device: {device_index}"
)
return None
fan_speed_value: Optional[int] = None
fan_speeds_value: Optional[Sequence[int]] = None
if fan_speed:
try:
fan_speed_value = pynvml.nvmlDeviceGetFanSpeed(gpu_handle)
except Exception:
# The `GetFanSpeed` function fails on many systems
# Only do DEBUG-level logging to reduce noise.
logger.debug(
f"Failed to get GPU fan speed for device: {device_index}", exc_info=True
)
if fan_speed_value is not None:
fan_speeds_value = tuple([fan_speed_value])
if hasattr(pynvml, "nvmlDeviceGetNumFans"):
try:
fan_count = pynvml.nvmlDeviceGetNumFans(gpu_handle)
value = [0] * fan_count
for i in range(fan_count):
speed = pynvml.nvmlDeviceGetFanSpeed_v2(gpu_handle, i)
value[i] = speed
# Make it immutable.
fan_speeds_value = tuple(value)
except Exception:
fan_speeds_value = tuple([fan_speed_value])
power_usage_watts_value: Optional[float] = None
power_limit_watts_value: Optional[float] = None
if power_usage:
try:
milliwatts = pynvml.nvmlDeviceGetPowerUsage(gpu_handle)
power_usage_watts_value = float(milliwatts) * 1e-3
milliwatts = pynvml.nvmlDeviceGetPowerManagementLimit(gpu_handle)
power_limit_watts_value = float(milliwatts) * 1e-3
except Exception:
logger.exception(
f"Failed to get GPU power usage for device: {device_index}"
)
return None
gpu_utilization_value: Optional[float] = None
memory_utilization_value: Optional[float] = None
if utilization:
try:
result = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle)
gpu_utilization_value = int(result.gpu)
memory_utilization_value = int(result.memory)
except Exception:
logger.exception(
f"Failed to get GPU utilization for device: {device_index}"
)
return None
performance_state_value: Optional[int] = None
if performance_state:
try:
performance_state_value = int(
pynvml.nvmlDeviceGetPerformanceState(gpu_handle)
)
except Exception:
logger.exception(
f"Failed to get GPU performance state for device: {device_index}"
)
return None
clock_speed_graphics_value: Optional[int] = None
clock_speed_sm_value: Optional[int] = None
clock_speed_memory_value: Optional[int] = None
if clock_speed:
try:
clock_speed_graphics_value = int(
pynvml.nvmlDeviceGetClockInfo(gpu_handle, pynvml.NVML_CLOCK_GRAPHICS)
)
clock_speed_sm_value = int(
pynvml.nvmlDeviceGetClockInfo(gpu_handle, pynvml.NVML_CLOCK_SM)
)
clock_speed_memory_value = int(
pynvml.nvmlDeviceGetClockInfo(gpu_handle, pynvml.NVML_CLOCK_MEM)
)
except Exception:
logger.exception(
f"Failed to get GPU clock speed for device: {device_index}"
)
return None
return NVidiaGpuRuntimeInfo(
device_index=device_index,
device_count=device_count,
used_memory_mb=used_memory_mb_value,
temperature=temperature_value,
fan_speed=fan_speed_value,
fan_speeds=fan_speeds_value,
power_usage_watts=power_usage_watts_value,
power_limit_watts=power_limit_watts_value,
gpu_utilization=gpu_utilization_value,
memory_utilization=memory_utilization_value,
performance_state=performance_state_value,
clock_speed_graphics=clock_speed_graphics_value,
clock_speed_sm=clock_speed_sm_value,
clock_speed_memory=clock_speed_memory_value,
)
[docs]
def get_nvidia_gpu_runtime_info(
device_index: int = 0,
) -> Optional[NVidiaGpuRuntimeInfo]:
"""Returns runtime stats for Nvidia GPU."""
return _get_nvidia_gpu_runtime_info_impl(
device_index=device_index,
memory=True,
temperature=True,
fan_speed=True,
power_usage=True,
utilization=True,
performance_state=True,
clock_speed=True,
)
[docs]
def log_nvidia_gpu_runtime_info(device_index: int = 0, log_prefix: str = "") -> None:
"""Prints the current NVIDIA GPU runtime info."""
info = get_nvidia_gpu_runtime_info(device_index)
logger.info(f"{log_prefix.rstrip()} GPU runtime info: {pformat(info)}.")
[docs]
def get_nvidia_gpu_memory_utilization(device_index: int = 0) -> float:
"""Returns amount of memory being used on an Nvidia GPU in MiB."""
info: Optional[NVidiaGpuRuntimeInfo] = _get_nvidia_gpu_runtime_info_impl(
device_index=device_index, memory=True
)
return (
info.used_memory_mb
if (info is not None and info.used_memory_mb is not None)
else 0.0
)
[docs]
def log_nvidia_gpu_memory_utilization(
device_index: int = 0, log_prefix: str = ""
) -> None:
"""Prints amount of memory being used on an Nvidia GPU."""
memory_mib = get_nvidia_gpu_memory_utilization(device_index)
logger.info(f"{log_prefix.rstrip()} GPU memory occupied: {memory_mib} MiB.")
[docs]
def get_nvidia_gpu_temperature(device_index: int = 0) -> int:
"""Returns the current temperature readings for the device, in degrees C."""
info: Optional[NVidiaGpuRuntimeInfo] = _get_nvidia_gpu_runtime_info_impl(
device_index=device_index,
temperature=True,
)
return (
info.temperature if (info is not None and info.temperature is not None) else 0
)
[docs]
def log_nvidia_gpu_temperature(device_index: int = 0, log_prefix: str = "") -> None:
"""Prints the current temperature readings for the device, in degrees C."""
temperature = get_nvidia_gpu_temperature(device_index)
logger.info(f"{log_prefix.rstrip()} GPU temperature: {temperature} C.")
[docs]
def get_nvidia_gpu_fan_speeds(device_index: int = 0) -> Sequence[int]:
"""Returns the current fan speeds for NVIDIA GPU device."""
info: Optional[NVidiaGpuRuntimeInfo] = _get_nvidia_gpu_runtime_info_impl(
device_index=device_index, fan_speed=True
)
return (
info.fan_speeds
if (info is not None and info.fan_speeds is not None)
else tuple()
)
[docs]
def log_nvidia_gpu_fan_speeds(device_index: int = 0, log_prefix: str = "") -> None:
"""Prints the current NVIDIA GPU fan speeds."""
fan_speeds = get_nvidia_gpu_fan_speeds(device_index)
logger.info(f"{log_prefix.rstrip()} GPU fan speeds: {fan_speeds}.")
[docs]
def get_nvidia_gpu_power_usage(device_index: int = 0) -> float:
"""Returns the current power usage for NVIDIA GPU device."""
info: Optional[NVidiaGpuRuntimeInfo] = _get_nvidia_gpu_runtime_info_impl(
device_index=device_index, power_usage=True
)
return (
info.power_usage_watts
if (info is not None and info.power_usage_watts is not None)
else 0.0
)
[docs]
def log_nvidia_gpu_power_usage(device_index: int = 0, log_prefix: str = "") -> None:
"""Prints the current NVIDIA GPU power usage."""
power_usage = get_nvidia_gpu_power_usage(device_index)
logger.info(f"{log_prefix.rstrip()} GPU power usage: {power_usage:.2f}W.")