oumi.core.trainers#
Core trainers module for the Oumi (Open Universal Machine Intelligence) library.
This module provides various trainer implementations for use in the Oumi framework. These trainers are designed to facilitate the training process for different types of models and tasks.
Example
>>> from oumi.core.trainers import Trainer
>>> trainer = Trainer(model=my_model, dataset=my_dataset)
>>> trainer.train()
Note
- For detailed information on each trainer, please refer to their respective
class documentation.
- class oumi.core.trainers.BaseTrainer[source]#
Bases:
ABC
- abstractmethod save_model(config: TrainingConfig, final: bool = True) None [source]#
Saves the model’s state dictionary to the specified output directory.
- Parameters:
config (TrainingConfig) – The Oumi training config.
final (bool) – Whether this is the final model being saved during training.
- Returns:
None
- class oumi.core.trainers.HuggingFaceTrainer(hf_trainer: Trainer, processor: BaseProcessor | None = None)[source]#
Bases:
BaseTrainer
- save_model(config: TrainingConfig, final: bool = True) None [source]#
Saves the model’s weights to the specified output directory.
- Parameters:
config – The Oumi training config.
final – Whether this is the final model being saved during training. - Applies optimizations for the final model checkpoint. - In the case of FSDP, this will always save the FULL_STATE_DICT instead of the default STATE_DICT.
- Returns:
None
- class oumi.core.trainers.Trainer(model: Module, processing_class: PreTrainedTokenizerBase | None, args: TrainingParams, train_dataset: Dataset, processor: BaseProcessor | None = None, eval_dataset: Dataset | None = None, callbacks: list[TrainerCallback] | None = None, data_collator: Callable | None = None, config: TrainingConfig | None = None, **kwargs)[source]#
Bases:
BaseTrainer
- log_metrics(metrics: dict[str, Any], step: int) None [source]#
Logs metrics to wandb and tensorboard.
- save_model(config: TrainingConfig, final: bool = True) None [source]#
Saves the model.
- class oumi.core.trainers.VerlGrpoTrainer(processing_class: PreTrainedTokenizerBase | None, config: TrainingConfig, reward_funcs: list[Callable], train_dataset: Dataset, eval_dataset: Dataset, cache_dir: str | Path = PosixPath('/home/runner/.cache/oumi/verl_datasets'), **kwargs)[source]#
Bases:
BaseTrainer
verl GRPO Trainer.
This class wraps verl’s RayPPOTrainer. This class’ name is misleading as it supports other RL algorithms as well, including GRPO, which we use here.
For documentation on the underlying verl RayPPOTrainer, see https://verl.readthedocs.io/en/latest/examples/config.html.
- save_model(config: TrainingConfig, final: bool = True) None [source]#
Saves the model.
- Parameters:
config – The Oumi training config.
final – Whether this is the final model being saved during training.