Training Configuration#

Introduction#

This guide covers the configuration options available for training in Oumi. The configuration system is designed to be:

  • Modular: Each aspect of training (model, data, optimization, etc.) is configured separately

  • Type-safe: All configuration options are validated at runtime

  • Flexible: Supports various training scenarios from single-GPU to distributed training

  • Extensible: Easy to add new configuration options and validate them

The configuration system is built on the TrainingConfig class, which contains all training settings. This class is composed of several parameter classes:

All configuration files in Oumi are YAML files, which provide a human-readable format for specifying training settings. The configuration system automatically validates these files and converts them to the appropriate Python objects.

Basic Structure#

A typical configuration file has this structure:

model:  # Model settings
  model_name: "HuggingFaceTB/SmolLM2-135M-Instruct"
  trust_remote_code: true

data:   # Dataset settings
  train:
    datasets:
      - dataset_name: "your_dataset"
        split: "train"

training:  # Training parameters
  output_dir: "output/my_run"
  num_train_epochs: 3
  learning_rate: 5e-5

peft:  # Optional PEFT settings
  peft_method: "lora"
  lora_r: 8

fsdp:  # Optional FSDP settings
  enable_fsdp: false

Each section in the configuration file maps to a specific parameter class and contains settings relevant to that aspect of training. The following sections detail each configuration component.

Configuration Components#

Model Configuration#

Configure the model architecture and loading using the ModelParams class:

model:
  # Required
  model_name: "meta-llama/Llama-2-7b-hf"    # Model ID or path (REQUIRED)

  # Model loading
  adapter_model: null                        # Path to adapter model (auto-detected if model_name is adapter)
  tokenizer_name: null                       # Custom tokenizer name/path (defaults to model_name)
  tokenizer_pad_token: null                  # Override pad token
  tokenizer_kwargs: {}                       # Additional tokenizer args
  model_max_length: null                     # Max sequence length (positive int or null)
  load_pretrained_weights: true              # Load pretrained weights
  trust_remote_code: false                   # Allow remote code execution (use with trusted models only)

  # Model precision and hardware
  torch_dtype_str: "float32"                 # Model precision (float32/float16/bfloat16/float64)
  device_map: "auto"                         # Device placement strategy (auto/null)
  compile: false                             # JIT compile model (use TrainingParams.compile for training)

  # Attention and optimization
  attn_implementation: null                  # Attention impl (null/sdpa/flash_attention_2/eager)
  enable_liger_kernel: false                 # Enable Liger CUDA kernel for potential speedup

  # Model behavior
  chat_template: null                        # Chat formatting template
  freeze_layers: []                          # Layer names to freeze during training

  # Additional settings
  model_kwargs: {}                           # Additional model constructor args

Data Configuration#

Configure datasets and data loading using the DataParams class. Each split (train/validation/test) is configured using DatasetSplitParams, and individual datasets are configured using DatasetParams:

data:
  train:  # Training dataset configuration
    datasets:  # List of datasets for this split
      - dataset_name: "text_sft"            # Required: Dataset format/type
        dataset_path: "/path/to/data"       # Optional: Path for local datasets
        subset: null                        # Optional: Dataset subset name
        split: "train"                      # Dataset split (default: "train")
        sample_count: null                  # Optional: Number of examples to sample
        mixture_proportion: null            # Optional: Proportion in mixture (0-1)
        shuffle: false                      # Whether to shuffle before sampling
        seed: null                          # Random seed for shuffling
        shuffle_buffer_size: 1000           # Size of shuffle buffer
        trust_remote_code: false            # Trust remote code when loading
        transform_num_workers: null         # Workers for dataset processing
        dataset_kwargs: {}                  # Additional dataset constructor args

    # Split-level settings
    collator_name: "text_with_padding"      # Data collator type
    pack: false                             # Pack text into constant-length chunks
    stream: false                           # Enable dataset streaming
    mixture_strategy: "first_exhausted"     # Strategy for mixing datasets
    seed: null                              # Random seed for mixing
    use_torchdata: false                    # Use `torchdata` (experimental)

  validation:  # Optional validation dataset config
    datasets:
      - dataset_name: "text_sft"
        dataset_path: "/path/to/val"
        split: "validation"

