Source code for oumi.utils.conversation_utils

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
from typing import Any, Union

import PIL.Image

from oumi.core.types.conversation import ContentItem, Message, Type
from oumi.utils.image_utils import (
    DEFAULT_IMAGE_MODE,
    load_image_png_bytes_from_path,
    load_image_png_bytes_from_url,
    load_pil_image_from_bytes,
    load_pil_image_from_path,
    load_pil_image_from_url,
)


[docs] def load_image_bytes_to_content_item( item: ContentItem, mode: str = DEFAULT_IMAGE_MODE ) -> ContentItem: """Ensures that message content item contains inline image bytes if it's an image. Loads image content if image type is `IMAGE_URL` or `IMAGE_PATH`. Otherwise returns the input content item w/o any changes. Args: item: An input message content item. mode: The requested image mode e.g., "RGB", "HSV", "RGBA", "P" (8-bit pixels, using a color palette). For details, see https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes Returns: A content item guaranteed to be `IMAGE_BINARY` if an input content item was any of image types (`IMAGE_URL`, `IMAGE_PATH`, `IMAGE_BINARY`). """ if item.type in (Type.IMAGE_PATH, Type.IMAGE_URL): if item.type == Type.IMAGE_PATH: if item.content is None: raise ValueError("Image path is None") png_bytes = load_image_png_bytes_from_path(item.content, mode=mode) else: assert item.type == Type.IMAGE_URL if item.content is None: raise ValueError("Image URL is None") png_bytes = load_image_png_bytes_from_url(item.content, mode=mode) return ContentItem(type=Type.IMAGE_BINARY, binary=png_bytes) return item
[docs] def load_pil_image_from_content_item( image_item: ContentItem, mode: str = DEFAULT_IMAGE_MODE ) -> PIL.Image.Image: """Loads a PIL image from a message content item. Args: image_item: A content item representing an image. mode: The requested image mode e.g., "RGB", "HSV", "RGBA", "P" (8-bit pixels, using a color palette). For details, see https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes Returns: Image.Image: A PIL image. """ if image_item.type == Type.IMAGE_PATH: if image_item.content is None: raise ValueError("Image path is None") image_bin = load_pil_image_from_path(image_item.content, mode=mode) elif image_item.type == Type.IMAGE_URL: if image_item.content is None: raise ValueError("Image URL is None") image_bin = load_pil_image_from_url(image_item.content, mode=mode) elif image_item.type == Type.IMAGE_BINARY: if image_item.binary is None: raise ValueError("Image binary is None") image_bin = load_pil_image_from_bytes(image_item.binary, mode=mode) else: raise ValueError( f"Unsupported content item type: {image_item.type}. Not an image!" ) return image_bin
[docs] def base64encode_content_item_image_bytes( item: ContentItem, *, add_mime_prefix: bool = True ) -> str: """Creates base-64 encoded image bytes as ASCII string value. Args: item: An input message content item of image type (one of `IMAGE_BINARY`, `IMAGE_PATH, `IMAGE_URL`) with the pre-populated `binary` field. add_mime_prefix: Whether to add MIME prefix `data:image/png;base64,` Returns: String containing base64 encoded image bytes `<BASE64_VALUE>`. If `add_mime_prefix` is True, then the following format is used: `data:image/png;base64,<BASE64_VALUE>`. """ if not item.is_image(): raise ValueError(f"Message type is not an image: {item.type}") elif not item.binary: raise ValueError(f"No image bytes in message: {item.type}") base64_str = base64.b64encode(item.binary).decode(encoding="utf8") return ("data:image/png;base64," + base64_str) if add_mime_prefix else base64_str
_JSON_DICT_KEY_TYPE: str = "type" _JSON_DICT_KEY_TEXT: str = "text" _JSON_DICT_KEY_IMAGE_URL: str = "image_url" _JSON_DICT_KEY_URL: str = "url"
[docs] def convert_message_content_item_to_json_dict( item: ContentItem, ) -> dict[str, Any]: """Returns the content for a message content item. Args: item: The message content item to get the content for. Returns: Dict[str, Any]: The content for the message. """ if item.type == Type.TEXT: return { _JSON_DICT_KEY_TYPE: Type.TEXT.value, _JSON_DICT_KEY_TEXT: (item.content or ""), } elif not item.is_image(): raise ValueError(f"Unsupported message type: {item.type}") if not item.binary and item.type != Type.IMAGE_URL: item = load_image_bytes_to_content_item(item) if item.binary: b64_image = base64encode_content_item_image_bytes(item, add_mime_prefix=True) return { _JSON_DICT_KEY_TYPE: Type.IMAGE_URL.value, _JSON_DICT_KEY_IMAGE_URL: {_JSON_DICT_KEY_URL: b64_image}, } assert ( item.type == Type.IMAGE_URL ), f"Unexpected message type: {item.type}. Must be a code bug." return { _JSON_DICT_KEY_TYPE: Type.IMAGE_URL.value, _JSON_DICT_KEY_IMAGE_URL: {_JSON_DICT_KEY_URL: item.content or ""}, }
[docs] def convert_content_items_to_json_list( content_items: list[ContentItem], ) -> list[dict[str, Any]]: """Converts content items to a list of JSON dicts. Args: content_items: A list of content items. Returns: list[Dict[str, Any]]: The list of all content items encoded as JSON dicts. """ return [convert_message_content_item_to_json_dict(item) for item in content_items]
[docs] def convert_message_to_json_content_list( message: Message, ) -> list[dict[str, Any]]: """Returns the message content as a list of its content items encoded as JSON dicts. Args: message: The message to get the content for. Returns: list[Dict[str, Any]]: The content for the message for all content items. """ return convert_content_items_to_json_list(message.content_items)
[docs] def convert_message_to_json_content( message: Message, ) -> Union[str, list[dict[str, Any]]]: """Returns the message content. Args: message: The message to get the content for. Returns: The content for the message returned either as a single string, or as a list of content items. """ if isinstance(message.content, str): return message.content assert isinstance(message.content, list) return convert_content_items_to_json_list(message.content_items)
[docs] def create_list_of_message_json_dicts( messages: list[Message], *, group_adjacent_same_role_turns: bool, ) -> list[dict[str, Any]]: """Returns a list of JSON dictionaries representing messages. Loads image bytes and encodes them as base64. Args: messages: The input messages. group_adjacent_same_role_turns: Whether to pack adjacent messages from the same role into a single element in output list. Returns: list[Dict[str, Any]]: The list of messages encoded as nested JSON dicts. """ num_messages = len(messages) result = [] idx = 0 while idx < num_messages: end_idx = idx + 1 if group_adjacent_same_role_turns: while end_idx < num_messages and ( messages[idx].role == messages[end_idx].role ): end_idx += 1 item: dict[str, Any] = { "role": messages[idx].role.value, } group_size = end_idx - idx if group_size == 1 and messages[idx].contains_single_text_content_item_only(): # Set "content" to a primitive string value, which is the common # convention for text-only models. item["content"] = messages[idx].text_content_items[0].content else: # Set "content" to be a list of dictionaries for more complex cases. content_list = [] while idx < end_idx: content_list.extend(convert_message_to_json_content_list(messages[idx])) idx += 1 item["content"] = content_list idx = end_idx result.append(item) return result