Source code for oumi.datasets.vision_language.mnist_sft
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,Optionalfromtyping_extensionsimportoverridefromoumi.core.datasetsimportVisionLanguageSftDatasetfromoumi.core.registryimportregister_datasetfromoumi.core.types.conversationimport(ContentItem,Conversation,Message,Role,Type,)
[docs]@register_dataset("mnist_sft")classMnistSftDataset(VisionLanguageSftDataset):"""MNIST dataset formatted as SFT data. MNIST is a well-known small dataset, can be useful for quick tests, prototyping, debugging. """default_dataset="ylecun/mnist"def__init__(self,*,dataset_name:Optional[str]=None,**kwargs,)->None:"""Initializes a new instance of the MnistSftDataset class."""super().__init__(dataset_name="ylecun/mnist",**kwargs,)@staticmethoddef_to_digit(value:Any)->int:result:int=0try:result=int(value)exceptException:raiseValueError(f"Failed to convert MNIST 'label' ({value}) to an integer!")ifnot(result>=0andresult<=9):raiseValueError(f"MNIST digit ({result}) is not in [0,9] range!")returnresult
[docs]@overridedeftransform_conversation(self,example:dict)->Conversation:"""Transform a single MNIST example into a Conversation object."""input_text="What digit is in this picture?"output_digit=self._to_digit(example["label"])returnConversation(messages=[Message(role=Role.USER,content=[ContentItem(type=Type.IMAGE_BINARY,binary=example["image"]["bytes"],),ContentItem(type=Type.TEXT,content=input_text),],),Message(role=Role.ASSISTANT,content=str(output_digit)),])