Source code for oumi.core.datasets.base_pretraining_dataset
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAnyimporttorchfromtyping_extensionsimportoverridefromoumi.core.datasets.base_iterable_datasetimportBaseIterableDatasetfromoumi.core.tokenizersimportBaseTokenizer
[docs]classBasePretrainingDataset(BaseIterableDataset):"""Base class for pretraining iterable datasets. This class extends BaseIterableDataset to provide functionality specific to pretraining tasks. Attributes: tokenizer (BaseTokenizer): The tokenizer used for text encoding. seq_length (int): The desired sequence length for model inputs. concat_token_id (int): The ID of the token used to concatenate documents. Example: >>> from transformers import AutoTokenizer >>> from oumi.builders import build_tokenizer >>> from oumi.core.configs import ModelParams >>> from oumi.core.datasets import BasePretrainingDataset >>> tokenizer = build_tokenizer(ModelParams(model_name="gpt2")) >>> dataset = BasePretrainingDataset( ... dataset_name="wikimedia/wikipedia", ... subset="20231101.en", ... split="train", ... tokenizer=tokenizer, ... seq_length=512 ... ) >>> example = next(iter(dataset)) """def__init__(self,*,tokenizer:BaseTokenizer,seq_length:int,dataset_text_field:str="text",append_concat_token:bool=True,add_special_tokens:bool=True,skip_last:bool=True,**kwargs,):"""Initializes a new instance of the BasePretrainingDataset class."""ifappend_concat_tokenandtokenizer.eos_token_idisNone:raiseValueError("Tokenizer must have an EOS token if append_concat_token is enabled.")self.concat_token_id=tokenizer.eos_token_idifappend_concat_tokenelseNoneself.tokenizer=tokenizerself.seq_length=seq_lengthself._dataset_text_field=dataset_text_fieldself._append_concat_token=append_concat_tokenself._add_special_tokens=add_special_tokensself._skip_last=skip_lastsuper().__init__(**kwargs)
[docs]def__iter__(self):"""Iterates over the dataset and yields samples of a specified sequence length. The underlying dataset is a stream of documents. Each document is expected to contain a text field `self._dataset_text_field` that will be tokenized. Training samples are then yielded in sequences of length `self.seq_length`. Given this iterator might return samples from different documents, we optionally use `self.concat_token_id` to separate the sequences from different documents. """buffer=[]fordocumentinself.data:ifself._append_concat_tokenandlen(buffer)>0:# We started preprocessing a new document# so we need to append the concatenation token to mark the end# of the previous document.buffer.append(self.concat_token_id)# Pre-process and tokenize the documentdocument_tokens=self.transform(document[self._dataset_text_field])buffer.extend(document_tokens)# Yield sequences of the specified length.# Otherwise, resume pre-processing the next document.whilelen(buffer)>=self.seq_length:# We have enough tokens to create a fully packed sampleyieldself._create_training_sample(buffer[:self.seq_length])buffer=buffer[self.seq_length:]# Finished iterating on the dataset, yield the remaining bufferiflen(buffer)>0:ifnotself._skip_lastorlen(buffer)==self.seq_length:yieldself._create_training_sample(buffer)
[docs]@overridedeftransform(self,sample:Any)->list[int]:"""Preprocesses the inputs in the given sample."""returnself.tokenize(sample)
[docs]deftokenize(self,text:str)->list[int]:"""Tokenizes the given text. Should not apply any padding/truncation to allow for packing. """returnself.tokenizer.encode(text=text,return_tensors=None,max_length=None,padding=False,truncation=False,add_special_tokens=self._add_special_tokens,)
def_create_training_sample(self,tokens:list)->dict[str,torch.Tensor]:"""Creates a training sample from the given tokens."""input_ids=torch.tensor(tokens)attention_mask=torch.ones_like(input_ids)return{"input_ids":input_ids,"attention_mask":attention_mask,"labels":input_ids.clone(),}