# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from tqdm import tqdm
from typing_extensions import override
from oumi.core.constants import LABEL_IGNORE_INDEX
from oumi.core.datasets.base_map_dataset import BaseMapDataset
from oumi.core.datasets.base_sft_dataset import BaseSftDataset
from oumi.utils.logging import logger
[docs]
class PackedSftDataset(BaseMapDataset):
"""A dataset that packs samples from a base SFT dataset to maximize efficiency."""
def __init__(
self,
base_dataset: BaseSftDataset,
max_seq_len: int,
show_progress: bool = True,
split_samples: bool = False,
concat_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
enable_padding: bool = True,
**kwargs,
):
"""Initialize the PackedSftDataset.
Args:
base_dataset: The base SFT dataset to pack samples from.
max_seq_len: Maximum sequence length for packed samples.
show_progress: Whether to show progress bar during packing.
Defaults to True.
split_samples: Whether to split samples that are longer than max_seq_len.
If False, samples longer than max_seq_len will be skipped.
Defaults to False.
concat_token_id: Token ID to use for concatenating samples.
If None, samples will be concatenated without a separator token.
Defaults to None.
pad_token_id: Token ID to use for padding.
Required if enable_padding is True. Defaults to None.
enable_padding: Whether to pad sequences to max_seq_len.
If True, pad_token_id must be provided. Defaults to True.
**kwargs: Additional arguments passed to BaseMapDataset.
"""
super().__init__(**kwargs, dataset_name=base_dataset.dataset_name)
self.base_dataset = base_dataset
self._max_seq_len = max_seq_len
self._disable_tqdm = not show_progress
self._split_samples = split_samples
self._concat_token_id = concat_token_id
self._pad_token_id = pad_token_id
self._enable_padding = enable_padding
self._data: list[dict[str, torch.Tensor]] = []
if self._enable_padding and self._pad_token_id is None:
raise ValueError(
"`pad_token_id` must be provided if `enable_padding` is True"
)
self._check_dataset_compatibility()
self._load_data()
@override
def _load_data(self) -> None:
"""Pack the base dataset into constant-length samples."""
buffer = self._get_empty_buffer()
iterator = range(len(self.base_dataset))
for idx in tqdm(
iterator,
desc="Packing dataset",
dynamic_ncols=True,
disable=self._disable_tqdm,
):
sample = self.base_dataset[idx]
sample_len = len(sample["input_ids"])
if sample_len > self._max_seq_len and not self._split_samples:
# We can't split samples, and the sample is too long to fit in
# the context window. There is no way to handle this sample
logger.warning(
f"Dataset sample is too long ({sample_len} > {self._max_seq_len}). "
"Please set `split_samples=True` or increase `max_seq_len`. "
"This sample will be skipped."
)
continue
if (
self._get_potential_sample_len(sample=sample, buffer=buffer)
== self._max_seq_len
):
# Done with the current buffer, we need to create a new pack
self._append_sample_to_buffer(sample=sample, buffer=buffer)
self._append_packed_sample_to_dataset(buffer)
buffer = self._get_empty_buffer()
continue
elif (
self._get_potential_sample_len(sample=sample, buffer=buffer)
< self._max_seq_len
):
# We still have space in the buffer, so we add the sample to it
# and keep going
self._append_sample_to_buffer(sample=sample, buffer=buffer)
continue
# We don't have space in the buffer, so we need to create a new pack
if self._split_samples:
self._append_sample_to_buffer(sample=sample, buffer=buffer)
while self._get_sample_len(buffer) >= self._max_seq_len:
finished_sample, buffer = self._split_sample(
buffer, cutoff=self._max_seq_len
)
self._append_packed_sample_to_dataset(finished_sample)
else:
# We're not allow to split samples, but buffer + sample is too large
if self._get_sample_len(buffer) == 0:
self._append_sample_to_buffer(sample=sample, buffer=buffer)
self._append_packed_sample_to_dataset(buffer)
else:
self._append_packed_sample_to_dataset(buffer)
buffer = self._get_empty_buffer()
self._append_sample_to_buffer(sample=sample, buffer=buffer)
# Handle remaining samples in buffer
if self._get_sample_len(buffer) > 0:
if self._split_samples:
while self._get_sample_len(buffer) > 0:
finished_sample, buffer = self._split_sample(
buffer, cutoff=self._max_seq_len
)
self._append_packed_sample_to_dataset(finished_sample)
else:
self._append_packed_sample_to_dataset(buffer)
[docs]
@override
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
"""Get a pack from the dataset by index."""
if idx >= len(self):
raise IndexError(f"Index {idx} is out of bounds for PackedSftDataset")
return self._data[idx]
#
# Private methods
#
def _append_packed_sample_to_dataset(self, buffer: dict[str, list]) -> None:
"""Creates a fixed length training sample from the buffer and add to dataset."""
buffer_len = self._get_sample_len(buffer)
if buffer_len > self._max_seq_len:
raise ValueError(
"Buffer is too long "
f"({buffer_len} >= {self._max_seq_len}). "
"Please increase `max_seq_len`."
)
# Convert lists to tensors
sample = {k: torch.tensor(v, dtype=torch.long) for k, v in buffer.items()}
# Pad if needed
if self._enable_padding and self._pad_token_id is not None:
if buffer_len < self._max_seq_len:
pad_length = self._max_seq_len - buffer_len
for name, value in sample.items():
if name == "labels":
pad_value = LABEL_IGNORE_INDEX
else:
pad_value = self._pad_token_id
sample[name] = torch.cat(
[
sample[name],
torch.full(
(pad_length,), fill_value=pad_value, dtype=torch.long
),
]
)
self._data.append(sample)
def _append_sample_to_buffer(
self, sample: dict[str, list], buffer: dict[str, list]
) -> None:
"""Append a single training sample to the buffer.
If concat token is enabled, and if and only if we actually concatenate
two samples, we add the concat token in between the two samples
"""
if len(sample["input_ids"]) == 0:
# Nothing to add
return
should_add_concat_token = self._concat_token_id is not None
if len(buffer["input_ids"]) == 0:
# Buffer is empty, so we're not concatenating two different samples
# no need to add concat token
should_add_concat_token = False
if should_add_concat_token:
buffer["input_ids"].append(self._concat_token_id)
buffer["labels"].append(LABEL_IGNORE_INDEX) # exclude from loss
buffer["input_ids"].extend(sample["input_ids"])
buffer["labels"].extend(sample["labels"])
def _split_sample(
self, sample: dict[str, list], cutoff: int
) -> tuple[dict[str, list], dict[str, list]]:
"""Split a sample into two parts at the cutoff point.
Args:
sample: Dictionary containing lists to split
cutoff: Index at which to split the lists
Returns:
Tuple of two dictionaries containing the split lists
"""
first_half = {k: v[:cutoff] for k, v in sample.items()}
second_half = {k: v[cutoff:] for k, v in sample.items()}
return first_half, second_half
def _get_empty_buffer(self) -> dict[str, list]:
"""Get an empty buffer with all required fields."""
return {
"input_ids": [],
"labels": [],
}
def _get_sample_len(self, buffer: dict[str, list]) -> int:
"""Get the length of the samples in the buffer."""
return len(buffer["input_ids"])
def _get_potential_sample_len(
self, sample: dict[str, list], buffer: dict[str, list]
) -> int:
"""Get the length of the samples in the buffer."""
buffer_len = self._get_sample_len(buffer)
sample_len = self._get_sample_len(sample)
# In case we don't need to add a concat token
if self._concat_token_id is None or buffer_len == 0 or sample_len == 0:
return buffer_len + sample_len
# In case we need to add a concat token
return buffer_len + sample_len + 1
def _check_dataset_compatibility(self) -> None:
"""Check the base dataset for errors."""
if len(self.base_dataset) == 0:
raise ValueError("Base dataset is empty. Cannot pack empty dataset.")
keys = set(self.base_dataset[0].keys())
if "input_ids" not in keys:
raise ValueError("Base dataset must contain 'input_ids' key.")
if "labels" not in keys:
raise ValueError("Base dataset must contain 'labels' key.")
if set(keys) != {"input_ids", "labels"}:
logger.warning(
"Base dataset contains additional keys. "
"Only 'input_ids' and 'labels' will be used."
)