Source code for oumi.core.callbacks.telemetry_callback

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collects sub-step/step/epoch timings."""

import copy
import pathlib
import sys
from pprint import pformat
from typing import Optional, Union

import transformers

import wandb  # isort: skip
from oumi.core.callbacks.base_trainer_callback import BaseTrainerCallback
from oumi.core.configs import TrainingParams
from oumi.core.distributed import get_device_rank_info, is_world_process_zero
from oumi.performance.telemetry import TelemetryTracker, TimerContext
from oumi.utils.device_utils import (
    log_nvidia_gpu_runtime_info,
)
from oumi.utils.io_utils import save_json
from oumi.utils.logging import logger

_LOGS_KWARG = "logs"


[docs] class TelemetryCallback(BaseTrainerCallback): """Trainer callback to collect sub-step/step/epoch timings. Based on `oumi.performance.telemetry.TelemetryTracker`. """ def __init__( self, skip_first_steps: int = 1, world_process_zero_only: bool = True, include_timer_metrics: bool = False, track_gpu_temperature: bool = False, output_dir: Optional[pathlib.Path] = None, ): """Initializes the TelemetryCallback. Args: skip_first_steps: The number of initial steps to exclude from stats. world_process_zero_only: Whether to collect stats on the main process only. include_timer_metrics: Whether to add timer stats to reported metrics. The timings stats can be verbose/distracting, so `False` by default. The timings will be written to a file at the end of training regardless of the value of this flag. track_gpu_temperature: Whether to record GPU temperature. output_dir: If specified, then telemetry stats will be written to the directory as JSON files. """ self._telemetry = TelemetryTracker() self._microstep_timer: Optional[TimerContext] = None self._step_timer: Optional[TimerContext] = None self._epoch_timer: Optional[TimerContext] = None self._skip_first_steps: int = skip_first_steps self._include_timer_metrics = include_timer_metrics self._track_gpu_temperature = track_gpu_temperature self._output_dir: Optional[pathlib.Path] = output_dir self._permanently_disabled: bool = ( world_process_zero_only and not is_world_process_zero() ) self._world_process_zero_only = world_process_zero_only self._step: int = 0 self._last_metrics_dict: Optional[dict[str, float]] = None
[docs] def on_step_begin( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the beginning of a training step. If using gradient accumulation, one training step might take several inputs. """ self._step += 1 if self._callback_disabled(): return self._complete_previous_microstep_if_needed() self._start_microstep() self._complete_previous_step_if_needed() self._start_step()
[docs] def on_substep_end( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the end of a substep during gradient accumulation.""" if self._callback_disabled(): return self._complete_previous_microstep_if_needed() self._start_microstep()
[docs] def on_step_end( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the end of each train step. Note that this will be called after all gradient accumulation substeps. """ if self._callback_disabled(): return self._complete_previous_microstep_if_needed() self._complete_previous_step_if_needed() if self._track_gpu_temperature: self._telemetry.record_gpu_temperature()
[docs] def on_epoch_begin( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the beginning of an epoch.""" if self._permanently_disabled: return self._complete_previous_epoch_if_needed() self._start_epoch()
[docs] def on_epoch_end( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the end of an epoch.""" if self._permanently_disabled: return self._complete_previous_epoch_if_needed() log_nvidia_gpu_runtime_info(log_prefix="On epoch end:")
[docs] def on_log( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called after logging the last logs.""" if self._callback_disabled(): return device_rank_info = get_device_rank_info() basename = f"telemetry_rank{device_rank_info.rank:03}" summary = self._telemetry.get_summary() if ( self._include_timer_metrics and "timers" in summary and _LOGS_KWARG in kwargs ): for name, stats in summary["timers"].items(): for stats_key in ("mean", "median", "std_dev", "min", "max", "count"): if stats_key in stats: metric_name = f"{basename}_{name}_{stats_key}" kwargs[_LOGS_KWARG][metric_name] = float(stats[stats_key]) if ( self._track_gpu_temperature and "gpu_temperature" in summary and summary["gpu_temperature"] and _LOGS_KWARG in kwargs ): stats = summary["gpu_temperature"] for stats_key in ("mean", "median", "std_dev", "min", "max", "count"): metric_name = f"{basename}_gpu_temperature_{stats_key}" kwargs[_LOGS_KWARG][metric_name] = float(stats[stats_key]) if _LOGS_KWARG in kwargs and is_world_process_zero(): self._last_metrics_dict = copy.deepcopy(kwargs[_LOGS_KWARG])
[docs] def on_train_end( self, args: Union[transformers.TrainingArguments, TrainingParams], state: Optional[transformers.TrainerState] = None, control: Optional[transformers.TrainerControl] = None, **kwargs, ): """Event called at the end of training.""" if self._callback_disabled() or not self._output_dir: return device_rank_info = get_device_rank_info() if is_world_process_zero(): metrics_dict = self._last_metrics_dict or {} save_json( metrics_dict, self._output_dir / f"telemetry_callback_metrics_rank{device_rank_info.rank:04}.json", ) if wandb.run: save_json( { "id": wandb.run.id, "name": wandb.run.name, "url": wandb.run.get_url(), }, self._output_dir / f"telemetry_callback_wandb_rank{device_rank_info.rank:04}.json", ) if self._world_process_zero_only: if is_world_process_zero(): summary = self._telemetry.get_summary() telemetry_file = ( self._output_dir / f"telemetry_callback_rank{device_rank_info.rank:04}.json" ) logger.info(f"Saving telemetry callback summary to {telemetry_file}...") save_json(summary, telemetry_file) else: # The function has to be called by all ranks. summaries = self._telemetry.get_summaries_from_all_ranks() if is_world_process_zero(): summaries_dict = { f"rank{rank:04}": summary for rank, summary in enumerate(summaries) } telemetry_file = self._output_dir / "telemetry_callback_all_ranks.json" logger.info( "Saving telemetry callback summaries " f"for all ranks to {telemetry_file}..." ) save_json(summaries_dict, telemetry_file) gpu_temperature_info_dict = ( self._telemetry.compute_cross_rank_summaries( summaries, measurement_names={ "gpu_temperature": {"max", "mean", "median"}, }, ) ) logger.info( f"GPU temperature summary:\n{pformat(gpu_temperature_info_dict)}" ) save_json( gpu_temperature_info_dict, self._output_dir / "telemetry_callback_gpu_temperature_summary.json", )
def _callback_disabled(self) -> bool: """Check if the callback should be disabled.""" if self._permanently_disabled: return True if self._skip_first_steps > 0 and self._step <= self._skip_first_steps: return True return False @staticmethod def _exit_timer_if_needed(timer: Optional[TimerContext]) -> Optional[TimerContext]: if timer is not None: timer.__exit__(*sys.exc_info()) return None def _start_timer(self, timer_name: str) -> TimerContext: timer: TimerContext = self._telemetry.timer(timer_name) timer.__enter__() return timer def _complete_previous_microstep_if_needed(self): self._microstep_timer = TelemetryCallback._exit_timer_if_needed( self._microstep_timer ) def _start_microstep(self): self._microstep_timer = self._start_timer("microsteps") def _complete_previous_step_if_needed(self): self._step_timer = TelemetryCallback._exit_timer_if_needed(self._step_timer) def _start_step(self): self._step_timer = self._start_timer("steps") def _complete_previous_epoch_if_needed(self): self._epoch_timer = TelemetryCallback._exit_timer_if_needed(self._epoch_timer) def _start_epoch(self): self._epoch_timer = self._start_timer("epochs")