Notes:

  • When using multiple datasets in a split with mixture_proportion:

    • All datasets must specify a mixture_proportion

    • The sum of all proportions must equal 1.0

    • The mixture_strategy determines how datasets are combined:

      • first_exhausted: Stops when any dataset is exhausted

      • all_exhausted: Continues until all datasets are exhausted (may oversample)

  • When pack is enabled:

    • stream must also be enabled

    • target_col must be specified

  • All splits must use the same collator type if specified

  • If a collator is specified for validation/test, it must also be specified for train

Training Configuration#

Configure the training process using the TrainingParams class:

training:
  # Basic settings
  output_dir: "output"                    # Directory for saving outputs
  run_name: null                          # Unique identifier for the run
  seed: 42                                # Random seed for reproducibility

  # Training duration
  num_train_epochs: 3                     # Number of training epochs
  max_steps: -1                           # Max training steps (-1 to use epochs)

  # Batch size settings
  per_device_train_batch_size: 8          # Training batch size per device
  per_device_eval_batch_size: 8           # Evaluation batch size per device
  gradient_accumulation_steps: 1          # Steps before weight update

  # Optimization
  learning_rate: 5e-5                     # Initial learning rate
  optimizer: "adamw_torch"                # Optimizer type ("adam", "adamw", "adamw_torch", "adamw_torch_fused", "sgd", "adafactor")
                                          # "adamw_8bit", "paged_adamw_8bit", "paged_adamw", "paged_adamw_32bit" (requires bitsandbytes)
  weight_decay: 0.0                       # Weight decay for regularization
  max_grad_norm: 1.0                      # Max gradient norm for clipping

  # Optimizer specific settings
  adam_beta1: 0.9                         # Adam beta1 parameter
  adam_beta2: 0.999                       # Adam beta2 parameter
  adam_epsilon: 1e-8                      # Adam epsilon parameter
  sgd_momentum: 0.0                       # SGD momentum (if using SGD)

  # Learning rate schedule
  lr_scheduler_type: "linear"             # LR scheduler type
  warmup_ratio: null                      # Warmup ratio of total steps
  warmup_steps: null                      # Number of warmup steps

  # Mixed precision and performance
  mixed_precision_dtype: "none"           # Mixed precision type ("none", "fp16", "bf16")
  compile: false                          # Whether to JIT compile model
  enable_gradient_checkpointing: false    # Trade compute for memory

  # Checkpointing
  save_steps: 500                         # Save every N steps
  save_epoch: false                       # Save at end of each epoch
  save_final_model: true                  # Save model at end of training
  resume_from_checkpoint: null            # Path to resume from
  try_resume_from_last_checkpoint: false  # Try auto-resume from last checkpoint

  # Evaluation
  eval_strategy: "steps"                  # When to evaluate ("no", "steps", "epoch")
  eval_steps: 500                         # Evaluate every N steps
  metrics_function: null                  # Name of metrics function to use

  # Logging
  log_level: "info"                       # Main logger level
  dep_log_level: "warning"                # Dependencies logger level
  enable_wandb: false                     # Enable Weights & Biases logging
  enable_tensorboard: true                # Enable TensorBoard logging
  logging_strategy: "steps"               # When to log ("steps", "epoch", "no")
  logging_steps: 50                       # Log every N steps
  logging_first_step: false               # Log first step metrics

  # DataLoader settings
  dataloader_num_workers: 0               # Number of dataloader workers (int or "auto")
  dataloader_prefetch_factor: null        # Batches to prefetch per worker (requires workers > 0)
  dataloader_main_process_only: null      # Iterate dataloader on main process only (auto if null)

  # Distributed training
  ddp_find_unused_parameters: false       # Find unused parameters in DDP
  nccl_default_timeout_minutes: null      # NCCL timeout in minutes

  # Performance monitoring
  include_performance_metrics: false      # Include token statistics
  include_alternative_mfu_metrics: false  # Include alternative MFU metrics
  log_model_summary: false                # Print model layer summary
  empty_device_cache_steps: null          # Steps between cache clearing

PEFT Configuration#

Configure parameter-efficient fine-tuning using the PeftParams class:

peft:
  # LoRA settings
  lora_r: 8                          # Rank of update matrices
  lora_alpha: 8                      # Scaling factor
  lora_dropout: 0.0                  # Dropout probability
  lora_target_modules: null          # Modules to apply LoRA to
  lora_modules_to_save: null         # Modules to unfreeze and train
  lora_bias: "none"                  # Bias training type
  lora_task_type: "CAUSAL_LM"        # Task type for adaptation
  lora_init_weights: "DEFAULT"       # Initialization of LoRA weights

  # Q-LoRA settings
  q_lora: false                      # Enable quantization
  q_lora_bits: 4                     # Quantization bits
  bnb_4bit_quant_type: "fp4"         # 4-bit quantization type
  use_bnb_nested_quant: false        # Use nested quantization
  bnb_4bit_quant_storage: "uint8"    # Storage type for params
  bnb_4bit_compute_dtype: "float32"  # Compute type for params

