Source code for oumi.inference.native_text_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import PIL.Image
import torch
import transformers
from tqdm import tqdm
from transformers import BatchEncoding
from typing_extensions import override

from oumi.builders import (
    build_model,
    build_processor,
    build_tokenizer,
    is_image_text_llm,
)
from oumi.core.configs import GenerationParams, InferenceConfig, ModelParams
from oumi.core.inference import BaseInferenceEngine
from oumi.core.processors.base_processor import BaseProcessor
from oumi.core.types.conversation import Conversation, Message, Role, Type
from oumi.utils.image_utils import load_pil_image_from_bytes
from oumi.utils.logging import logger


[docs] class NativeTextInferenceEngine(BaseInferenceEngine): """Engine for running text-to-text model inference.""" def __init__( self, model_params: ModelParams, *, generation_params: Optional[GenerationParams] = None, ): """Initializes the inference Engine. Args: model_params: The model parameters to use for inference. generation_params: Parameters for generation. """ super().__init__(model_params=model_params, generation_params=generation_params) self._model = build_model(self._model_params) self._tokenizer = build_tokenizer(self._model_params) self._processor: Optional[BaseProcessor] = None if not hasattr(self._model, "generate"): raise ValueError( f"Model {self._model_params.model_name} does not support generation." ) if is_image_text_llm(self._model_params): # Only enable Processor for vision language models for now. self._processor = build_processor( self._model_params.model_name, self._tokenizer, trust_remote_code=self._model_params.trust_remote_code, ) # https://stackoverflow.com/questions/69609401/suppress-huggingface-logging-warning-setting-pad-token-id-to-eos-token-id self._model.generation_config.pad_token_id = self._tokenizer.pad_token_id def _make_batches( self, input: list[Conversation], batch_size: int ) -> list[list[Conversation]]: """Splits the input into batches of the specified size. Args: input: A list of text prompts. batch_size: The number of sequences to generate in parallel. Returns: List[List[str]]: A list of batches of text prompts. """ return [input[i : i + batch_size] for i in range(0, len(input), batch_size)] def _update_stop_criteria( self, generation_params: GenerationParams ) -> GenerationParams: """Updates the stop tokens/strings in the generation params, if needed. Args: generation_params: Parameters for generation during inference. Returns: GenerationParams: Updated generation params. Note: model.generate accepts both `stop_strings` and `stop_token_ids` as stop criteria. Though these are defined as lists in our generation config (for compatibility with other APIs), in this API they could also be single values (a `str` or an `int`). If both are provided, we will stop at the first one that is found, either a stop string or a stop token id. """ if self._tokenizer.eos_token and generation_params.stop_strings: if self._tokenizer.eos_token not in generation_params.stop_strings: logger.warning( f"User-defined EOS token(s) {generation_params.stop_strings} do NOT" f" include the tokenizer's default EOS token" f" `{self._tokenizer.eos_token}`." ) if self._tokenizer.eos_token_id and generation_params.stop_token_ids: if self._tokenizer.eos_token_id not in generation_params.stop_token_ids: logger.warning( f"User-defined EOS token ids(s) {generation_params.stop_token_ids}" f" do NOT include the tokenizer's default EOS token id" f" `{self._tokenizer.eos_token_id}`." ) if not generation_params.stop_token_ids and not generation_params.stop_strings: if self._tokenizer.eos_token_id: eos_token_id = self._tokenizer.eos_token_id logger.info(f"Setting EOS token id to `{eos_token_id}`") if not isinstance(eos_token_id, int): raise RuntimeError( f"Tokenizer's `eos_token_id` is not an integer: " f"{eos_token_id}. Type: {type(eos_token_id)}" ) generation_params.stop_token_ids = [eos_token_id] elif self._tokenizer.eos_token: eos_token = self._tokenizer.eos_token logger.info(f"Setting EOS token to `{eos_token}`") if not isinstance(eos_token, str): raise RuntimeError( f"Tokenizer's `eos_token_id` is not a string: " f"{eos_token}. Type: {type(eos_token)}" ) generation_params.stop_strings = [eos_token] else: logger.warning("No EOS token defined.") return generation_params def _apply_chat_template_impl(self, conversation: Conversation) -> str: if self._processor is None: return self._tokenizer.apply_chat_template( conversation, # type: ignore tokenize=False, add_generation_prompt=True, ) return self._processor.apply_chat_template( conversation, # type: ignore add_generation_prompt=True, ) def _generate_batch_encoding_with_tokenizer( self, text_prompts: list[str] ) -> BatchEncoding: return self._tokenizer(text_prompts, return_tensors="pt", padding=True) def _generate_batch_encoding_with_processor( self, text_prompts: list[str], conversations: list[Conversation] ) -> BatchEncoding: assert len(text_prompts) == len(conversations) assert self._processor is not None pil_images: list[PIL.Image.Image] = [] for i, conversation in enumerate(conversations): image_items = [ item for m in conversation.messages for item in m.image_content_items ] num_images = len(image_items) if num_images >= 1: if num_images > 1: # FIXME OPE-355 Support multiple images logger.warning( conversation.append_id_to_string( f"A conversation contains multiple images ({num_images}). " "Only 1 image is currently supported. " "Using the last image." ) ) if len(pil_images) != i: raise ValueError( conversation.append_id_to_string( "All or none conversations in a batch must contain images." ) ) image_item = image_items[-1] if image_item.type != Type.IMAGE_BINARY: raise NotImplementedError( conversation.append_id_to_string( "Only binary image messages (`IMAGE_BINARY`) " f"are supported. Actual: {image_item.type}" ) ) elif image_item.binary is None or len(image_item.binary) == 0: raise ValueError( conversation.append_id_to_string( "No image bytes in a binary image message (`IMAGE_BINARY`)!" ) ) image = load_pil_image_from_bytes(image_item.binary) pil_images.append(image) batch = self._processor( text=text_prompts, images=(pil_images if len(pil_images) > 0 else None), return_tensors="pt", padding=True, ) return batch def _infer( self, input: list[Conversation], inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs batch inference for a model using the provided configuration. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: object: A list of model responses of shape (num_batches, batch_size). """ generation_params = ( inference_config.generation if inference_config and inference_config.generation else self._generation_params ) model_device = next(self._model.parameters()).device if generation_params.batch_size is None: logger.warning("Batch size not specified. Defaulting to 1.") generation_params.batch_size = 1 batched_input: list[list[Conversation]] = self._make_batches( input, generation_params.batch_size ) num_batches: int = len(batched_input) input_batches: list[BatchEncoding] = [BatchEncoding()] * num_batches for batch_index in range(num_batches): batch = batched_input[batch_index] text_prompts: list[str] = [ self._apply_chat_template_impl(conversation) for conversation in batch ] if self._processor is None: batch = self._generate_batch_encoding_with_tokenizer(text_prompts) else: batch = self._generate_batch_encoding_with_processor( text_prompts, batch ) input_batches[batch_index] = batch.to(model_device) # Validate or (if needed) set the End Of Sequence (EOS) tokens/strings. generation_params = self._update_stop_criteria(generation_params) # Create a GenerationConfig object with the new parameters # Documentation: https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig use_sampling = generation_params.use_sampling extra_kwargs = {} min_p, temperature = generation_params.min_p, generation_params.temperature if use_sampling: extra_kwargs["min_p"] = min_p extra_kwargs["temperature"] = temperature elif min_p > 0.0 or temperature > 0.0: logger.debug( f"The sampling params: min_p: {min_p} and temperature: {temperature} " "are ignored because sampling is disabled!" ) generation_config = transformers.GenerationConfig( max_new_tokens=generation_params.max_new_tokens, top_p=generation_params.top_p, frequency_penalty=generation_params.frequency_penalty, presence_penalty=generation_params.presence_penalty, do_sample=use_sampling, include_stop_str_in_output=False, detokenize=True, seed=generation_params.seed, stop_strings=generation_params.stop_strings, eos_token_id=generation_params.stop_token_ids, num_beams=generation_params.num_beams, use_cache=generation_params.use_cache, **extra_kwargs, ) # skip using a progress for single turns disable_tgdm = len(input) < 2 # Generate model outputs (batch mode). output_conversations = [] for batch_index in tqdm( range(len(input_batches)), desc="Generating Model Responses", disable=disable_tgdm, ): batch = input_batches[batch_index] output_batch = self._model.generate( **batch, generation_config=generation_config, tokenizer=self._tokenizer ) # For each batch, remove the prepended prompts from all model reponses. if generation_params.exclude_prompt_from_response: new_batch_data = [] for response_index, response in enumerate(output_batch.data): prompt = input_batches[batch_index]["input_ids"][response_index] # type: ignore # Sanity check prompt_as_list = prompt.tolist() response_prefix_as_list = response[: len(prompt)].tolist() if prompt_as_list != response_prefix_as_list: raise RuntimeError( "Inconsistent prompt prefix content! " f"\nRequest: {prompt_as_list} " f"\nResponse: {response_prefix_as_list}" ) new_batch_data.append(response[len(prompt) :]) output_batch.data = torch.stack(new_batch_data, dim=0) output_batch_decoded = self._tokenizer.batch_decode( output_batch.data, skip_special_tokens=True, clean_up_tokenization_spaces=True, ) for conversation, response in zip( batched_input[batch_index], output_batch_decoded ): messages = [ *conversation.messages, Message(role=Role.ASSISTANT, content=response), ] new_conversation = Conversation( messages=messages, metadata=conversation.metadata, conversation_id=conversation.conversation_id, ) if inference_config and inference_config.output_path: self._save_conversation( new_conversation, inference_config.output_path ) output_conversations.append(new_conversation) return output_conversations
[docs] @override def infer_online( self, input: list[Conversation], inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference online. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """ return self._infer(input, inference_config)
[docs] @override def infer_from_file( self, input_filepath: str, inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference on inputs in the provided file. This is a convenience method to prevent boilerplate from asserting the existence of input_filepath in the generation_params. Args: input_filepath: Path to the input file containing prompts for generation. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """ input = self._read_conversations(input_filepath) return self._infer(input, inference_config)
[docs] @override def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { "batch_size", "exclude_prompt_from_response", "frequency_penalty", "max_new_tokens", "min_p", "presence_penalty", "seed", "stop_strings", "stop_token_ids", "temperature", "top_p", "use_sampling", "use_cache", "num_beams", }