Source code for oumi.datasets.vision_language.the_cauldron
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
from oumi.core.datasets import VisionLanguageSftDataset
from oumi.core.registry import register_dataset
from oumi.core.types.conversation import ContentItem, Conversation, Message, Role, Type
[docs]
@register_dataset("HuggingFaceM4/the_cauldron")
class TheCauldronDataset(VisionLanguageSftDataset):
"""Dataset class for the `HuggingFaceM4/the_cauldron` dataset.
The `HuggingFaceM4/the_cauldron` dataset is a comprehensive collection of
50 vision-language datasets, primarily training sets, used
for fine-tuning the Idefics2 vision-language model.
The datasets cover various domains such as general visual question answering,
captioning, OCR, document understanding, chart/figure understanding,
table understanding, reasoning, logic, maths, textbook/academic questions,
differences between images, and screenshot to code.
"""
default_dataset = "HuggingFaceM4/the_cauldron"
[docs]
def transform_conversation(self, example: dict[str, Any]) -> Conversation:
"""Transform raw data into a conversation with images."""
for required_key in ("images", "texts"):
if required_key not in example:
raise ValueError(
f"Example doesn't contain '{required_key}'. "
f"Actual keys: {sorted(example.keys())}"
)
if not (isinstance(example[required_key], (list, np.ndarray))):
actual_type = type(example[required_key])
raise ValueError(
f"Example's '{required_key}' must be a list or np.ndarray. "
f"Actual type: {actual_type}"
)
images_list: list[Any] = []
if isinstance(example["images"], np.ndarray):
images_list = example["images"].tolist()
else:
images_list = example["images"]
num_images = len(images_list)
if num_images <= 0:
raise ValueError("Example contains no images.")
image_content_items: list[ContentItem] = []
for idx, image_item in enumerate(images_list):
if not isinstance(image_item, dict):
actual_type = type(image_item)
raise ValueError(
f"Example image type is not `dict`. Actual type: {actual_type} "
f"for image {idx + 1} of {num_images}"
)
image_bytes = image_item["bytes"]
if not isinstance(image_bytes, bytes):
actual_type = type(image_bytes)
raise ValueError(
f"Example image type is not `bytes`. Actual type: {actual_type} "
f"for image {idx + 1} of {num_images}"
)
image_content_items.append(
ContentItem(type=Type.IMAGE_BINARY, binary=image_bytes)
)
texts_list: list[dict] = []
if isinstance(example["texts"], np.ndarray):
texts_list = example["texts"].tolist()
else:
texts_list = example["texts"]
num_texts = len(texts_list)
if num_texts <= 0:
raise ValueError(f"Example must contain some 'texts'. Got: {num_texts}")
messages_list: list[Message] = []
for idx, text_entry in enumerate(texts_list):
if not isinstance(text_entry, dict):
actual_type = type(text_entry)
raise ValueError(
f"Texts entry must be a `dict`. "
f"Actual type: {actual_type} "
f"for text entry {idx + 1} of {num_texts}"
)
elif not (("user" in text_entry) and ("assistant" in text_entry)):
raise ValueError(
f"Texts entry must contain both 'user' and 'assistant' keys. "
f"Got: {sorted(text_entry.keys())} "
f"for text entry {idx + 1} of {num_texts}"
)
if idx == 0:
# Only include image(s) once for the first turn.
messages_list.append(
Message(
role=Role.USER,
content=(
image_content_items
+ [
ContentItem(type=Type.TEXT, content=text_entry["user"]),
]
),
)
)
else:
messages_list.append(
Message(role=Role.USER, content=text_entry["user"])
)
messages_list.append(
Message(role=Role.ASSISTANT, content=text_entry["assistant"]),
)
return Conversation(messages=messages_list)