Custom Models#

This guide explains how to create custom models in Oumi. We’ll use the MLPEncoder as a concrete example to demonstrate best practices and requirements.

Core Concepts#

Before diving into the implementation, let’s understand the key concepts and components:

  1. Base Model Interface

    • All Oumi models inherit from oumi.core.models.BaseModel

    • This provides a consistent interface and enforces implementation of required methods

    • The base class extends torch.nn.Module for PyTorch compatibility

  2. Model Registry

  3. Model Outputs

    • Models return dictionaries containing torch.Tensor

    • Required outputs depend on the task (e.g., logits for classification)

    • Loss is included in outputs during training

  4. Loss Functions

    • Models define their training criterion via the criterion property

    • Common losses are available in torch.nn.functional

    • Custom losses can be implemented as needed

  5. Integration Points

Configuration#

The configuration is part of the overall TrainingConfig and is defined under the model section:

model:
  # Required
  model_name: "my_custom_model"    # Model ID or path
  model_kwargs:                    # Parameters passed to model constructor
    input_dim: 768
    hidden_dim: 128
    output_dim: 10

  # Optional settings
  trust_remote_code: false         # Allow remote code execution
  torch_dtype_str: "float32"       # Model precision
  device_map: "auto"              # Device placement strategy

Key points about configuration:

  • Model parameters are defined in the model section of training config

  • model_kwargs contains parameters passed to the model’s constructor

  • Configuration can be loaded from YAML files or created programmatically

Implementing a Custom Model#

Overview#

At a high level, an Oumi model:

  1. Inherits from oumi.core.models.BaseModel (which extends torch.nn.Module)

  2. Implements a forward pass that returns a dictionary of tensors

  3. Defines a loss criterion for training

  4. Follows PyTorch and (optionally) Hugging Face conventions

Here’s the complete implementation of the oumi.models.mlp.MLPEncoder, a simple text encoder model:

MLPEncoder
from typing import Callable, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F

from oumi.core import registry
from oumi.core.models.base_model import BaseModel

@registry.register("MLPEncoder", registry.RegistryType.MODEL)
class MLPEncoder(BaseModel):
    def __init__(
        self, input_dim: int = 768, hidden_dim: int = 128, output_dim: int = 10
    ):
        """Initialize the MLPEncoder.

        Args:
            input_dim (int): The input dimension.
            hidden_dim (int): The hidden dimension.
            output_dim (int): The output dimension.
        """
        super().__init__()

        self.embedding = nn.Embedding(input_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(
        self,
        input_ids: torch.LongTensor,
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        """Forward pass of the MLP model.

        Args:
            input_ids (torch.LongTensor): The input tensor of shape
                (batch_size, sequence_length).
            labels (torch.LongTensor, optional): The target labels tensor
                of shape (batch_size,).
            **kwargs: Additional keyword arguments provided by the tokenizer.
                Not used in this model.

        Returns:
            dict[str, torch.Tensor]: A dictionary containing the model outputs.
                The dictionary has the following keys:
                - "logits" (torch.Tensor): The output logits tensor of
                  shape (batch_size, num_classes).
                - "loss" (torch.Tensor, optional): The computed loss tensor
                  if labels is not None.
        """
        x = self.embedding(input_ids)
        x = self.fc1(x)
        x = self.relu(x)
        logits = self.fc2(x)
        outputs = {"logits": logits}

        if labels is not None:
            loss = self.criterion(
                logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1
            )
            outputs["loss"] = loss

        return outputs

    @property
    def criterion(self) -> Callable:
        """Returns the criterion function for the MLP model.

        The criterion function is used to compute the loss during training.

        Returns:
            torch.nn.CrossEntropyLoss: The cross-entropy loss function.
        """
        return F.cross_entropy

Implementation Breakdown#

Let’s break down each component of the implementation:

1. Model Registration and Base Class#

@registry.register("MLPEncoder", registry.RegistryType.MODEL)
class MLPEncoder(BaseModel):

2. Model Architecture#

def __init__(self, input_dim: int = 768, hidden_dim: int = 128, output_dim: int = 10):
    super().__init__()
    self.embedding = nn.Embedding(input_dim, hidden_dim)
    self.fc1 = nn.Linear(hidden_dim, hidden_dim)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_dim, output_dim)

3. Forward Pass#

def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, **kwargs) -> dict[str, torch.Tensor]:
  • Implements the required torch.nn.Module.forward() method

  • Takes input tensors and optional labels

  • Returns a dictionary with model outputs

  • Computes loss during training if labels are provided

4. Loss Function#

@property
def criterion(self) -> Callable:
    return F.cross_entropy

Using Custom Models via the CLI#

See Customizing Oumi to quickly enable your model when using the CLI.

Testing Models#

Oumi uses pytest for testing models. Here’s an example test for the MLPEncoder:

import pytest
import torch
from oumi.models import MLPEncoder

def test_mlp_encoder_forward():
    """Test forward pass of MLPEncoder."""
    # Initialize model
    model = MLPEncoder(input_dim=100, hidden_dim=32, output_dim=10)

    # Create dummy inputs
    batch_size, seq_len = 16, 8
    input_ids = torch.randint(0, 100, (batch_size, seq_len))
    labels = torch.randint(0, 10, (batch_size,))

    # Test forward pass without labels
    outputs = model(input_ids=input_ids)
    assert "logits" in outputs
    assert outputs["logits"].shape == (batch_size, 10)

    # Test forward pass with labels
    outputs = model(input_ids=input_ids, labels=labels)
    assert "loss" in outputs
    assert outputs["loss"].shape == ()

Using the Model#

There are two main ways to use custom models in Oumi:

1. Using build_model#

You can create model instances programmatically using the build_model function:

from oumi.builders import build_model
from oumi.core.configs import ModelParams

# Create model parameters
model_params = ModelParams(
    model_name="MLPEncoder",  # Name used in @registry.register
    model_kwargs={
        "input_dim": 1000,
        "hidden_dim": 256,
        "output_dim": 10,
    },
    torch_dtype_str="float32",
    device_map="auto",
)

# Build the model
model = build_model(model_params=model_params)

2. Using Training Configuration#

More commonly, you’ll define the model as part of a training configuration:

# train_config.yaml
model:
  model_name: "MLPEncoder"
  model_kwargs:
    input_dim: 1000
    hidden_dim: 256
    output_dim: 10
  torch_dtype_str: "float32"
  device_map: "auto"

data:
  train:
    datasets:
      - dataset_name: "text_classification"
        dataset_path: "path/to/data"
        split: "train"

training:
  output_dir: "outputs/mlp_run"
  num_train_epochs: 3
  learning_rate: 1e-4
  per_device_train_batch_size: 32

Then use it in your training script:

from oumi.core.configs import TrainingConfig
from oumi.train import train

# Load and run training
config = TrainingConfig.from_yaml("train_config.yaml")
train(config)

The key points about using models in Oumi:

  • Models are instantiated through the oumi.builders.build_model() function

  • All constructor parameters go in model_kwargs

  • Models can be configured through YAML for training

  • The training system handles device placement, distributed training, etc.

3. Standard PyTorch Training Loop#

The model can be used in a standard PyTorch training loop:

# Initialize model and optimizer
model = MLPEncoder()
optimizer = torch.optim.Adam(model.parameters())

# Training loop
for batch in dataloader:
    optimizer.zero_grad()
    outputs = model(**batch)
    loss = outputs["loss"]
    loss.backward()
    optimizer.step()

See Also#