Source code for oumi.datasets.vision_language.huggingface
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Generic class for using HuggingFace vision-language datasets.Allows users to specify the image, question, and answer columns at the config level."""importbase64frompathlibimportPathfromtypingimportAny,Optional,Unionimportpandasaspdfromtyping_extensionsimportoverridefromoumi.core.datasetsimportVisionLanguageSftDatasetfromoumi.core.registryimportregister_datasetfromoumi.core.types.conversationimport(ContentItem,Conversation,Message,Role,Type,)
[docs]@register_dataset("hf_vision")classHuggingFaceVisionDataset(VisionLanguageSftDataset):"""Converts HuggingFace Vision-Language Datasets to Oumi Message format. This dataset handles standard HuggingFace datasets that contain: - An image column (containing image data or paths) - A question/prompt column (text input) - An optional answer column (text output) Example:: dataset = HuggingFaceVisionDataset( hf_dataset_path="HuggingFaceM4/VQAv2", image_column="image", question_column="question", answer_column="answer" ) """def__init__(self,*,hf_dataset_path:str,image_column:str,question_column:str,answer_column:Optional[str]=None,system_prompt_column:Optional[str]=None,system_prompt:Optional[str]=None,**kwargs,)->None:"""Initializes a new instance of the HuggingFaceVisionDataset class. Args: hf_dataset_path: Path to the HuggingFace dataset. image_column: Name of the column containing image data. question_column: Name of the column containing the question/prompt text. answer_column: Optional name of the column containing the answer text. system_prompt: Optional system prompt to add as the first message. system_prompt_column: Optional name of the column containing system prompts. **kwargs: Additional arguments passed to the parent class. """ifnothf_dataset_path:raiseValueError("The `hf_dataset_path` parameter must be provided.")ifnotimage_column:raiseValueError("The `image_column` parameter must be provided.")ifnotquestion_column:raiseValueError("The `question_column` parameter must be provided.")self.image_column=image_columnself.question_column=question_columnself.answer_column=answer_columnself.system_prompt=system_promptself.system_prompt_column=system_prompt_columnifsystem_promptandsystem_prompt_column:raiseValueError("Only one of `system_prompt` or `system_prompt_column` can be provided.")ifPath(hf_dataset_path).exists():# If the path exists, it's a local datasetkwargs["dataset_path"]=hf_dataset_pathkwargs["dataset_name"]="hf_vision"else:# Otherwise, assume it's a remote datasetkwargs["dataset_name"]=hf_dataset_pathsuper().__init__(**kwargs)def_get_image_content_item(self,image_data)->ContentItem:"""Create a ContentItem for the image data. Args: image_data: Image data from the dataset (could be bytes, PIL Image, etc.). Returns: ContentItem containing the image data. """ifisinstance(image_data,bytes):# Raw bytesreturnContentItem(type=Type.IMAGE_BINARY,binary=image_data,)elifhasattr(image_data,"bytes"):# PIL Image or similar object with bytes attributereturnContentItem(type=Type.IMAGE_BINARY,binary=image_data.bytes,)elifisinstance(image_data,dict)and"bytes"inimage_data:# Dict with bytesreturnContentItem(type=Type.IMAGE_BINARY,binary=image_data["bytes"],)elifisinstance(image_data,str):ifimage_data.startswith(("http://","https://")):returnContentItem(type=Type.IMAGE_URL,content=image_data)else:# Assume it's a base64 imagereturnContentItem(type=Type.IMAGE_BINARY,binary=base64.b64decode(image_data))else:raiseValueError(f"Unsupported image data type: {type(image_data)}. ""Expected bytes, PIL Image, URL string, or base64 encoded string.")
[docs]@overridedeftransform_conversation(self,example:Union[dict,pd.Series])->Conversation:"""Preprocesses the inputs of the example and returns a Conversation. Args: example: An example containing image, question, and optionally answer data. Returns: Conversation: A Conversation object containing the messages. """messages=[]# Validate required columnsrequired={"image_column":self.image_column,"question_column":self.question_column,}ifself.answer_column:required["answer_column"]=self.answer_columnifself.system_prompt_column:required["system_prompt_column"]=self.system_prompt_columnforcolumn_name,column_varinrequired.items():ifcolumn_varnotinexample:raiseValueError(f"The column '{column_name}' (specified as {column_var}) "f"is not present in the example. "f"Available columns: {list(example.keys())}")# Add system prompt if available (either static or from column)ifself.system_prompt:system_message_content=self.system_promptelse:system_message_content=self._process_text_value(example.get(self.system_prompt_column))ifsystem_message_content:messages.append(Message(role=Role.SYSTEM,content=system_message_content))# Extract and process the dataimage_data=example[self.image_column]question_text=self._process_text_value(example[self.question_column])# Create the image content itemimage_content_item=self._get_image_content_item(image_data)# Create the user message with image and textuser_message=Message(role=Role.USER,content=[image_content_item,ContentItem(type=Type.TEXT,content=question_text),],)messages.append(user_message)# Add assistant response if answer column is specified and presentifself.answer_columnandself.answer_columninexample:answer_text=self._process_text_value(example[self.answer_column])assistant_message=Message(role=Role.ASSISTANT,content=answer_text)messages.append(assistant_message)returnConversation(messages=messages)
def_process_text_value(self,s:Any)->str:"""Process a text value. Args: s: The text value to process. Returns: The processed text value. """ifsisNone:return""ifisinstance(s,str):# The data contains occasional `\n` at the beginning or end# of text values. Let's strip them.returns.strip()raiseValueError(f"Unsupported text value type: {type(s)}")