Training Configuration#
Introduction#
This guide covers the configuration options available for training in Oumi. The configuration system is designed to be:
Modular: Each aspect of training (model, data, optimization, etc.) is configured separately
Type-safe: All configuration options are validated at runtime
Flexible: Supports various training scenarios from single-GPU to distributed training
Extensible: Easy to add new configuration options and validate them
The configuration system is built on the TrainingConfig
class, which contains all training settings. This class is composed of several parameter classes:
Model Configuration: Model architecture and loading settings
Data Configuration: Dataset and data loading configuration
Training Configuration: Core training parameters
PEFT Configuration: Parameter-efficient fine-tuning options
FSDP Configuration: Distributed training settings
All configuration files in Oumi are YAML files, which provide a human-readable format for specifying training settings. The configuration system automatically validates these files and converts them to the appropriate Python objects.
Basic Structure#
A typical configuration file has this structure:
model: # Model settings
model_name: "HuggingFaceTB/SmolLM2-135M-Instruct"
trust_remote_code: true
data: # Dataset settings
train:
datasets:
- dataset_name: "your_dataset"
split: "train"
training: # Training parameters
output_dir: "output/my_run"
num_train_epochs: 3
learning_rate: 5e-5
peft: # Optional PEFT settings
peft_method: "lora"
lora_r: 8
fsdp: # Optional FSDP settings
enable_fsdp: false
Each section in the configuration file maps to a specific parameter class and contains settings relevant to that aspect of training. The following sections detail each configuration component.
Configuration Components#
Model Configuration#
Configure the model architecture and loading using the ModelParams
class:
model:
# Required
model_name: "meta-llama/Llama-2-7b-hf" # Model ID or path (REQUIRED)
# Model loading
adapter_model: null # Path to adapter model (auto-detected if model_name is adapter)
tokenizer_name: null # Custom tokenizer name/path (defaults to model_name)
tokenizer_pad_token: null # Override pad token
tokenizer_kwargs: {} # Additional tokenizer args
model_max_length: null # Max sequence length (positive int or null)
load_pretrained_weights: true # Load pretrained weights
trust_remote_code: false # Allow remote code execution (use with trusted models only)
# Model precision and hardware
torch_dtype_str: "float32" # Model precision (float32/float16/bfloat16/float64)
device_map: "auto" # Device placement strategy (auto/null)
compile: false # JIT compile model (use TrainingParams.compile for training)
# Attention and optimization
attn_implementation: null # Attention impl (null/sdpa/flash_attention_2/eager)
enable_liger_kernel: false # Enable Liger CUDA kernel for potential speedup
# Model behavior
chat_template: null # Chat formatting template
freeze_layers: [] # Layer names to freeze during training
# Additional settings
model_kwargs: {} # Additional model constructor args
Data Configuration#
Configure datasets and data loading using the DataParams
class. Each split (train
/validation
/test
) is configured using DatasetSplitParams
, and individual datasets are configured using DatasetParams
:
data:
train: # Training dataset configuration
datasets: # List of datasets for this split
- dataset_name: "text_sft" # Required: Dataset format/type
dataset_path: "/path/to/data" # Optional: Path for local datasets
subset: null # Optional: Dataset subset name
split: "train" # Dataset split (default: "train")
sample_count: null # Optional: Number of examples to sample
mixture_proportion: null # Optional: Proportion in mixture (0-1)
shuffle: false # Whether to shuffle before sampling
seed: null # Random seed for shuffling
shuffle_buffer_size: 1000 # Size of shuffle buffer
trust_remote_code: false # Trust remote code when loading
transform_num_workers: null # Workers for dataset processing
dataset_kwargs: {} # Additional dataset constructor args
# Split-level settings
collator_name: "text_with_padding" # Data collator type
pack: false # Pack text into constant-length chunks
stream: false # Enable dataset streaming
mixture_strategy: "first_exhausted" # Strategy for mixing datasets
seed: null # Random seed for mixing
use_torchdata: false # Use `torchdata` (experimental)
validation: # Optional validation dataset config
datasets:
- dataset_name: "text_sft"
dataset_path: "/path/to/val"
split: "validation"
Notes:
When using multiple datasets in a split with
mixture_proportion
:All datasets must specify a
mixture_proportion
The sum of all proportions must equal 1.0
The
mixture_strategy
determines how datasets are combined:first_exhausted
: Stops when any dataset is exhaustedall_exhausted
: Continues until all datasets are exhausted (may oversample)
When
pack
is enabled:stream
must also be enabledtarget_col
must be specified
All splits must use the same collator type if specified
If a collator is specified for validation/test, it must also be specified for train
Training Configuration#
Configure the training process using the TrainingParams
class:
training:
# Basic settings
output_dir: "output" # Directory for saving outputs
run_name: null # Unique identifier for the run
seed: 42 # Random seed for reproducibility
# Training duration
num_train_epochs: 3 # Number of training epochs
max_steps: -1 # Max training steps (-1 to use epochs)
# Batch size settings
per_device_train_batch_size: 8 # Training batch size per device
per_device_eval_batch_size: 8 # Evaluation batch size per device
gradient_accumulation_steps: 1 # Steps before weight update
# Optimization
learning_rate: 5e-5 # Initial learning rate
optimizer: "adamw_torch" # Optimizer type ("adam", "adamw", "adamw_torch", "adamw_torch_fused", "sgd", "adafactor")
# "adamw_8bit", "paged_adamw_8bit", "paged_adamw", "paged_adamw_32bit" (requires bitsandbytes)
weight_decay: 0.0 # Weight decay for regularization
max_grad_norm: 1.0 # Max gradient norm for clipping
# Optimizer specific settings
adam_beta1: 0.9 # Adam beta1 parameter
adam_beta2: 0.999 # Adam beta2 parameter
adam_epsilon: 1e-8 # Adam epsilon parameter
sgd_momentum: 0.0 # SGD momentum (if using SGD)
# Learning rate schedule
lr_scheduler_type: "linear" # LR scheduler type
warmup_ratio: null # Warmup ratio of total steps
warmup_steps: null # Number of warmup steps
# Mixed precision and performance
mixed_precision_dtype: "none" # Mixed precision type ("none", "fp16", "bf16")
compile: false # Whether to JIT compile model
enable_gradient_checkpointing: false # Trade compute for memory
# Checkpointing
save_steps: 500 # Save every N steps
save_epoch: false # Save at end of each epoch
save_final_model: true # Save model at end of training
resume_from_checkpoint: null # Path to resume from
try_resume_from_last_checkpoint: false # Try auto-resume from last checkpoint
# Evaluation
eval_strategy: "steps" # When to evaluate ("no", "steps", "epoch")
eval_steps: 500 # Evaluate every N steps
metrics_function: null # Name of metrics function to use
# Logging
log_level: "info" # Main logger level
dep_log_level: "warning" # Dependencies logger level
enable_wandb: false # Enable Weights & Biases logging
enable_tensorboard: true # Enable TensorBoard logging
logging_strategy: "steps" # When to log ("steps", "epoch", "no")
logging_steps: 50 # Log every N steps
logging_first_step: false # Log first step metrics
# DataLoader settings
dataloader_num_workers: 0 # Number of dataloader workers (int or "auto")
dataloader_prefetch_factor: null # Batches to prefetch per worker (requires workers > 0)
dataloader_main_process_only: null # Iterate dataloader on main process only (auto if null)
# Distributed training
ddp_find_unused_parameters: false # Find unused parameters in DDP
nccl_default_timeout_minutes: null # NCCL timeout in minutes
# Performance monitoring
include_performance_metrics: false # Include token statistics
include_alternative_mfu_metrics: false # Include alternative MFU metrics
log_model_summary: false # Print model layer summary
empty_device_cache_steps: null # Steps between cache clearing
PEFT Configuration#
Configure parameter-efficient fine-tuning using the PeftParams
class:
peft:
# LoRA settings
lora_r: 8 # Rank of update matrices
lora_alpha: 8 # Scaling factor
lora_dropout: 0.0 # Dropout probability
lora_target_modules: null # Modules to apply LoRA to
lora_modules_to_save: null # Modules to unfreeze and train
lora_bias: "none" # Bias training type
lora_task_type: "CAUSAL_LM" # Task type for adaptation
lora_init_weights: "DEFAULT" # Initialization of LoRA weights
# Q-LoRA settings
q_lora: false # Enable quantization
q_lora_bits: 4 # Quantization bits
bnb_4bit_quant_type: "fp4" # 4-bit quantization type
use_bnb_nested_quant: false # Use nested quantization
bnb_4bit_quant_storage: "uint8" # Storage type for params
bnb_4bit_compute_dtype: "float32" # Compute type for params
FSDP Configuration#
Configure fully sharded data parallel training using the FSDPParams
class:
fsdp:
enable_fsdp: false # Enable FSDP training
sharding_strategy: "FULL_SHARD" # How to shard model
cpu_offload: false # Offload to CPU
mixed_precision: null # Mixed precision type
backward_prefetch: "BACKWARD_PRE" # When to prefetch params
forward_prefetch: false # Prefetch forward results
use_orig_params: null # Use original module params
state_dict_type: "FULL_STATE_DICT" # Checkpoint format
# Auto wrapping settings
auto_wrap_policy: "NO_WRAP" # How to wrap layers
min_num_params: 100000 # Min params for wrapping
transformer_layer_cls: null # Transformer layer class
# Other settings
sync_module_states: true # Sync states across processes
Notes on FSDP sharding strategies:
FULL_SHARD
: Shards model parameters, gradients, and optimizer states. Most memory efficient but may impact performance.SHARD_GRAD_OP
: Shards gradients and optimizer states only. Balances memory and performance.HYBRID_SHARD
: Shards parameters within a node, replicates across nodes.NO_SHARD
: No sharding (use DDP instead).HYBRID_SHARD_ZERO2
: Uses SHARD_GRAD_OP within node, replicates across nodes.
Example Configurations#
You can find these examples and many more in the Recipes section.
We aim to provide a comprehensive (and growing) set of recipes for all the common training scenarios:
Full Fine-tuning (SFT)#
This example shows how to fine-tune a small model (‘SmolLM2-135M’) without any parameter-efficient methods:
configs/recipes/smollm/sft/135m/quickstart_train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
# SFT config for SmolLM 135M Instruct.
model:
model_name: "HuggingFaceTB/SmolLM2-135M-Instruct"
model_max_length: 2048
torch_dtype_str: "bfloat16"
attn_implementation: "sdpa"
load_pretrained_weights: True
trust_remote_code: True
data:
train:
datasets:
- dataset_name: "yahma/alpaca-cleaned"
target_col: "prompt"
training:
trainer_type: TRL_SFT
save_final_model: True
save_steps: 100
max_steps: 10
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
ddp_find_unused_parameters: False
optimizer: "adamw_torch"
learning_rate: 2.0e-05
compile: False
dataloader_num_workers: "auto"
dataloader_prefetch_factor: 32
logging_steps: 5
log_model_summary: False
empty_device_cache_steps: 50
output_dir: "output/smollm135m.fft"
include_performance_metrics: True
Parameter-Efficient Fine-tuning (LoRA)#
This example shows how to fine-tune a large model (‘Llama-3.1-70b’) using LoRA:
configs/recipes/llama3_1/sft/70b_lora/train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
# Lora config for Llama 70B.
# Borrows param values from:
# https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama3_1/70B_lora.yaml
model:
model_name: "meta-llama/Meta-Llama-3.1-70B-Instruct"
model_max_length: 2048
torch_dtype_str: "bfloat16"
attn_implementation: "sdpa"
load_pretrained_weights: True
trust_remote_code: True
data:
train:
datasets:
- dataset_name: "yahma/alpaca-cleaned" # 51,760 examples
target_col: "prompt"
use_async_dataset: True
training:
trainer_type: "TRL_SFT"
use_peft: True
save_steps: 200
num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
enable_gradient_checkpointing: True
gradient_checkpointing_kwargs:
use_reentrant: False
ddp_find_unused_parameters: False
optimizer: "adamw_torch_fused"
learning_rate: 3.0e-04
warmup_steps: 100
weight_decay: 0.01
compile: False
dataloader_num_workers: "auto"
dataloader_prefetch_factor: 32
logging_steps: 100
log_model_summary: False
empty_device_cache_steps: 50
output_dir: "output/llama70b.lora"
include_performance_metrics: True
enable_wandb: True
peft:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
- "q_proj"
- "k_proj"
- "v_proj"
fsdp:
enable_fsdp: True
forward_prefetch: True
auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
transformer_layer_cls: "LlamaDecoderLayer"
Distributed Training (FSDP)#
This example shows how to fine-tune a medium-sized model (‘Llama-3.1-8b’) using FSDP for distributed training:
configs/recipes/llama3_1/sft/8b_full/train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
# SFT config for Llama 8B.
# Borrows param values from:
# https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama3_1/8B_full.yaml
model:
model_name: "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_max_length: 8192
torch_dtype_str: "bfloat16"
attn_implementation: "sdpa"
load_pretrained_weights: True
trust_remote_code: True
# Improves training speed by 20% with default config.
enable_liger_kernel: True
data:
train:
datasets:
- dataset_name: "yahma/alpaca-cleaned"
target_col: "prompt"
use_async_dataset: True
training:
trainer_type: "TRL_SFT"
save_steps: 800
num_train_epochs: 3
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
enable_gradient_checkpointing: True
gradient_checkpointing_kwargs:
use_reentrant: False
ddp_find_unused_parameters: False
optimizer: "adamw_torch_fused"
learning_rate: 2.0e-05
compile: False
dataloader_num_workers: "auto"
dataloader_prefetch_factor: 32
logging_steps: 100
log_model_summary: False
empty_device_cache_steps: 50
output_dir: "output/llama8b.fft"
include_performance_metrics: True
enable_wandb: True
fsdp:
enable_fsdp: True
sharding_strategy: "HYBRID_SHARD"
forward_prefetch: True
auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
transformer_layer_cls: "LlamaDecoderLayer"
Vision-Language Fine-tuning#
This example shows how to fine-tune a vision-language model (‘LLaVA-7B’):
configs/recipes/vision/llava_7b/sft/train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
model:
model_name: "llava-hf/llava-1.5-7b-hf"
torch_dtype_str: "bfloat16"
model_max_length: 1024
trust_remote_code: True
attn_implementation: "sdpa"
chat_template: "llava"
freeze_layers:
- "vision_tower"
data:
train:
collator_name: "vision_language_with_padding"
use_torchdata: True
datasets:
- dataset_name: "merve/vqav2-small"
split: "validation"
shuffle: True
seed: 42
transform_num_workers: "auto"
dataset_kwargs:
processor_name: "llava-hf/llava-1.5-7b-hf"
# limit: 8192 # Uncomment to limit dataset size!
return_tensors: True
# Below are examples of other vision SFT datasets
# - dataset_name: "HuggingFaceH4/llava-instruct-mix-vsft"
# split: "train"
# shuffle: True
# seed: 42
# transform_num_workers: "auto"
# dataset_kwargs:
# processor_name: "llava-hf/llava-1.5-7b-hf"
# limit: 8192
# return_tensors: True
# - dataset_name: "coco_captions"
# split: "train"
# trust_remote_code: True
# dataset_kwargs:
# processor_name: "llava-hf/llava-1.5-7b-hf"
# limit: 8192
# return_tensors: True
# - dataset_name: vision_language_jsonl
# dataset_path: "training.jsonl" # See notebook for example how to generate this file
# dataset_kwargs:
# data_column: "messages"
# processor_name: "llava-hf/llava-1.5-7b-hf"
training:
output_dir: "output/vlm_finetuned"
trainer_type: "TRL_SFT"
enable_gradient_checkpointing: True
per_device_train_batch_size: 6
gradient_accumulation_steps: 8
max_steps: 20
gradient_checkpointing_kwargs:
# Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
use_reentrant: False
ddp_find_unused_parameters: False
empty_device_cache_steps: 2
compile: False
optimizer: "adamw_torch_fused"
learning_rate: 2e-5
warmup_ratio: 0.03
weight_decay: 0.0
lr_scheduler_type: "cosine"
logging_steps: 5
save_steps: 0
dataloader_num_workers: "auto"
dataloader_prefetch_factor: 32
include_performance_metrics: True
log_model_summary: False
enable_wandb: True