Source code for oumi.models.mlp

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module defines the MLPEncoder class, which is a simple text encoder."""

from typing import Callable, Optional

import torch
import torch.nn as nn
from torch.nn import functional as F

from oumi.core import registry
from oumi.core.models.base_model import BaseModel


[docs] @registry.register("MlpEncoder", registry.RegistryType.MODEL) class MLPEncoder(BaseModel): def __init__( self, input_dim: int = 768, hidden_dim: int = 128, output_dim: int = 10 ): """Initialize the MLPEncoder. Args: input_dim (int): The input dimension. hidden_dim (int): The hidden dimension. output_dim (int): The output dimension. """ super().__init__() self.embedding = nn.Embedding(input_dim, hidden_dim) self.fc1 = nn.Linear(hidden_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim)
[docs] def forward( self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, **kwargs, ) -> dict[str, torch.Tensor]: """Forward pass of the MLP model. Args: input_ids (torch.LongTensor): The input tensor of shape (batch_size, sequence_length). labels (torch.LongTensor, optional): The target labels tensor of shape (batch_size,). **kwargs: Additional keyword arguments provided by the tokenizer. Not used in this model. Returns: dict: A dictionary containing the model outputs. The dictionary has the following keys: - "logits" (torch.Tensor): The output logits tensor of shape (batch_size, num_classes). - "loss" (torch.Tensor, optional): The computed loss tensor if labels is not None. """ x = self.embedding(input_ids) x = self.fc1(x) x = self.relu(x) logits = self.fc2(x) outputs = {"logits": logits} if labels is not None: loss = self.criterion( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1 ) outputs["loss"] = loss return outputs
@property def criterion(self) -> Callable: """Returns the criterion function for the MLP model. The criterion function is used to compute the loss during training. Returns: torch.nn.CrossEntropyLoss: The cross-entropy loss function. """ return F.cross_entropy