Source code for oumi.datasets.vision_language.llava_instruct_mix_vsft
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtyping_extensionsimportoverridefromoumi.core.datasetsimportVisionLanguageSftDatasetfromoumi.core.registryimportregister_datasetfromoumi.core.types.conversationimport(ContentItem,Conversation,Message,Role,Type,)fromoumi.utils.loggingimportlogger
[docs]@register_dataset("HuggingFaceH4/llava-instruct-mix-vsft")classLlavaInstructMixVsftDataset(VisionLanguageSftDataset):"""Dataset class for the `HuggingFaceH4/llava-instruct-mix-vsft` dataset."""default_dataset="HuggingFaceH4/llava-instruct-mix-vsft"def_process_text_value(self,s:str)->str:# The data contains occasional `\n` at the beginning or end# of text values. Let's strip them.returns.strip()ifselse""def_parse_user_messages(self,message_list:list[dict],images:list[dict])->Message:role=Role.USERiflen(message_list)notin(1,2):raiseValueError(f"The `content` field for '{role}' must "f"contain 1 or 2 elements (question, and, optionally, image). "f"Actual: {len(message_list)}")text_items:list[ContentItem]=[]image_items:list[ContentItem]=[]foruser_messageinmessage_list:message_type=user_message["type"]ifmessage_type=="text":text_items.append(ContentItem(type=Type.TEXT,content=self._process_text_value(user_message["text"]),))elifmessage_type=="image":image_index=int(user_message["index"])ifnot(image_index>=0andimage_index<len(images)):raiseValueError(f"Image index is out-of-bounds. "f"Index: {image_index} "f"Image count: {len(images)}")image_dict=images[image_index]if"bytes"inimage_dictandimage_dict["bytes"]:image_items.append(ContentItem(type=Type.IMAGE_BINARY,binary=image_dict["bytes"],))elif"path"inimage_dictandimage_dict["path"]:image_items.append(ContentItem(type=Type.IMAGE_PATH,content=image_dict["path"],))else:raiseValueError(f"Image element must include 'bytes' or 'path'. "f"Actual keys: {image_dict.keys()}")else:raiseValueError(f"{role}'s question has unknown type: '{message_type}'")iflen(text_items)!=1:raiseValueError(f"{role}'s turn must include 1 text question. Actual: {len(text_items)}")iflen(image_items)>1:raiseValueError(f"{role}'s turn must include max 1 image. Actual: {len(image_items)}")# Add image messages before text messages!returnMessage(role=role,content=(image_items+text_items))def_parse_assistant_messages(self,message_list:list[dict])->Message:role=Role.ASSISTANTiflen(message_list)!=1:raiseValueError(f"The `content` field for {role} must "f"contain exactly 1 element (response). "f"Actual: {len(message_list)}")response_type=message_list[0]["type"]ifresponse_type!="text":raiseValueError(f"{role}'s response is expected to be text. Actual: {response_type}")returnMessage(role=role,content=self._process_text_value(message_list[0]["text"]),)
[docs]@overridedeftransform_conversation(self,example:dict)->Conversation:"""Transform a dataset example into a Conversation object."""example_messages=example.get("messages")ifexample_messagesisNoneorlen(example_messages)==0:raiseValueError("No messages in input example.")images=example.get("images")ifimagesisNoneorlen(images)==0:raiseValueError("No images in input example.")eliflen(images)!=1:logger.warning(f"Example contains multiple images: {len(images)}")messages:list[Message]=[]formessageinexample_messages:message_list=message.get("content")ifmessage_listisNoneorlen(message_list)==0:raiseValueError("Missing or empty `content` field in message.")ifmessage["role"]=="user":messages.append(self._parse_user_messages(message_list,images))elifmessage["role"]=="assistant":messages.append(self._parse_assistant_messages(message_list))else:raiseValueError(f"Unknown role: {message['from']}")returnConversation(messages=messages)