FSDP Configuration#

Configure fully sharded data parallel training using the FSDPParams class:

fsdp:
  enable_fsdp: false                        # Enable FSDP training
  sharding_strategy: "FULL_SHARD"           # How to shard model
  cpu_offload: false                        # Offload to CPU
  mixed_precision: null                     # Mixed precision type
  backward_prefetch: "BACKWARD_PRE"         # When to prefetch params
  forward_prefetch: false                   # Prefetch forward results
  use_orig_params: null                     # Use original module params
  state_dict_type: "FULL_STATE_DICT"        # Checkpoint format

  # Auto wrapping settings
  auto_wrap_policy: "NO_WRAP"               # How to wrap layers
  min_num_params: 100000                    # Min params for wrapping
  transformer_layer_cls: null               # Transformer layer class

  # Other settings
  sync_module_states: true                  # Sync states across processes

Notes on FSDP sharding strategies:

  • FULL_SHARD: Shards model parameters, gradients, and optimizer states. Most memory efficient but may impact performance.

  • SHARD_GRAD_OP: Shards gradients and optimizer states only. Balances memory and performance.

  • HYBRID_SHARD: Shards parameters within a node, replicates across nodes.

  • NO_SHARD: No sharding (use DDP instead).

  • HYBRID_SHARD_ZERO2: Uses SHARD_GRAD_OP within node, replicates across nodes.

Example Configurations#

You can find these examples and many more in the Recipes section.

We aim to provide a comprehensive (and growing) set of recipes for all the common training scenarios:

Full Fine-tuning (SFT)#

This example shows how to fine-tune a small model (‘SmolLM2-135M’) without any parameter-efficient methods:

configs/recipes/smollm/sft/135m/quickstart_train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py

# SFT config for SmolLM 135M Instruct.

model:
  model_name: "HuggingFaceTB/SmolLM2-135M-Instruct"
  model_max_length: 2048
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"
  load_pretrained_weights: True
  trust_remote_code: True

data:
  train:
    datasets:
      - dataset_name: "yahma/alpaca-cleaned"
    target_col: "prompt"

training:
  trainer_type: TRL_SFT
  save_final_model: True
  save_steps: 100
  max_steps: 10
  per_device_train_batch_size: 4
  gradient_accumulation_steps: 4

  ddp_find_unused_parameters: False
  optimizer: "adamw_torch"
  learning_rate: 2.0e-05
  compile: False

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 5
  log_model_summary: False
  empty_device_cache_steps: 50
  output_dir: "output/smollm135m.fft"
  include_performance_metrics: True

Parameter-Efficient Fine-tuning (LoRA)#

This example shows how to fine-tune a large model (‘Llama-3.1-70b’) using LoRA:

configs/recipes/llama3_1/sft/70b_lora/train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py

# Lora config for Llama 70B.
# Borrows param values from:
# https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama3_1/70B_lora.yaml

model:
  model_name: "meta-llama/Meta-Llama-3.1-70B-Instruct"
  model_max_length: 2048
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"
  load_pretrained_weights: True
  trust_remote_code: True

data:
  train:
    datasets:
      - dataset_name: "yahma/alpaca-cleaned" # 51,760 examples
    target_col: "prompt"
    use_async_dataset: True

training:
  trainer_type: "TRL_SFT"
  use_peft: True
  save_steps: 200
  num_train_epochs: 1
  per_device_train_batch_size: 2
  gradient_accumulation_steps: 1

  enable_gradient_checkpointing: True
  gradient_checkpointing_kwargs:
    use_reentrant: False
  ddp_find_unused_parameters: False
  optimizer: "adamw_torch_fused"
  learning_rate: 3.0e-04
  warmup_steps: 100
  weight_decay: 0.01
  compile: False

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 100
  log_model_summary: False
  empty_device_cache_steps: 50
  output_dir: "output/llama70b.lora"
  include_performance_metrics: True
  enable_wandb: True

peft:
  lora_r: 16
  lora_alpha: 32
  lora_dropout: 0.0
  lora_target_modules:
    - "q_proj"
    - "k_proj"
    - "v_proj"

