# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.frompathlibimportPathfromtypingimportAny,Optional,Unionimportpandasaspdfromtyping_extensionsimportoverridefromoumi.core.datasetsimportBaseSftDatasetfromoumi.core.registryimportregister_datasetfromoumi.core.types.conversationimportConversationfromoumi.utils.io_utilsimportload_json,load_jsonlines
[docs]@register_dataset("text_sft")@register_dataset("text_sft_jsonl")classTextSftJsonLinesDataset(BaseSftDataset):"""TextSftJsonLinesDataset for loading SFT data in oumi and alpaca formats. This dataset class is designed to work with JSON Lines (.jsonl) or JSON (.json) files containing text-based supervised fine-tuning (SFT) data. It supports loading data either from a file or from a provided list of data samples in oumi and alpaca formats. Supported formats: 1. JSONL or JSON of conversations (Oumi format) 2. JSONL or JSON of Alpaca-style turns (instruction, input, output) Args: dataset_path (Optional[Union[str, Path]]): Path to the dataset file (.jsonl or .json). data (Optional[List[Dict[str, Any]]]): List of conversation dicts if not loading from a file. format (Optional[str]): The format of the data. Either "conversations" or "alpaca". If not provided, the format will be auto-detected. **kwargs: Additional arguments to pass to the parent class. Examples: Loading conversations from a JSONL file with auto-detection: >>> from oumi.datasets import TextSftJsonLinesDataset >>> dataset = TextSftJsonLinesDataset( # doctest: +SKIP ... dataset_path="/path/to/your/dataset.jsonl" ... ) Loading Alpaca-style data from a JSON file: >>> from oumi.datasets import TextSftJsonLinesDataset >>> dataset = TextSftJsonLinesDataset( # doctest: +SKIP ... dataset_path="/path/to/your/dataset.json", ... format="alpaca" ... ) Loading from a list of data samples: >>> from oumi.datasets import TextSftJsonLinesDataset >>> data_samples = [ ... {"messages": [{"role": "user", "content": "Hello"}, ... {"role": "assistant", "content": "Hi there!"}]}, ... {"messages": [{"role": "user", "content": "How are you?"}, ... {"role": "assistant", "content": "great!"}]} ... ] >>> dataset = TextSftJsonLinesDataset( ... data=data_samples, ... ) """default_dataset="custom"def__init__(self,dataset_path:Optional[Union[str,Path]]=None,data:Optional[list[dict[str,Any]]]=None,format:Optional[str]=None,**kwargs,):"""Initializes a new instance of the TextSftJsonLinesDataset class. Args: dataset_path (Optional): Path to the JSON lines dataset file. data (Optional): List of conversation dicts if not loading from a file. format (Optional): The format of the data. Either "conversations", or "alpaca". If not provided, the format will be auto-detected. **kwargs: Additional arguments to pass to the parent class. Raises: ValueError: If neither dataset_path nor data is provided, or if both are provided. """ifdataset_pathisnotNoneanddataisnotNone:raiseValueError("Either dataset_path or data must be provided, but not both")self._data_column:str="_messages_column"self._dataset_path:Optional[Path]=(Path(dataset_path)ifdataset_pathelseNone)ifdataisnotNone:data_frame=pd.DataFrame({self._data_column:data})elifself._dataset_pathisnotNone:ifself._dataset_path.suffix.lower()==".jsonl":data=load_jsonlines(self._dataset_path)elifself._dataset_path.suffix.lower()==".json":data=load_json(self._dataset_path)else:raiseValueError(f"Unsupported file format: {self._dataset_path.suffix}. ""Use .jsonl or .json file extensions.")data_frame=pd.DataFrame({self._data_column:data})else:raiseValueError("Dataset path or data must be provided")assertdata_frameisnotNoneself._data:pd.DataFrame=data_frameifformatandformatnotin["oumi","alpaca"]:raiseValueError(f"Invalid format: {format}. Supported formats are 'oumi', and 'alpaca'.")self._format:str=formatifformatelseself._detect_format(data_frame)super().__init__(**kwargs)def_detect_format(self,data_frame:pd.DataFrame)->str:"""Detect the format of the data based on the first item. Args: data_frame: The DataFrame containing the data. Returns: str: The detected format ("oumi", or "alpaca"). Raises: ValueError: If the format cannot be detected. """first_item=data_frame[self._data_column].iloc[0]ifnotisinstance(first_item,dict):raiseValueError("Invalid data format. Each item in the dataset should be a dictionary. "f"Found type: {type(first_item)}. ""Please check your data structure and try again.")if"messages"infirst_item:ifisinstance(first_item["messages"],list)andall(isinstance(m,dict)and"role"inmand"content"inmforminfirst_item["messages"]):return"oumi"return"conversations"elifall(keyinfirst_itemforkeyin["instruction","input","output"]):return"alpaca"raiseValueError("Unable to auto-detect format. ""The data structure doesn't match any supported format. ""Please specify the format manually or ensure your data follows ""one of these structures:\n""1. Conversations format: ""{'messages': [{'role': 'user', 'content': '...'}, ...]}\n""2. Alpaca format: ""{'instruction': '...', 'input': '...', 'output': '...'}\n")@overridedef_load_data(self)->pd.DataFrame:# Data is already loaded in __init__returnself._data
[docs]@overridedeftransform_conversation(self,example:dict)->Conversation:"""Transform a single conversation example into a Conversation object. Args: example: The input example containing the messages or Alpaca-style turn. Returns: Conversation: A Conversation object containing the messages. """conversation_dict=example[self._data_column]ifself._format=="oumi":try:returnConversation.model_validate(conversation_dict)exceptExceptionase:raiseValueError(f"Invalid conversation format. "f"Expected a dictionary with a 'messages' key "f"containing a list of message dictionaries. Error: {str(e)}")fromeelifself._format=="alpaca":returnself._alpaca_to_conversation(conversation_dict)else:raiseValueError(f"Unsupported format: {self._format}")
def_alpaca_to_conversation(self,turn:dict)->Conversation:"""Convert an Alpaca-style turn to a Conversation object. Args: turn: A dictionary containing 'instruction', 'input', and 'output' keys. Returns: Conversation: A Conversation object representing the Alpaca-style turn. Raises: ValueError: If the turn doesn't contain the required keys. """required_keys=["instruction","input","output"]ifnotall(keyinturnforkeyinrequired_keys):raiseValueError(f"Invalid Alpaca format. The turn must contain all of these keys: ""{required_keys}. "f"Found keys: {list(turn.keys())}")messages=[{"role":"user","content":f"{turn['instruction']}\n\n{turn['input']}".strip(),},{"role":"assistant","content":turn["output"]},]returnConversation(messages=messages)