# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import math
import os
import time
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable
from pathlib import Path
from typing import Any, NamedTuple, Optional, Union, cast
import datasets
import pandas as pd
from torch.utils.data import MapDataPipe
from oumi.utils.hf_datasets_utils import is_cached_to_disk_hf_dataset
from oumi.utils.logging import logger
from oumi.utils.torch_utils import estimate_sample_dict_size_in_bytes, get_shape_as_list
class _ExamplesIndicesRange(NamedTuple):
"""A valid sub-range of example indices."""
start_index: int
end_index: int
class _InferredFeatureMap(NamedTuple):
feature_map: datasets.Features
"""Inferred feature map."""
is_feature_map_optimized: bool
"""Indicates whether the original feature map was optimized.
In optimized feature maps, large features use the inferred `ArrayXD` arrow
column type (not `sequence`) which supports larger datasets with more elements.
"""
element_size_in_bytes: int
"""Estimated element size in bytes."""
multimodal: bool
"""Whether the features are multimodal."""
[docs]
class BaseMapDataset(MapDataPipe, ABC):
"""Abstract base class for map datasets."""
_data: pd.DataFrame
dataset_name: str
dataset_path: Optional[str] = None
default_dataset: Optional[str] = None
default_subset: Optional[str] = None
trust_remote_code: bool
transform_num_workers: Optional[Union[str, int]] = None
def __init__(
self,
*,
dataset_name: Optional[str],
dataset_path: Optional[str] = None,
subset: Optional[str] = None,
split: Optional[str] = None,
trust_remote_code: bool = False,
transform_num_workers: Optional[Union[str, int]] = None,
**kwargs,
) -> None:
"""Initializes a new instance of the BaseDataset class."""
dataset_type_name = self.__class__.__name__
logger.info(
f"Creating map dataset (type: {dataset_type_name}) "
f"dataset_name: '{dataset_name}', dataset_path: '{dataset_path}'..."
)
if len(kwargs) > 0:
logger.debug(
f"Unknown arguments: {', '.join(kwargs.keys())}. "
"Please check the class constructor for supported arguments "
f"(type: {dataset_type_name})."
)
dataset_name = dataset_name or self.default_dataset
if dataset_name is None:
raise ValueError(
"Please specify a dataset_name or "
"set the default_dataset class attribute "
f"(type: {dataset_type_name})."
)
self.dataset_name = dataset_name
self.dataset_path = dataset_path
self.dataset_subset = subset or self.default_subset
self.split = split
self.trust_remote_code = trust_remote_code
self.transform_num_workers = transform_num_workers
#
# Main API
#
[docs]
def __getitem__(self, idx: int) -> dict:
"""Gets the item at the specified index.
Args:
idx (int): The index of the item to retrieve.
Returns:
dict: The item at the specified index.
"""
sample = self.raw(idx)
processed = self.transform(sample)
return processed
[docs]
def __len__(self) -> int:
"""Gets the number of items in the dataset.
Returns:
int: The number of items in the dataset.
"""
return len(self._data)
@property
def data(self) -> pd.DataFrame:
"""Returns the underlying dataset data."""
return self._data
[docs]
def raw(self, idx: int) -> pd.Series:
"""Returns the raw data at the specified index.
Args:
idx (int): The index of the data to retrieve.
Returns:
pd.Series: The raw data at the specified index.
"""
return self._data.iloc[idx]
[docs]
def as_generator(self) -> Generator[dict[str, Any], None, None]:
"""Returns a generator for the dataset."""
for idx in range(len(self)):
yield self[idx]
def _as_generator_over_shards(
self, shards: list[_ExamplesIndicesRange]
) -> Generator[dict[str, Any], None, None]:
"""Returns a sharded generator for the dataset."""
for shard in shards:
for idx in range(shard.start_index, shard.end_index):
yield self[idx]
def _detect_features_and_estimate_element_size_bytes(
self, samples_iter: Iterable[dict[str, Any]]
) -> _InferredFeatureMap:
"""Returns an estimate of max element size in bytes."""
samples_list = list(samples_iter)
def _dummy_generator():
yield from samples_list
sample_dataset = cast(
datasets.Dataset,
datasets.Dataset.from_generator(_dummy_generator, keep_in_memory=True),
)
if len(sample_dataset) <= 0:
raise ValueError("Empty sample dataset!")
max_elem_bytes = max(
[estimate_sample_dict_size_in_bytes(elem) for elem in samples_list]
)
features = sample_dataset.features.copy()
is_feature_map_optimized: bool = False
is_multimodal: bool = False
# At this time, we care mostly about `pixel_values` as it's by far the largest
# feature (e.g., 15MB for Llama 3.2 Vision), which causes serialization errors
# for large datasets if saved in the default format, which is
# a nested sequence (of sequences (of sequences ...)).
# TODO: Tune feature types for other features for efficiency.
if "pixel_values" in samples_list[0]:
is_multimodal = True
inferred_features = []
variable_shapes_detected: bool = False
for elem in samples_list:
shape = tuple(get_shape_as_list(elem["pixel_values"]))
shape_dims = len(shape)
if shape_dims == 2:
feature_def = datasets.Array2D(dtype="float32", shape=shape)
elif shape_dims == 3:
feature_def = datasets.Array3D(dtype="float32", shape=shape)
elif shape_dims == 4:
feature_def = datasets.Array4D(dtype="float32", shape=shape)
elif shape_dims == 5:
feature_def = datasets.Array5D(dtype="float32", shape=shape)
else:
raise ValueError(
"The `pixel_values` feature has unsupported dimensionality "
f"({shape_dims}D). Must be 2D...5D."
)
inferred_features.append(feature_def)
for i in range(1, len(samples_list)):
if (
type(inferred_features[i - 1]),
inferred_features[i - 1].dtype,
inferred_features[i - 1].shape,
) != (
type(inferred_features[i]),
inferred_features[i].dtype,
inferred_features[i].shape,
):
variable_shapes_detected = True
logger.warning(
f"The `pixel_values` feature has variable shapes: "
f"{inferred_features[i - 1]} vs {inferred_features[i]}!"
)
if not variable_shapes_detected:
# Re-define the feature to be `ArrayXD`
# if all shapes are the same.
features["pixel_values"] = inferred_features[0]
is_feature_map_optimized = True
logger.info(
"The `pixel_values` feature has this inferred type: "
f"{inferred_features[0]}"
)
del sample_dataset
return _InferredFeatureMap(
feature_map=features,
is_feature_map_optimized=is_feature_map_optimized,
element_size_in_bytes=max_elem_bytes,
multimodal=is_multimodal,
)
def _compute_effective_transform_num_workers(self) -> int:
"""Returns an effective number of dataset transform workers.
Guaranteed to be a positive integer (>= 1). 1 if no parallelism is used.
"""
num_proc = None
if self.transform_num_workers is not None:
if isinstance(self.transform_num_workers, int):
num_proc = self.transform_num_workers
elif self.transform_num_workers == "auto":
num_proc = os.cpu_count()
if num_proc is not None:
# Limit the max number of sub-processes.
num_proc = min(8, num_proc)
assert (
num_proc is None or num_proc > 0
), f"transform_num_workers: {self.transform_num_workers}"
num_proc = max(1, num_proc if num_proc is not None else 1)
assert num_proc >= 1
return num_proc
[docs]
def to_hf(
self, return_iterable: bool = False
) -> Union[datasets.Dataset, datasets.IterableDataset]:
"""Converts the dataset to a Hugging Face dataset.
Args:
return_iterable: Whether to return an iterable dataset.
Iterable datasets aren't cached to disk, which can sometimes be
advantageous. For example, if transformed examples are very large
(e.g., if `pixel_values` are large for multimodal data), or if you don't
want to post-process the whole dataset before training starts.
Returns:
A HuggingFace dataset. Can be `datasets.Dataset` or
`datasets.IterableDataset` depending on the value of `return_iterable`.
"""
_MAX_SHARD_SIZE = 1 * 1024 * 1024 * 1024 # ~1GB
dataset_type_name = self.__class__.__name__
num_proc = self._compute_effective_transform_num_workers()
total_examples = len(self)
output_features: _InferredFeatureMap = (
self._detect_features_and_estimate_element_size_bytes(
self._as_generator_over_shards(
[
_ExamplesIndicesRange(start_index=i, end_index=(i + 1))
for i in range(0, total_examples, max(1, total_examples // 8))
]
)
)
)
elements_per_shard: int = int(math.ceil(float(total_examples) / num_proc))
if output_features.element_size_in_bytes > 0:
elements_per_shard = min(
elements_per_shard,
_MAX_SHARD_SIZE // output_features.element_size_in_bytes,
)
# Clamp `writer_batch_size` to [1, 200/1000] range.
writer_batch_size = max(
1, min(elements_per_shard, 200 if output_features.multimodal else 1000)
)
logger.info(
f"{dataset_type_name}: features={output_features.feature_map.keys()}"
)
logger.debug(
f"{dataset_type_name}: features={output_features} "
f"examples={total_examples} "
f"writer_batch_size={writer_batch_size} num_proc={num_proc}"
)
# If feature map isn't "optimized" then ignore it to fallback
# to the default behavior in `from_generator()`.
feature_map = (
output_features.feature_map
if output_features.is_feature_map_optimized
else None
)
start_time = time.perf_counter()
if num_proc > 1 or (
output_features.element_size_in_bytes * total_examples > _MAX_SHARD_SIZE
):
starts: list[int] = list(
range(
0,
total_examples,
writer_batch_size,
)
)
stops: list[int] = starts[1:] + [total_examples]
shards: list[_ExamplesIndicesRange] = [
_ExamplesIndicesRange(start_index=item[0], end_index=item[1])
for item in zip(starts, stops)
]
if return_iterable:
result = datasets.IterableDataset.from_generator(
self._as_generator_over_shards,
gen_kwargs={"shards": shards},
features=feature_map,
)
else:
result = datasets.Dataset.from_generator(
self._as_generator_over_shards,
gen_kwargs={"shards": shards},
keep_in_memory=False,
num_proc=(num_proc if num_proc > 1 else None),
features=feature_map,
writer_batch_size=writer_batch_size,
)
else:
if return_iterable:
result = datasets.IterableDataset.from_generator(
self.as_generator,
features=feature_map,
)
else:
result = datasets.Dataset.from_generator(
self.as_generator,
keep_in_memory=False,
features=feature_map,
writer_batch_size=writer_batch_size,
)
duration_sec = time.perf_counter() - start_time
logger.info(
f"Finished transforming dataset ({dataset_type_name})! "
f"Speed: {total_examples / duration_sec:.2f} examples/sec. "
f"Examples: {total_examples}. "
f"Duration: {duration_sec:.1f} sec. Transform workers: {num_proc}."
)
if return_iterable:
result = cast(datasets.IterableDataset, result)
logger.debug(f"{dataset_type_name}: IterableDataset: {result}")
else:
result = cast(datasets.Dataset, result)
logger.debug(
f"{dataset_type_name}: MapDataset: {result}\n\n"
f"Arrow schema: {result.features.arrow_schema}"
)
return result
#
# Abstract Methods
#
#
# Data Loading
#
def _load_data(self) -> pd.DataFrame:
"""Loads the dataset from the specified source.
Returns:
dict: The loaded dataset.
"""
if self.dataset_path:
result = self._load_local_dataset(self.dataset_path)
else:
result = self._load_hf_hub_dataset()
# Reclaim memory after data loading.
gc.collect()
logger.info(
f"Loaded DataFrame with shape: {result.shape}. Columns:\n{result.dtypes}"
)
return result
def _load_local_dataset(self, path: str) -> pd.DataFrame:
"""Loads the dataset from the specified local source.
Returns:
dict: The loaded dataset.
"""
dataset_path = Path(path)
if not dataset_path.exists():
raise FileNotFoundError(f"File not found: {dataset_path}")
if dataset_path.suffix.lower() == ".jsonl" and dataset_path.is_file():
result = self._load_jsonl_dataset(dataset_path)
elif dataset_path.suffix.lower() == ".parquet" and dataset_path.is_file():
result = self._load_parquet_dataset(dataset_path)
elif is_cached_to_disk_hf_dataset(dataset_path):
result = self._load_dataset_from_disk(dataset_path)
else:
raise ValueError(f"File format not supported for {self.dataset_name}")
return result
def _load_hf_hub_dataset(self) -> pd.DataFrame:
"""Loads the dataset from the specified Hugging Face Hub source.
Returns:
dict: The loaded dataset.
"""
splits_or_dataset = datasets.load_dataset(
path=self.dataset_name,
name=self.dataset_subset,
split=self.split,
trust_remote_code=self.trust_remote_code,
)
if isinstance(
splits_or_dataset, (datasets.IterableDataset, datasets.IterableDatasetDict)
):
raise ValueError("IterableDataset is not supported with this class.")
# Grab a single dataset split
if isinstance(splits_or_dataset, datasets.Dataset):
dataset = splits_or_dataset
elif self.split is not None:
dataset = splits_or_dataset[self.split]
elif len(splits_or_dataset) == 1:
dataset = splits_or_dataset.values().__iter__().__next__()
else:
raise ValueError(
"Multiple splits found in the dataset. Please specify a single split. "
f"Available splits: {list(splits_or_dataset.keys())}"
)
logger.info(
"\n".join(
[
"Dataset Info:",
f"\tSplit: {dataset.split}",
f"\tVersion: {dataset.version}",
f"\tDataset size: {dataset.dataset_size}",
f"\tDownload size: {dataset.download_size}",
f"\tSize: {dataset.size_in_bytes} bytes",
f"\tRows: {len(dataset)}",
f"\tColumns: {dataset.column_names}",
]
)
)
result = dataset.to_pandas()
del dataset
return cast(pd.DataFrame, result)
def _load_jsonl_dataset(self, path: Path) -> pd.DataFrame:
return pd.read_json(path, lines=True)
def _load_parquet_dataset(self, path: Path) -> pd.DataFrame:
return pd.read_parquet(path)
def _load_dataset_from_disk(self, path: Path) -> pd.DataFrame:
dataset: datasets.Dataset = datasets.Dataset.load_from_disk(path)
result = dataset.to_pandas()
del dataset
return cast(pd.DataFrame, result)