fsdp:
  enable_fsdp: True
  forward_prefetch: True
  auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
  transformer_layer_cls: "LlamaDecoderLayer"

Distributed Training (FSDP)#

This example shows how to fine-tune a medium-sized model (‘Llama-3.1-8b’) using FSDP for distributed training:

configs/recipes/llama3_1/sft/8b_full/train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py

# SFT config for Llama 8B.
# Borrows param values from:
# https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama3_1/8B_full.yaml

model:
  model_name: "meta-llama/Meta-Llama-3.1-8B-Instruct"
  model_max_length: 8192
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"
  load_pretrained_weights: True
  trust_remote_code: True
  # Improves training speed by 20% with default config.
  enable_liger_kernel: True

data:
  train:
    datasets:
      - dataset_name: "yahma/alpaca-cleaned"
    target_col: "prompt"
    use_async_dataset: True

training:
  trainer_type: "TRL_SFT"
  save_steps: 800
  num_train_epochs: 3
  per_device_train_batch_size: 2
  gradient_accumulation_steps: 1

  enable_gradient_checkpointing: True
  gradient_checkpointing_kwargs:
    use_reentrant: False
  ddp_find_unused_parameters: False
  optimizer: "adamw_torch_fused"
  learning_rate: 2.0e-05
  compile: False

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 100
  log_model_summary: False
  empty_device_cache_steps: 50
  output_dir: "output/llama8b.fft"
  include_performance_metrics: True
  enable_wandb: True

fsdp:
  enable_fsdp: True
  sharding_strategy: "HYBRID_SHARD"
  forward_prefetch: True
  auto_wrap_policy: "TRANSFORMER_BASED_WRAP"
  transformer_layer_cls: "LlamaDecoderLayer"

Vision-Language Fine-tuning#

This example shows how to fine-tune a vision-language model (‘LLaVA-7B’):

configs/recipes/vision/llava_7b/sft/train.yaml
# Class: oumi.core.configs.TrainingConfig
# https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py

model:
  model_name: "llava-hf/llava-1.5-7b-hf"
  torch_dtype_str: "bfloat16"
  model_max_length: 1024
  trust_remote_code: True
  attn_implementation: "sdpa"
  chat_template: "llava"
  freeze_layers:
    - "vision_tower"

data:
  train:
    collator_name: "vision_language_with_padding"
    use_torchdata: True
    datasets:
      - dataset_name: "merve/vqav2-small"
        split: "validation"
        shuffle: True
        seed: 42
        transform_num_workers: "auto"
        dataset_kwargs:
          processor_name: "llava-hf/llava-1.5-7b-hf"
          # limit: 8192 # Uncomment to limit dataset size!
          return_tensors: True

      # Below are examples of other vision SFT datasets
      # - dataset_name: "HuggingFaceH4/llava-instruct-mix-vsft"
      #   split: "train"
      #   shuffle: True
      #   seed: 42
      #   transform_num_workers: "auto"
      #   dataset_kwargs:
      #     processor_name: "llava-hf/llava-1.5-7b-hf"
      #     limit: 8192
      #     return_tensors: True
      # - dataset_name: "coco_captions"
      #   split: "train"
      #   trust_remote_code: True
      #   dataset_kwargs:
      #     processor_name: "llava-hf/llava-1.5-7b-hf"
      #     limit: 8192
      #     return_tensors: True
      # - dataset_name: vision_language_jsonl
      #   dataset_path: "training.jsonl"  # See notebook for example how to generate this file
      #   dataset_kwargs:
      #     data_column: "messages"
      #     processor_name: "llava-hf/llava-1.5-7b-hf"

training:
  output_dir: "output/vlm_finetuned"
  trainer_type: "TRL_SFT"
  enable_gradient_checkpointing: True
  per_device_train_batch_size: 6
  gradient_accumulation_steps: 8
  max_steps: 20

  gradient_checkpointing_kwargs:
    # Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
    use_reentrant: False
  ddp_find_unused_parameters: False
  empty_device_cache_steps: 2
  compile: False

  optimizer: "adamw_torch_fused"
  learning_rate: 2e-5
  warmup_ratio: 0.03
  weight_decay: 0.0
  lr_scheduler_type: "cosine"

  logging_steps: 5
  save_steps: 0
  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32
  include_performance_metrics: True
  log_model_summary: False
  enable_wandb: True