Source code for oumi.datasets.vision_language.coco_captions
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing_extensions import override
from oumi.core.datasets import VisionLanguageSftDataset
from oumi.core.registry import register_dataset
from oumi.core.types.conversation import (
ContentItem,
Conversation,
Message,
Role,
Type,
)
_COCO_COLUMN_SENTENCES = "sentences"
_COCO_COLUMN_RAW = "raw"
_COCO_COLUMN_IMAGE = "image"
_COCO_COLUMN_PATH = "path"
_COCO_COLUMN_BYTES = "bytes"
[docs]
@register_dataset("coco_captions")
class COCOCaptionsDataset(VisionLanguageSftDataset):
"""Dataset class for the `HuggingFaceM4/COCO` dataset."""
default_dataset = "HuggingFaceM4/COCO"
default_prompt = "Describe this image:"
[docs]
@override
def transform_conversation(self, example: dict) -> Conversation:
"""Transform a single conversation example into a Conversation object."""
input_text = self.default_prompt
for required_key in (_COCO_COLUMN_SENTENCES, _COCO_COLUMN_IMAGE):
if required_key not in example:
raise ValueError(
"Training example doesn't contain '{required_key}' key. "
f"Available keys: {example.keys()}."
)
if _COCO_COLUMN_RAW not in example[_COCO_COLUMN_SENTENCES]:
raise ValueError(
"Training example doesn't contain 'sentences.raw' key. Available keys "
f"under 'sentences.': {example[_COCO_COLUMN_SENTENCES].keys()}."
)
output_text = example[_COCO_COLUMN_SENTENCES][_COCO_COLUMN_RAW]
user_items: list[ContentItem] = []
if _COCO_COLUMN_BYTES in example[_COCO_COLUMN_IMAGE]:
user_items.append(
ContentItem(
binary=example[_COCO_COLUMN_IMAGE][_COCO_COLUMN_BYTES],
type=Type.IMAGE_BINARY,
)
)
elif _COCO_COLUMN_PATH in example[_COCO_COLUMN_IMAGE]:
user_items.append(
ContentItem(
content=example[_COCO_COLUMN_IMAGE][_COCO_COLUMN_PATH],
type=Type.IMAGE_PATH,
)
)
else:
raise ValueError(
"Training example contains none of required keys: "
"'image.bytes', 'image.path'. "
f"Available keys under 'image.': {example[_COCO_COLUMN_IMAGE].keys()}."
)
user_items.append(ContentItem(type=Type.TEXT, content=input_text))
return Conversation(
messages=[
Message(role=Role.USER, content=user_items),
Message(role=Role.ASSISTANT, content=output_text),
]
)