Out of Memory (OOM)#

Introduction#

Out of Memory (OOM) errors are a common challenge when working with large language models and datasets.

In this guide, we will discuss a few strategies to reduce GPU memory requirements.

Best Practices

  • Always monitor memory usage and performance metrics when applying these optimizations, using nvidia-smi and Oumi’s telemetry output.

  • Combine multiple techniques for best results, but introduce changes gradually to isolate their effects.

  • Some techniques may trade off speed and model accuracy for memory efficiency. Choose the right balance for your specific use case.

Training Optimizations#

  1. Reduce batch size:

    from oumi.core.configs import TrainingConfig, TrainingParams
    
    config = TrainingConfig(
        training=TrainingParams(
            per_device_train_batch_size=8,  # Decrease this value
            gradient_accumulation_steps=4,  # Increase this value
        ),
    )
    
    training:
        per_device_train_batch_size: 8  # Decrease this value
        gradient_accumulation_steps: 4  # Increase this value
    
  2. Enable gradient checkpointing:

    config = TrainingConfig(
        training=TrainingParams(
            enable_gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
        ),
    )
    
    training:
        enable_gradient_checkpointing: true
        gradient_checkpointing_kwargs:
            use_reentrant: false
    
  3. Use fused optimizers:

    config = TrainingConfig(
        training=TrainingParams(
            optimizer="adamw_torch_fused",
        ),
    )
    
    training:
        optimizer: adamw_torch_fused
    
  4. Use mixed precision training:

    config = TrainingConfig(
        training=TrainingParams(
            mixed_precision_dtype="bf16",  # or "fp16"
        ),
    )
    
    training:
        mixed_precision_dtype: bf16  # or fp16
    
  5. Train in half-precision:

    config = TrainingConfig(
        model=ModelParams(
            torch_dtype_str="bfloat16",  # or "float16"
        ),
    )
    
    model:
        torch_dtype_str: bfloat16  # or float16
    
  6. Empty GPU cache more frequently:

    config = TrainingConfig(
        training=TrainingParams(
            empty_device_cache_steps=50,  # Clear GPU cache every 50 steps
        ),
    )
    
    training:
        empty_device_cache_steps: 50  # Clear GPU cache every 50 steps
    
  7. Tune CUDA Allocator Settings

    It’s sometimes possible to eliminate OOM errors (e.g., OOM-s caused by GPU VRAM fragmentation) by tuning CUDA allocator configuration as described in PyTorch Optimizing Memory Usage e.g., by switching to a different allocator, tuning garbage collection settings. Example:

    envs:
        PYTORCH_CUDA_ALLOC_CONF: "garbage_collection_threshold:0.8,max_split_size_mb:128"
    
    export PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.8,max_split_size_mb:128"
    
  8. Use Paged Adam:

    config = TrainingConfig(
        training=TrainingParams(
            optimizer="paged_adamw_32bit",
        ),
    )
    
    training:
        optimizer: paged_adamw_32bit
    

    Note

    Paged Adam requires bitsandbytes to be installed.

Model Configuration#

  1. Use flash attention:

    config = TrainingConfig(
        model=ModelParams(
            attn_implementation="sdpa",  # or "flash_attention2"
        ),
    )
    
    model:
        attn_implementation: sdpa  # or flash_attention2
    
  2. Enable model compilation:

    config = TrainingConfig(
        training=TrainingParams(
            compile=True,
        ),
    )
    
    training:
        compile: true
    
  3. Enable Liger Kernels:

    from oumi.core.configs import ModelParams
    
    config = TrainingConfig(
        model=ModelParams(
            enable_liger_kernel=True,
        ),
    )
    
    model:
        enable_liger_kernel: true
    
  4. Reduce training sequence length:

    config = TrainingConfig(
        model=ModelParams(
            model_max_length=2048,  # Reduce sequence length
        ),
    )
    
    model:
        model_max_length: 2048  # Reduce sequence length
    
  5. Selectively freeze layers:

    config = TrainingConfig(
        model=ModelParams(
            freeze_layers=["layer.0", "layer.1", "layer.2"],
        ),
    )
    
    model:
        freeze_layers:
            - layer.0
            - layer.1
            - layer.2
    
  6. Enable ring attention:

Added in version 0.2.0: (Coming soon)

config = TrainingConfig(
    model=ModelParams(
        attn_implementation="ring_attention",
    ),
)
model:
  attn_implementation: ring_attention

Parameter-Efficient Fine-Tuning (PEFT)#

  1. Enable LoRA:

    from oumi.core.configs import PeftParams
    
    config = TrainingConfig(
        training=TrainingParams(use_peft=True),
        peft=PeftParams(
            lora_r=16,
            lora_alpha=32,
            lora_dropout=0.05,
        ),
    )
    
    training:
        use_peft: true
    
    peft:
        lora_r: 16
        lora_alpha: 32
        lora_dropout: 0.05
    

Distributed Training with FSDP#

If you have access to multiple GPUs, you can leverage FSDP to distribute the training process across multiple GPUs. To run FSDP jobs, make sure to invoke your training job with torchrun to run on multiple GPUs/nodes. We also provide the oumi distributed wrapper to automatically try to set the flags needed for torchrun. For example, you can simply run oumi distributed torchrun -m oumi train -c path/to/train.yaml.

  1. Enable distributed training:

    from oumi.core.configs import FSDPParams
    from oumi.core.configs.params.fsdp_params import ShardingStrategy
    
    config = TrainingConfig(
        fsdp=FSDPParams(
            enable_fsdp=True,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
        ),
    )
    
    fsdp:
        enable_fsdp: true
        sharding_strategy: FULL_SHARD
    
  2. Enable CPU offloading:

    config = TrainingConfig(
        fsdp=FSDPParams(
            enable_fsdp=True,
            cpu_offload=True,
        ),
    )
    
    fsdp:
        enable_fsdp: true
        cpu_offload: true
    
  3. Disable Forward Prefetch:

    config = TrainingConfig(
        fsdp=FSDPParams(
            enable_fsdp=True,
            forward_prefetch=False,
        ),
    )
    
    fsdp:
        enable_fsdp: true
        forward_prefetch: false
    
  4. Disable Backward Prefetch:

    config = TrainingConfig(
        fsdp=FSDPParams(
            enable_fsdp=True,
            backward_prefetch=BackwardPrefetch.NO_PREFETCH,
        ),
    )
    
    fsdp:
        enable_fsdp: true
        backward_prefetch: NO_PREFETCH
    

    Attention

    Disabling FSDP’s forward and backward prefetch can lead to significant slower training times, use with caution.