Source code for oumi.core.callbacks.nan_inf_detection_callback
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A callback to detect NaN/INF metric values."""
import copy
from typing import Optional, Union
import numpy as np
import transformers
from oumi.core.callbacks.base_trainer_callback import BaseTrainerCallback
from oumi.core.configs import TrainingParams
from oumi.utils.logging import logger
_LOGS_KWARG = "logs"
[docs]
class NanInfDetectionCallback(BaseTrainerCallback):
"""Trainer callback to detect abnormal values (NaN, INF) of selected metrics.
For example, `NaN` loss value is an almost certain indication of a training process
going badly, in which cases it's best to detect the condition early, and fail.
"""
def __init__(
self,
metrics: list[str],
):
"""Initializes the NanInfDetectionCallback.
Args:
metrics: The list of metrics to monitor.
"""
self._metrics = copy.deepcopy(metrics)
[docs]
def on_log(
self,
args: Union[transformers.TrainingArguments, TrainingParams],
state: Optional[transformers.TrainerState] = None,
control: Optional[transformers.TrainerControl] = None,
**kwargs,
):
"""Event called after logging the last logs."""
metrics_dict = kwargs.pop(_LOGS_KWARG, None)
if not metrics_dict:
return
# Now check for NaN or Inf.
for metric in self._metrics:
metric_val = metrics_dict.get(metric, None)
if metric_val is not None and (
np.isnan(metric_val) or np.isinf(metric_val)
):
error_message = (
"NaN" if np.isnan(metric_val) else "INF"
) + f" is detected for the '{metric}' metric."
logger.error(error_message)
raise RuntimeError(error_message)