# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import math
import os
import time
from contextlib import contextmanager
from pathlib import Path
from pprint import pformat
from typing import Any, Callable, Optional, cast
import pydantic
import safetensors.torch
import torch
import torch.amp
import torch.distributed.checkpoint as dcp
import torch.utils.tensorboard as tensorboard
import wandb # isort: skip
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_state_dict,
)
from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm.auto import tqdm
from transformers import TrainerCallback
from oumi.core.configs import MixedPrecisionDtype, TrainingConfig, TrainingParams
from oumi.core.configs.params.fsdp_params import FSDPParams, StateDictType
from oumi.core.distributed import (
barrier,
get_device_rank_info,
is_distributed,
is_local_process_zero,
is_world_process_zero,
prepare_model_for_distributed,
)
from oumi.core.processors.base_processor import BaseProcessor
from oumi.core.tokenizers import BaseTokenizer
from oumi.core.trainers.base_trainer import BaseTrainer
from oumi.models.layers.ring_attention import (
apply_zigzag_ring_attn_monkey_patch_llama as apply_ring_attention_monkey_patch,
)
from oumi.models.layers.ring_attention import (
prepare_zigzag_ring_attn_inputs as prepare_seq_parallel_inputs,
)
from oumi.performance.telemetry import TelemetryTracker
from oumi.utils.io_utils import load_json, save_json
from oumi.utils.logging import logger
from oumi.utils.torch_utils import log_trainable_parameters
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
class TrainingState(pydantic.BaseModel):
epoch: int = 0
global_step: int = 0
total_tokens_seen: int = 0
[docs]
class Trainer(BaseTrainer):
def __init__(
self,
model: torch.nn.Module,
tokenizer: Optional[BaseTokenizer],
args: TrainingParams,
train_dataset: Dataset,
processor: Optional[BaseProcessor] = None,
eval_dataset: Optional[Dataset] = None,
callbacks: Optional[list[TrainerCallback]] = None,
data_collator: Optional[Callable] = None,
config: Optional[TrainingConfig] = None,
**kwargs,
):
"""Initializes the Oumi trainer."""
# Importing these here to avoid circular dependencies
from oumi.builders.lr_schedules import build_lr_scheduler
from oumi.builders.optimizers import build_optimizer
self.telemetry = TelemetryTracker()
self.start_time = time.perf_counter()
self.collator_fn = data_collator
self.tokenizer = tokenizer
self._processor = processor
self.params = copy.deepcopy(args)
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.max_norm = (
float(args.max_grad_norm) if args.max_grad_norm is not None else None
)
self.config = config or TrainingConfig()
self.fsdp_params = self.config.fsdp or FSDPParams()
self.is_using_fsdp = self.fsdp_params.enable_fsdp
# TODO OPE-333 Define a param to enable ring attention + check pre-conditions:
# 1. Flash Attention (`is_ring_attention_available()`),
# 2. CUDA and distributed multi-GPU training (otherwise, pointless).
# 3. Supported model type.
self.is_using_ring_attention = False
self.params.finalize_and_validate()
self.state = TrainingState()
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
# Enable mixed precision bf16/fp16 training if requested.
# Model dtype has been verified to be fp32 if this is the case.
self.mixed_precision_ctx = contextlib.nullcontext()
mixed_precision_dtype = None
if self.params.mixed_precision_dtype == MixedPrecisionDtype.BF16:
mixed_precision_dtype = torch.bfloat16
elif self.params.mixed_precision_dtype == MixedPrecisionDtype.FP16:
mixed_precision_dtype = torch.float16
if mixed_precision_dtype:
self.mixed_precision_ctx = torch.amp.autocast(
device_type=self.device_type,
enabled=True,
dtype=mixed_precision_dtype,
)
# We want to enable gradient scaling for fp16 mixed precision training
# to prevent gradient underflows. This is not needed for bf16 since it has the
# same dynamic range as fp32. See here for details:
# https://pytorch.org/docs/stable/amp.html#gradient-scaling
self.scaler = torch.amp.GradScaler(
device=self.device_type,
enabled=self.params.mixed_precision_dtype == MixedPrecisionDtype.FP16,
)
device_info = get_device_rank_info()
# TODO: OPE-218 - give users fine-grained control on device placement
# TODO: OPE-217 - non-leader models should be on meta
if torch.cuda.is_available():
self.device = f"cuda:{device_info.local_rank}"
torch.cuda.set_device(self.device)
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
# ----------------------------------
# Prepare model for training
# ----------------------------------
if args.enable_gradient_checkpointing:
model.gradient_checkpointing_enable(args.gradient_checkpointing_kwargs)
model.to(self.device)
if is_distributed():
# Wrap model for distributed training
with self._telemetry_block("wrap model for distributed"):
model = prepare_model_for_distributed(
model,
self.config,
ddp_find_unused_parameters=self.params.ddp_find_unused_parameters,
)
# Apply ring attention monkey patch if enabled
if self.is_using_ring_attention:
apply_ring_attention_monkey_patch()
if self.params.compile:
self.log("Compiling model...")
with self._telemetry_block("compile model"):
model = cast(torch.nn.Module, torch.compile(model))
self.model = model
self.callbacks = callbacks if callbacks is not None else []
self.optimizer = build_optimizer(self.model, self.params)
self.lr_scheduler = build_lr_scheduler(
optimizer=self.optimizer,
training_params=self.params,
current_epoch=self.state.epoch,
num_training_steps=self._estimate_total_training_steps(),
)
self.train_dataloader = self._get_train_dataloader()
self.eval_dataloader = self._get_eval_dataloader() if eval_dataset else None
self._init_logging()
#
# Training
#
[docs]
def train(self, resume_from_checkpoint: Optional[str] = None):
"""Trains the model."""
if resume_from_checkpoint:
with torch.profiler.record_function("load_from_checkpoint"):
self._load_from_checkpoint(resume_from_checkpoint)
if is_local_process_zero():
log_trainable_parameters(self.model)
total_steps = self._estimate_total_training_steps()
self.start_time = time.perf_counter()
# Make sure all workers start at the same time.
barrier()
with tqdm(
total=total_steps,
desc="Training",
disable=not is_world_process_zero(),
) as progress_bar:
while True:
epoch = self.state.epoch
if self.params.max_steps > 0:
if self.state.global_step >= self.params.max_steps:
self.log(
f"Reached {self.state.global_step} global steps. "
"Training completed."
)
break
elif (
self.params.num_train_epochs > 0
and epoch >= self.params.num_train_epochs
):
self.log(f"Reached {epoch} epochs. Training completed.")
break
with torch.profiler.record_function(f"epoch_{epoch}"):
self._set_sampler_epoch(epoch)
self._train_epoch(progress_bar)
if self.params.save_epoch:
self.save_state()
if (
self.eval_dataloader
and self.params.eval_strategy == "epoch"
and is_world_process_zero()
):
# TODO: OPE-223 - only the global leader is used for evaluation
# To enable distributed evaluation, the eval function needs
# to be updated to aggregate metrics accross all workers.
self.evaluate()
self.state.epoch += 1
barrier()
self._process_callbacks("on_train_end")
self.log(
f"Training finished! Global step: {self.state.global_step} "
f"Training runtime: {time.perf_counter() - self.start_time}s"
)
@contextmanager
def _telemetry_block(self, name: str):
with (
torch.profiler.record_function(name) as record_function_context,
self.telemetry.timer(name) as timer_context,
):
yield (record_function_context, timer_context)
@staticmethod
def _cuda_sync_and_empty_cache() -> None:
if torch.cuda.is_available() and torch.cuda.is_initialized():
torch.cuda.synchronize()
torch.cuda.empty_cache()
def _train_epoch(self, progress_bar: tqdm) -> None:
"""Trains the model for one epoch."""
epoch_start_time = time.perf_counter()
self.model.train()
self._cuda_sync_and_empty_cache()
self.optimizer.zero_grad(set_to_none=True)
micro_step = 0
data_iter = iter(self.train_dataloader)
gradient_accumulation_steps = max(1, self.params.gradient_accumulation_steps)
while True:
with torch.profiler.record_function(
"microstep" if gradient_accumulation_steps > 1 else "step"
):
if micro_step % gradient_accumulation_steps == 0:
self._process_callbacks("on_step_begin")
# True if `max_steps` is configured and we reached the limit.
stop_on_max_steps_limit = (
self.params.max_steps > 0
and (self.state.global_step + 1) >= self.params.max_steps
)
# End of global step. May include multiple micro steps
# if gradient_accumulation_steps > 1.
end_of_global_step = (
(micro_step + 1) % gradient_accumulation_steps
) == 0
with self._telemetry_block("fetching batch"):
try:
batch = next(data_iter)
except StopIteration:
# FIXME Update metrics and log
self.log("End of epoch")
break
# Count tokens on CPU.
with self._telemetry_block("computing tokens"):
if self.tokenizer is not None and "input_ids" in batch:
num_tokens = (
batch["input_ids"]
.ne(self.tokenizer.pad_token_id)
.sum()
.item()
)
self.state.total_tokens_seen += num_tokens
with self._telemetry_block("moving batch to device"):
if not self.is_using_fsdp and not self.is_using_ring_attention:
batch = {
k: v.to(self.device, non_blocking=True)
for k, v in batch.items()
}
with self.mixed_precision_ctx, self._telemetry_block("model forward"):
self.model.require_backward_grad_sync = ( # type: ignore
end_of_global_step or stop_on_max_steps_limit
)
if self.is_using_ring_attention:
# Prepare inputs for ring attention
prepared_inputs = prepare_seq_parallel_inputs(
batch["input_ids"],
batch.get("position_ids"),
batch.get("labels"),
get_device_rank_info().rank,
get_device_rank_info().world_size,
self.device,
)
outputs = self.model(**prepared_inputs)
else:
outputs = self.model(**batch)
loss = outputs["loss"] / gradient_accumulation_steps
with self._telemetry_block("loss backward"):
self.scaler.scale(loss).backward()
if end_of_global_step or stop_on_max_steps_limit:
with self._telemetry_block("optimizer step"):
self.scaler.unscale_(self.optimizer)
if self.max_norm is not None and self.max_norm > 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self.max_norm
)
# save lr for logging
last_lr = self.lr_scheduler.get_last_lr()[0]
# step optimizer, scaler, and lr schedule
self.scaler.step(self.optimizer)
self.scaler.update()
self.lr_scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
self.state.global_step += 1
if self.params.telemetry.track_gpu_temperature:
self.telemetry.record_gpu_temperature()
progress_bar.update(1)
self._process_callbacks("on_step_end")
if (
self.params.logging_steps > 0
and not (
self.state.global_step == 1
and self.params.logging_first_step
)
and (
stop_on_max_steps_limit
or (self.state.global_step % self.params.logging_steps == 0)
)
):
# Log metrics
elapsed = time.perf_counter() - self.start_time
loss_value = loss.item() * gradient_accumulation_steps
metrics = {
"train/loss": loss_value,
"learning_rate": last_lr,
"epoch": self.state.epoch,
"global_step": self.state.global_step,
"total_tokens_seen": self.state.total_tokens_seen,
"global_steps_per_second": self.state.global_step / elapsed,
"tokens_per_second": self.state.total_tokens_seen / elapsed,
"tokens_per_step_per_gpu": self.state.total_tokens_seen
/ self.state.global_step,
}
callback_metrics = self._process_callbacks("on_log", metrics)
metrics.update(callback_metrics)
self.log_metrics(metrics, self.state.global_step)
if is_local_process_zero():
self.telemetry.print_summary()
if (
self.params.save_steps > 0
and self.state.global_step % self.params.save_steps == 0
):
self.save_state()
if (
self.eval_dataloader
and self.params.eval_steps > 0
and self.state.global_step % self.params.eval_steps == 0
and is_world_process_zero()
):
# TODO: OPE-223 - only the global leader is used for evaluation
# To enable distributed evaluation, th eval function needs
# to be updated to aggregate metrics accross all workers.
self.evaluate()
if stop_on_max_steps_limit:
self.log(f"Reached {self.params.max_steps} max steps condition.")
break
micro_step += 1
self.log(
f"End of epoch. "
f"Global step: {self.state.global_step}. "
f"Epoch runtime: {time.perf_counter() - epoch_start_time}s"
)
#
# Evaluation
#
[docs]
@torch.no_grad()
def evaluate(self) -> dict[str, float]:
"""Evaluates the model on the evaluation dataset."""
if self.eval_dataloader is None:
raise ValueError("No evaluation dataloader provided.")
self.model.eval()
eval_losses = []
for batch in tqdm(
self.eval_dataloader,
desc="Evaluating",
disable=not is_local_process_zero(),
):
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
eval_losses.append(outputs.loss.item())
eval_loss = sum(eval_losses) / len(eval_losses)
perplexity = torch.exp(torch.tensor(eval_loss))
results = {"val/loss": eval_loss, "val/perplexity": perplexity.item()}
self.log("Finished evaluation.")
self.log_metrics(results, self.state.global_step)
self.model.train()
return results
#
# Checkpointing
#
[docs]
def save_model(self, config: TrainingConfig, final: bool = True) -> None:
"""Saves the model."""
self._cuda_sync_and_empty_cache()
if is_world_process_zero():
output_dir = Path(config.training.output_dir)
output_dir.mkdir(exist_ok=True)
model_path = output_dir / "model.safetensors"
safetensors.torch.save_model(model=self.model, filename=str(model_path))
self.log(f"Model saved to {model_path}.")
if self._processor is not None:
self._processor.save_config(output_dir)
logger.info(f"Processor config has been saved at {output_dir}.")
self._cuda_sync_and_empty_cache()
[docs]
def save_state(self):
"""Saves the training state."""
self._cuda_sync_and_empty_cache()
checkpoint_dir = Path(self.params.output_dir)
if is_local_process_zero():
checkpoint_dir.mkdir(exist_ok=True)
if (
self.params.telemetry.collect_telemetry_for_all_ranks
or is_world_process_zero()
):
telemetry_dir = self.params.telemetry_dir
if telemetry_dir:
device_rank_info = get_device_rank_info()
telemetry_state_path = (
telemetry_dir / f"telemetry_rank{device_rank_info.rank:04}.json"
)
save_json(
data=self.telemetry.state_dict(),
filename=telemetry_state_path,
)
if self.is_using_fsdp:
storage_options = StateDictOptions(
full_state_dict=self.fsdp_params.state_dict_type
== StateDictType.FULL_STATE_DICT,
cpu_offload=self.fsdp_params.cpu_offload,
ignore_frozen_params=False,
strict=True,
broadcast_from_rank0=False, # TODO: make this configurable
)
else:
storage_options = None
model_state_dict, optimizer_state_dict = get_state_dict(
model=self.model,
optimizers=self.optimizer,
options=storage_options,
)
model_path = checkpoint_dir / "model"
optimizer_path = checkpoint_dir / "optimizer"
dataloader_state_path = checkpoint_dir / "dataloader.pt"
trainer_state_path = checkpoint_dir / "trainer_state.json"
dcp.save(model_state_dict, checkpoint_id=model_path)
dcp.save(optimizer_state_dict, checkpoint_id=optimizer_path)
if is_world_process_zero():
torch.save(self.train_dataloader.state_dict(), dataloader_state_path)
save_json(data=self.state.model_dump(), filename=trainer_state_path)
logger.info(f"Training state saved to {checkpoint_dir}")
self._cuda_sync_and_empty_cache()
def _load_from_checkpoint(self, checkpoint_dirname: str):
"""Loads the training state from a checkpoint."""
checkpoint_dir = Path(checkpoint_dirname)
device_rank_info = get_device_rank_info()
model_path = checkpoint_dir / "model"
optimizer_path = checkpoint_dir / "optimizer"
dataloader_state_path = checkpoint_dir / "dataloader.pt"
trainer_state_path = checkpoint_dir / "trainer_state.json"
telemetry_state_path = (
checkpoint_dir / f"telemetry_rank{device_rank_info.rank:04}.json"
)
if not checkpoint_dir.exists():
raise ValueError(f"Checkpoint directory does not exist: {checkpoint_dir}")
if not model_path.exists():
raise ValueError(
f"Invalid checkpoint, model state folder does not exist: {model_path}"
)
if not optimizer_path.exists():
raise ValueError(
"Invalid checkpoint, optimizer state folder does not exist: "
f"{optimizer_path}"
)
if self.is_using_fsdp:
storage_options = StateDictOptions(
full_state_dict=self.fsdp_params.state_dict_type
== StateDictType.FULL_STATE_DICT,
cpu_offload=self.fsdp_params.cpu_offload,
ignore_frozen_params=False,
strict=True,
broadcast_from_rank0=False,
)
else:
storage_options = None
model_state_dict, optimizer_state_dict = get_state_dict(
model=self.model,
optimizers=self.optimizer,
options=storage_options,
)
dcp.load(model_state_dict, checkpoint_id=model_path)
dcp.load(optimizer_state_dict, checkpoint_id=optimizer_path)
if dataloader_state_path.exists():
self.train_dataloader.load_state_dict(torch.load(dataloader_state_path))
if trainer_state_path.exists():
self.state = TrainingState.model_validate(
load_json(trainer_state_path), strict=True
)
if telemetry_state_path.exists():
self.telemetry.load_state_dict(load_json(telemetry_state_path))
self.log(f"Resumed training from checkpoint: {checkpoint_dirname}")
#
# Logging
#
[docs]
def log(self, message: str):
"""Logs a message if the process is the local process zero."""
if not is_local_process_zero():
return
logger.info(message)
[docs]
def log_metrics(self, metrics: dict[str, Any], step: int) -> None:
"""Logs metrics to wandb and tensorboard."""
# Log to console and log file
if not is_world_process_zero():
return
self.log(pformat(metrics))
# Log to Weights and Biases
if self.params.enable_wandb:
wandb.log(metrics, step=self.state.global_step)
# Log to tensorboard
if self.params.enable_tensorboard and self.tensorboard_writer:
for key, value in metrics.items():
self.tensorboard_writer.add_scalar(key, value, self.state.global_step)
def _init_logging(
self,
) -> None:
"""Initializes logging."""
if not is_world_process_zero():
return
self.log(f"Logging to {self.params.output_dir}")
if self.params.enable_wandb:
project_name = os.environ.get("WANDB_PROJECT", "oumi")
self.log(f"Logging to Weights and Biases project: '{project_name}'")
run = wandb.init(
project=project_name, name=self.params.run_name, job_type="train"
)
self.log(f"View wandb run {run.id} at: {run.get_url()}")
wandb.watch(self.model)
if self.params.enable_tensorboard:
tensorboard_folder = Path(self.params.output_dir) / "tensorboard"
self.log(f"Logging to tensorboard folder: '{tensorboard_folder}'")
self.tensorboard_writer = tensorboard.SummaryWriter(
log_dir=tensorboard_folder
)
else:
self.tensorboard_writer = None
#
# Data loading
#
def _get_train_dataloader(self) -> StatefulDataLoader:
"""Returns the training dataloader."""
# At this point, "auto" must be pre-resolved to `int`.
assert isinstance(self.params.dataloader_num_workers, int)
prefetch_factor = (
self.params.dataloader_prefetch_factor
if self.params.dataloader_num_workers > 0
else None
)
# IterDataPipe is a subclass of IterableDataset.
if isinstance(self.train_dataset, IterableDataset):
# TODO: configure sharding for iterable datasets
sampler = None
shuffle = None
else:
# Configure sampler for map datasets. If using multiple GPUs,
# we use a DistributedSampler to make sure each worker gets a
# different subset of the dataset.
# In non-distributed mode, we iterate over the full dataset.
if is_distributed():
# TODO: OPE-219 this strategy should only be enabled for DDP
# and FSDP with NO_SHARDING
device_info = get_device_rank_info()
# Distribute the dataset across all GPU workers
# Each rank will get a subset of the dataset
sampler = DistributedSampler(
self.train_dataset,
num_replicas=device_info.world_size,
rank=device_info.rank,
seed=self.params.seed,
shuffle=True,
)
shuffle = False
else:
# If not distributed, let the dataloader handle shuffling
sampler = None
shuffle = True
# Keeping track of the sampler so we can update after each epoch
self._sampler = sampler
return StatefulDataLoader(
self.train_dataset,
batch_size=self.params.per_device_train_batch_size,
shuffle=shuffle,
sampler=self._sampler,
num_workers=self.params.dataloader_num_workers,
pin_memory=self.device_type == "cuda",
prefetch_factor=prefetch_factor,
pin_memory_device=self.device,
snapshot_every_n_steps=self.params.save_steps,
collate_fn=self.collator_fn,
)
def _get_eval_dataloader(self) -> DataLoader:
"""Returns the evaluation dataloader."""
if not self.eval_dataset:
raise ValueError("No evaluation dataset provided.")
# At this point, "auto" must be pre-resolved to `int`.
assert isinstance(self.params.dataloader_num_workers, int)
return DataLoader(
self.eval_dataset,
batch_size=self.params.per_device_eval_batch_size,
shuffle=False,
num_workers=self.params.dataloader_num_workers,
collate_fn=self.collator_fn,
)
def _estimate_total_training_steps(self) -> int:
# If max_steps is set, use it.
if self.params.max_steps > 0:
return self.params.max_steps
num_epochs = self.params.num_train_epochs
if num_epochs > 0:
num_dataset_examples = 0
try:
if not isinstance(self.train_dataset, IterableDataset):
num_dataset_examples = len(self.train_dataset) # type: ignore
elif hasattr(self.train_dataset, "datapipe"):
# Hacky way to get examples count from
# MapToIterConverterIterDataPipe.
# FIXME Remove DataPipes OPE-811
num_dataset_examples = len(self.train_dataset.datapipe) # type: ignore
except Exception:
num_dataset_examples = 0
if num_dataset_examples > 0:
world_size = get_device_rank_info().world_size
batch_size = self.params.per_device_train_batch_size
steps_per_epoch_per_device = math.ceil(
float(num_dataset_examples) / (batch_size * world_size)
)
return int(num_epochs * max(steps_per_epoch_per_device, 1))
raise ValueError(
"Unable to estimate `total_training_steps` "
+ (f"in {num_epochs} epochs" if num_epochs > 0 else "")
+ ". Please define `max_steps` training parameter!"
)
def _set_sampler_epoch(self, epoch: int) -> None:
"""Sets the current epoch on sampler, if it exists and supports it."""
if self._sampler and hasattr(self._sampler, "set_epoch"):
self.log(f"Setting sampler epoch to {epoch}.")
self._sampler.set_epoch(epoch)
#
# Handle callbacks
#
def _process_callbacks(
self, event: str, logs: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
"""Process callbacks.
Extremely hacky way to handle HF callbacks.
Just here to unblock debugging with our MfuCallback
"""
logs = logs or {}
for callback in self.callbacks:
if hasattr(callback, event):
action = getattr(callback, event)
action(args=self.params, state=None, control=None, logs=logs)
return logs