# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibimportcopyimportmathimportosimporttimefromcontextlibimportcontextmanagerfrompathlibimportPathfrompprintimportpformatfromtypingimportAny,Callable,Optional,castimportpydanticimportsafetensors.torchimporttorchimporttorch.ampimporttorch.distributed.checkpointasdcpimporttorch.utils.tensorboardastensorboardimportmlflow# isort: skipimportwandb# isort: skipfromtorch.distributed.checkpoint.state_dictimport(StateDictOptions,get_state_dict,)fromtorch.utils.dataimportDataLoader,Dataset,DistributedSampler,IterableDatasetfromtorchdata.stateful_dataloaderimportStatefulDataLoaderfromtqdm.autoimporttqdmfromtransformersimportTrainerCallbackfromoumi.core.configsimportMixedPrecisionDtype,TrainingConfig,TrainingParamsfromoumi.core.configs.params.fsdp_paramsimportFSDPParams,StateDictTypefromoumi.core.distributedimport(barrier,get_device_rank_info,is_distributed,is_local_process_zero,is_world_process_zero,prepare_model_for_distributed,)fromoumi.core.processors.base_processorimportBaseProcessorfromoumi.core.tokenizersimportBaseTokenizerfromoumi.core.trainers.base_trainerimportBaseTrainerfromoumi.models.layers.ring_attentionimport(apply_zigzag_ring_attn_monkey_patch_llamaasapply_ring_attention_monkey_patch,)fromoumi.models.layers.ring_attentionimport(prepare_zigzag_ring_attn_inputsasprepare_seq_parallel_inputs,)fromoumi.performance.telemetryimportTelemetryTrackerfromoumi.utils.io_utilsimportload_json,save_jsonfromoumi.utils.loggingimportloggertorch.backends.cuda.matmul.allow_tf32=True# allow tf32 on matmultorch.backends.cudnn.allow_tf32=True# allow tf32 on cudnnclassTrainingState(pydantic.BaseModel):epoch:int=0global_step:int=0total_tokens_seen:int=0
[docs]classTrainer(BaseTrainer):def__init__(self,model:torch.nn.Module,processing_class:Optional[BaseTokenizer],args:TrainingParams,train_dataset:Dataset,processor:Optional[BaseProcessor]=None,eval_dataset:Optional[Dataset]=None,callbacks:Optional[list[TrainerCallback]]=None,data_collator:Optional[Callable]=None,config:Optional[TrainingConfig]=None,**kwargs,):"""Initializes the Oumi trainer."""# Importing these here to avoid circular dependenciesfromoumi.builders.lr_schedulesimportbuild_lr_schedulerfromoumi.builders.optimizersimportbuild_optimizerself.telemetry=TelemetryTracker()self.start_time=time.perf_counter()self.collator_fn=data_collatorself.processing_class=processing_classself._processor=processorself.params=copy.deepcopy(args)self.train_dataset=train_datasetself.eval_dataset=eval_datasetself.max_norm=(float(args.max_grad_norm)ifargs.max_grad_normisnotNoneelseNone)self.config=configorTrainingConfig()self.fsdp_params=self.config.fsdporFSDPParams()self.is_using_fsdp=self.fsdp_params.enable_fsdp# TODO OPE-333 Define a param to enable ring attention + check pre-conditions:# 1. Flash Attention (`is_ring_attention_available()`),# 2. CUDA and distributed multi-GPU training (otherwise, pointless).# 3. Supported model type.self.is_using_ring_attention=Falseself.params.finalize_and_validate()self.state=TrainingState()self.device_type="cuda"iftorch.cuda.is_available()else"cpu"# Enable mixed precision bf16/fp16 training if requested.# Model dtype has been verified to be fp32 if this is the case.self.mixed_precision_ctx=contextlib.nullcontext()mixed_precision_dtype=Noneifself.params.mixed_precision_dtype==MixedPrecisionDtype.BF16:mixed_precision_dtype=torch.bfloat16elifself.params.mixed_precision_dtype==MixedPrecisionDtype.FP16:mixed_precision_dtype=torch.float16ifmixed_precision_dtype:self.mixed_precision_ctx=torch.amp.autocast(device_type=self.device_type,enabled=True,dtype=mixed_precision_dtype,)# We want to enable gradient scaling for fp16 mixed precision training# to prevent gradient underflows. This is not needed for bf16 since it has the# same dynamic range as fp32. See here for details:# https://pytorch.org/docs/stable/amp.html#gradient-scalingself.scaler=torch.amp.GradScaler(device=self.device_type,enabled=self.params.mixed_precision_dtype==MixedPrecisionDtype.FP16,)device_info=get_device_rank_info()# TODO: OPE-218 - give users fine-grained control on device placement# TODO: OPE-217 - non-leader models should be on metaiftorch.cuda.is_available():self.device=f"cuda:{device_info.local_rank}"torch.cuda.set_device(self.device)eliftorch.backends.mps.is_available():self.device="mps"else:self.device="cpu"# ----------------------------------# Prepare model for training# ----------------------------------ifargs.enable_gradient_checkpointing:model.gradient_checkpointing_enable(args.gradient_checkpointing_kwargs)model.to(self.device)ifis_distributed():# Wrap model for distributed trainingwithself._telemetry_block("wrap model for distributed"):model=prepare_model_for_distributed(model,self.config,ddp_find_unused_parameters=self.params.ddp_find_unused_parameters,)# Apply ring attention monkey patch if enabledifself.is_using_ring_attention:apply_ring_attention_monkey_patch()ifself.params.compile:self.log("Compiling model...")withself._telemetry_block("compile model"):model=cast(torch.nn.Module,torch.compile(model))self.model=modelself.callbacks=callbacksifcallbacksisnotNoneelse[]self.optimizer=build_optimizer(self.model,self.params)self.lr_scheduler=build_lr_scheduler(optimizer=self.optimizer,training_params=self.params,current_epoch=self.state.epoch,num_training_steps=self._estimate_total_training_steps(),)self.train_dataloader=self._get_train_dataloader()self.eval_dataloader=self._get_eval_dataloader()ifeval_datasetelseNoneself._init_logging()## Training#
[docs]deftrain(self,resume_from_checkpoint:Optional[str]=None):"""Trains the model."""ifresume_from_checkpoint:withtorch.profiler.record_function("load_from_checkpoint"):self._load_from_checkpoint(resume_from_checkpoint)total_steps=self._estimate_total_training_steps()self.start_time=time.perf_counter()# Make sure all workers start at the same time.barrier()withtqdm(total=total_steps,desc="Training",disable=notis_world_process_zero(),)asprogress_bar:whileTrue:epoch=self.state.epochifself.params.max_steps>0:ifself.state.global_step>=self.params.max_steps:self.log(f"Reached {self.state.global_step} global steps. ""Training completed.")breakelif(self.params.num_train_epochs>0andepoch>=self.params.num_train_epochs):self.log(f"Reached {epoch} epochs. Training completed.")breakwithtorch.profiler.record_function(f"epoch_{epoch}"):self._set_sampler_epoch(epoch)self._train_epoch(progress_bar)ifself.params.save_epoch:self.save_state()if(self.eval_dataloaderandself.params.eval_strategy=="epoch"andis_world_process_zero()):# TODO: OPE-223 - only the global leader is used for evaluation# To enable distributed evaluation, the eval function needs# to be updated to aggregate metrics accross all workers.self.evaluate()self.state.epoch+=1barrier()self._process_callbacks("on_train_end")self.log(f"Training finished! Global step: {self.state.global_step} "f"Training runtime: {time.perf_counter()-self.start_time}s")ifself.params.enable_mlflow:mlflow.end_run()
@contextmanagerdef_telemetry_block(self,name:str):with(torch.profiler.record_function(name)asrecord_function_context,self.telemetry.timer(name)astimer_context,):yield(record_function_context,timer_context)@staticmethoddef_cuda_sync_and_empty_cache()->None:iftorch.cuda.is_available()andtorch.cuda.is_initialized():torch.cuda.synchronize()torch.cuda.empty_cache()def_train_epoch(self,progress_bar:tqdm)->None:"""Trains the model for one epoch."""epoch_start_time=time.perf_counter()self.model.train()self._cuda_sync_and_empty_cache()self.optimizer.zero_grad(set_to_none=True)micro_step=0data_iter=iter(self.train_dataloader)gradient_accumulation_steps=max(1,self.params.gradient_accumulation_steps)whileTrue:withtorch.profiler.record_function("microstep"ifgradient_accumulation_steps>1else"step"):ifmicro_step%gradient_accumulation_steps==0:self._process_callbacks("on_step_begin")# True if `max_steps` is configured and we reached the limit.stop_on_max_steps_limit=(self.params.max_steps>0and(self.state.global_step+1)>=self.params.max_steps)# End of global step. May include multiple micro steps# if gradient_accumulation_steps > 1.end_of_global_step=((micro_step+1)%gradient_accumulation_steps)==0withself._telemetry_block("fetching batch"):try:batch=next(data_iter)exceptStopIteration:# FIXME Update metrics and logself.log("End of epoch")break# Count tokens on CPU.withself._telemetry_block("computing tokens"):ifself.processing_classisnotNoneand"input_ids"inbatch:num_tokens=(batch["input_ids"].ne(self.processing_class.pad_token_id).sum().item())self.state.total_tokens_seen+=num_tokenswithself._telemetry_block("moving batch to device"):ifnotself.is_using_fsdpandnotself.is_using_ring_attention:batch={k:v.to(self.device,non_blocking=True)fork,vinbatch.items()}withself.mixed_precision_ctx,self._telemetry_block("model forward"):self.model.require_backward_grad_sync=(# type: ignoreend_of_global_steporstop_on_max_steps_limit)ifself.is_using_ring_attention:# Prepare inputs for ring attentionprepared_inputs=prepare_seq_parallel_inputs(batch["input_ids"],batch.get("position_ids"),batch.get("labels"),get_device_rank_info().rank,get_device_rank_info().world_size,self.device,)outputs=self.model(**prepared_inputs)else:outputs=self.model(**batch)loss=outputs["loss"]/gradient_accumulation_stepswithself._telemetry_block("loss backward"):self.scaler.scale(loss).backward()ifend_of_global_steporstop_on_max_steps_limit:withself._telemetry_block("optimizer step"):self.scaler.unscale_(self.optimizer)ifself.max_normisnotNoneandself.max_norm>0:torch.nn.utils.clip_grad_norm_(self.model.parameters(),max_norm=self.max_norm)# save lr for logginglast_lr=self.lr_scheduler.get_last_lr()[0]# step optimizer, scaler, and lr scheduleself.scaler.step(self.optimizer)self.scaler.update()self.lr_scheduler.step()self.optimizer.zero_grad(set_to_none=True)self.state.global_step+=1ifself.params.telemetry.track_gpu_temperature:self.telemetry.record_gpu_temperature()progress_bar.update(1)self._process_callbacks("on_step_end")if(self.params.logging_steps>0andnot(self.state.global_step==1andself.params.logging_first_step)and(stop_on_max_steps_limitor(self.state.global_step%self.params.logging_steps==0))):# Log metricselapsed=time.perf_counter()-self.start_timeloss_value=loss.item()*gradient_accumulation_stepsmetrics={"train/loss":loss_value,"learning_rate":last_lr,"epoch":self.state.epoch,"global_step":self.state.global_step,"total_tokens_seen":self.state.total_tokens_seen,"global_steps_per_second":self.state.global_step/elapsed,"tokens_per_second":self.state.total_tokens_seen/elapsed,"tokens_per_step_per_gpu":self.state.total_tokens_seen/self.state.global_step,}callback_metrics=self._process_callbacks("on_log",metrics)metrics.update(callback_metrics)self.log_metrics(metrics,self.state.global_step)ifis_local_process_zero():self.telemetry.print_summary()if(self.params.save_steps>0andself.state.global_step%self.params.save_steps==0):self.save_state()if(self.eval_dataloaderandself.params.eval_steps>0andself.state.global_step%self.params.eval_steps==0andis_world_process_zero()):# TODO: OPE-223 - only the global leader is used for evaluation# To enable distributed evaluation, th eval function needs# to be updated to aggregate metrics accross all workers.self.evaluate()ifstop_on_max_steps_limit:self.log(f"Reached {self.params.max_steps} max steps condition.")breakmicro_step+=1self.log(f"End of epoch. "f"Global step: {self.state.global_step}. "f"Epoch runtime: {time.perf_counter()-epoch_start_time}s")## Evaluation#
[docs]@torch.no_grad()defevaluate(self)->dict[str,float]:"""Evaluates the model on the evaluation dataset."""ifself.eval_dataloaderisNone:raiseValueError("No evaluation dataloader provided.")self.model.eval()eval_losses=[]forbatchintqdm(self.eval_dataloader,desc="Evaluating",disable=notis_local_process_zero(),):batch={k:v.to(self.device)fork,vinbatch.items()}outputs=self.model(**batch)eval_losses.append(outputs.loss.item())eval_loss=sum(eval_losses)/len(eval_losses)perplexity=torch.exp(torch.tensor(eval_loss))results={"val/loss":eval_loss,"val/perplexity":perplexity.item()}self.log("Finished evaluation.")self.log_metrics(results,self.state.global_step)self.model.train()returnresults
## Checkpointing#
[docs]defsave_model(self,config:TrainingConfig,final:bool=True)->None:"""Saves the model."""self._cuda_sync_and_empty_cache()ifis_world_process_zero():output_dir=Path(config.training.output_dir)output_dir.mkdir(exist_ok=True)model_path=output_dir/"model.safetensors"safetensors.torch.save_model(model=self.model,filename=str(model_path))self.log(f"Model saved to {model_path}.")ifself._processorisnotNone:self._processor.save_config(output_dir)logger.info(f"Processor config has been saved at {output_dir}.")self._cuda_sync_and_empty_cache()
[docs]defsave_state(self):"""Saves the training state."""self._cuda_sync_and_empty_cache()checkpoint_dir=Path(self.params.output_dir)ifis_local_process_zero():checkpoint_dir.mkdir(exist_ok=True)if(self.params.telemetry.collect_telemetry_for_all_ranksoris_world_process_zero()):telemetry_dir=self.params.telemetry_diriftelemetry_dir:device_rank_info=get_device_rank_info()telemetry_state_path=(telemetry_dir/f"telemetry_rank{device_rank_info.rank:04}.json")save_json(data=self.telemetry.state_dict(),filename=telemetry_state_path,)ifself.is_using_fsdp:storage_options=StateDictOptions(full_state_dict=self.fsdp_params.state_dict_type==StateDictType.FULL_STATE_DICT,cpu_offload=self.fsdp_params.cpu_offload,ignore_frozen_params=False,strict=True,broadcast_from_rank0=False,# TODO: make this configurable)else:storage_options=Nonemodel_state_dict,optimizer_state_dict=get_state_dict(model=self.model,optimizers=self.optimizer,options=storage_options,)model_path=checkpoint_dir/"model"optimizer_path=checkpoint_dir/"optimizer"dataloader_state_path=checkpoint_dir/"dataloader.pt"trainer_state_path=checkpoint_dir/"trainer_state.json"dcp.save(model_state_dict,checkpoint_id=model_path)dcp.save(optimizer_state_dict,checkpoint_id=optimizer_path)ifis_world_process_zero():torch.save(self.train_dataloader.state_dict(),dataloader_state_path)save_json(data=self.state.model_dump(),filename=trainer_state_path)logger.info(f"Training state saved to {checkpoint_dir}")self._cuda_sync_and_empty_cache()
def_load_from_checkpoint(self,checkpoint_dirname:str):"""Loads the training state from a checkpoint."""checkpoint_dir=Path(checkpoint_dirname)device_rank_info=get_device_rank_info()model_path=checkpoint_dir/"model"optimizer_path=checkpoint_dir/"optimizer"dataloader_state_path=checkpoint_dir/"dataloader.pt"trainer_state_path=checkpoint_dir/"trainer_state.json"telemetry_state_path=(checkpoint_dir/f"telemetry_rank{device_rank_info.rank:04}.json")ifnotcheckpoint_dir.exists():raiseValueError(f"Checkpoint directory does not exist: {checkpoint_dir}")ifnotmodel_path.exists():raiseValueError(f"Invalid checkpoint, model state folder does not exist: {model_path}")ifnotoptimizer_path.exists():raiseValueError("Invalid checkpoint, optimizer state folder does not exist: "f"{optimizer_path}")ifself.is_using_fsdp:storage_options=StateDictOptions(full_state_dict=self.fsdp_params.state_dict_type==StateDictType.FULL_STATE_DICT,cpu_offload=self.fsdp_params.cpu_offload,ignore_frozen_params=False,strict=True,broadcast_from_rank0=False,)else:storage_options=Nonemodel_state_dict,optimizer_state_dict=get_state_dict(model=self.model,optimizers=self.optimizer,options=storage_options,)dcp.load(model_state_dict,checkpoint_id=model_path)dcp.load(optimizer_state_dict,checkpoint_id=optimizer_path)ifdataloader_state_path.exists():self.train_dataloader.load_state_dict(torch.load(dataloader_state_path))iftrainer_state_path.exists():self.state=TrainingState.model_validate(load_json(trainer_state_path),strict=True)iftelemetry_state_path.exists():self.telemetry.load_state_dict(load_json(telemetry_state_path))self.log(f"Resumed training from checkpoint: {checkpoint_dirname}")## Logging#
[docs]deflog(self,message:str):"""Logs a message if the process is the local process zero."""ifnotis_local_process_zero():returnlogger.info(message)
[docs]deflog_metrics(self,metrics:dict[str,Any],step:int)->None:"""Logs metrics to wandb and tensorboard."""# Log to console and log fileifnotis_world_process_zero():returnself.log(pformat(metrics))# Log to Weights and Biasesifself.params.enable_wandb:wandb.log(metrics,step=self.state.global_step)# Log to tensorboardifself.params.enable_tensorboardandself.tensorboard_writer:forkey,valueinmetrics.items():self.tensorboard_writer.add_scalar(key,value,self.state.global_step)
def_init_logging(self,)->None:"""Initializes logging."""ifnotis_world_process_zero():returnself.log(f"Logging to {self.params.output_dir}")ifself.params.enable_wandb:project_name=os.environ.get("WANDB_PROJECT","oumi")self.log(f"Logging to Weights and Biases project: '{project_name}'")run=wandb.init(project=project_name,name=self.params.run_name,job_type="train")self.log(f"View wandb run {run.id} at: {run.get_url()}")wandb.watch(self.model)ifself.params.enable_tensorboard:tensorboard_folder=Path(self.params.output_dir)/"tensorboard"self.log(f"Logging to tensorboard folder: '{tensorboard_folder}'")self.tensorboard_writer=tensorboard.SummaryWriter(log_dir=tensorboard_folder)else:self.tensorboard_writer=Noneifself.params.enable_mlflow:self.mlflow_run=mlflow.start_run()## Data loading#def_get_train_dataloader(self)->StatefulDataLoader:"""Returns the training dataloader."""# At this point, "auto" must be pre-resolved to `int`.assertisinstance(self.params.dataloader_num_workers,int)prefetch_factor=(self.params.dataloader_prefetch_factorifself.params.dataloader_num_workers>0elseNone)# IterDataPipe is a subclass of IterableDataset.ifisinstance(self.train_dataset,IterableDataset):# TODO: configure sharding for iterable datasetssampler=Noneshuffle=Noneelse:# Configure sampler for map datasets. If using multiple GPUs,# we use a DistributedSampler to make sure each worker gets a# different subset of the dataset.# In non-distributed mode, we iterate over the full dataset.ifis_distributed():# TODO: OPE-219 this strategy should only be enabled for DDP# and FSDP with NO_SHARDINGdevice_info=get_device_rank_info()# Distribute the dataset across all GPU workers# Each rank will get a subset of the datasetsampler=DistributedSampler(self.train_dataset,num_replicas=device_info.world_size,rank=device_info.rank,seed=self.params.seed,shuffle=True,)shuffle=Falseelse:# If not distributed, let the dataloader handle shufflingsampler=Noneshuffle=True# Keeping track of the sampler so we can update after each epochself._sampler=samplerreturnStatefulDataLoader(self.train_dataset,batch_size=self.params.per_device_train_batch_size,shuffle=shuffle,sampler=self._sampler,num_workers=self.params.dataloader_num_workers,pin_memory=self.device_type=="cuda",prefetch_factor=prefetch_factor,pin_memory_device=self.device,snapshot_every_n_steps=self.params.save_steps,collate_fn=self.collator_fn,)def_get_eval_dataloader(self)->DataLoader:"""Returns the evaluation dataloader."""ifnotself.eval_dataset:raiseValueError("No evaluation dataset provided.")# At this point, "auto" must be pre-resolved to `int`.assertisinstance(self.params.dataloader_num_workers,int)returnDataLoader(self.eval_dataset,batch_size=self.params.per_device_eval_batch_size,shuffle=False,num_workers=self.params.dataloader_num_workers,collate_fn=self.collator_fn,)def_estimate_total_training_steps(self)->int:# If max_steps is set, use it.ifself.params.max_steps>0:returnself.params.max_stepsnum_epochs=self.params.num_train_epochsifnum_epochs>0:num_dataset_examples=0try:ifnotisinstance(self.train_dataset,IterableDataset):num_dataset_examples=len(self.train_dataset)# type: ignoreelifhasattr(self.train_dataset,"datapipe"):# Hacky way to get examples count from# MapToIterConverterIterDataPipe.# FIXME Remove DataPipes OPE-811num_dataset_examples=len(self.train_dataset.datapipe)# type: ignoreexceptException:num_dataset_examples=0ifnum_dataset_examples>0:world_size=get_device_rank_info().world_sizebatch_size=self.params.per_device_train_batch_sizesteps_per_epoch_per_device=math.ceil(float(num_dataset_examples)/(batch_size*world_size))returnint(num_epochs*max(steps_per_epoch_per_device,1))raiseValueError("Unable to estimate `total_training_steps` "+(f"in {num_epochs} epochs"ifnum_epochs>0else"")+". Please define `max_steps` training parameter!")def_set_sampler_epoch(self,epoch:int)->None:"""Sets the current epoch on sampler, if it exists and supports it."""ifself._samplerandhasattr(self._sampler,"set_epoch"):self.log(f"Setting sampler epoch to {epoch}.")self._sampler.set_epoch(epoch)## Handle callbacks#def_process_callbacks(self,event:str,logs:Optional[dict[str,Any]]=None)->dict[str,Any]:"""Process callbacks. Extremely hacky way to handle HF callbacks. Just here to unblock debugging with our MfuCallback """logs=logsor{}forcallbackinself.callbacks:ifhasattr(callback,event):action=getattr(callback,event)action(args=self.params,state=None,control=None,logs=logs)returnlogs