Source code for oumi.datasets.vision_language.llava_instruct_mix_vsft
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing_extensions import override
from oumi.core.datasets import VisionLanguageSftDataset
from oumi.core.registry import register_dataset
from oumi.core.types.conversation import (
ContentItem,
Conversation,
Message,
Role,
Type,
)
from oumi.utils.logging import logger
[docs]
@register_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
class LlavaInstructMixVsftDataset(VisionLanguageSftDataset):
"""Dataset class for the `HuggingFaceH4/llava-instruct-mix-vsft` dataset."""
default_dataset = "HuggingFaceH4/llava-instruct-mix-vsft"
def _process_text_value(self, s: str) -> str:
# The data contains occasional `\n` at the beginning or end
# of text values. Let's strip them.
return s.strip() if s else ""
def _parse_user_messages(
self, message_list: list[dict], images: list[dict]
) -> Message:
role = Role.USER
if len(message_list) not in (1, 2):
raise ValueError(
f"The `content` field for '{role}' must "
f"contain 1 or 2 elements (question, and, optionally, image). "
f"Actual: {len(message_list)}"
)
text_items: list[ContentItem] = []
image_items: list[ContentItem] = []
for user_message in message_list:
message_type = user_message["type"]
if message_type == "text":
text_items.append(
ContentItem(
type=Type.TEXT,
content=self._process_text_value(user_message["text"]),
)
)
elif message_type == "image":
image_index = int(user_message["index"])
if not (image_index >= 0 and image_index < len(images)):
raise ValueError(
f"Image index is out-of-bounds. "
f"Index: {image_index} "
f"Image count: {len(images)}"
)
image_dict = images[image_index]
if "bytes" in image_dict and image_dict["bytes"]:
image_items.append(
ContentItem(
type=Type.IMAGE_BINARY,
binary=image_dict["bytes"],
)
)
elif "path" in image_dict and image_dict["path"]:
image_items.append(
ContentItem(
type=Type.IMAGE_PATH,
content=image_dict["path"],
)
)
else:
raise ValueError(
f"Image element must include 'bytes' or 'path'. "
f"Actual keys: {image_dict.keys()}"
)
else:
raise ValueError(
f"{role}'s question has unknown type: '{message_type}'"
)
if len(text_items) != 1:
raise ValueError(
f"{role}'s turn must include 1 text question. "
f"Actual: {len(text_items)}"
)
if len(image_items) > 1:
raise ValueError(
f"{role}'s turn must include max 1 image. "
f"Actual: {len(image_items)}"
)
# Add image messages before text messages!
return Message(role=role, content=(image_items + text_items))
def _parse_assistant_messages(self, message_list: list[dict]) -> Message:
role = Role.ASSISTANT
if len(message_list) != 1:
raise ValueError(
f"The `content` field for {role} must "
f"contain exactly 1 element (response). "
f"Actual: {len(message_list)}"
)
response_type = message_list[0]["type"]
if response_type != "text":
raise ValueError(
f"{role}'s response is expected to be text. Actual: {response_type}"
)
return Message(
role=role,
content=self._process_text_value(message_list[0]["text"]),
)