Source code for oumi.core.inference.base_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional

import jsonlines
from tqdm import tqdm

from oumi.core.configs import (
    GenerationParams,
    InferenceConfig,
    ModelParams,
)
from oumi.core.types.conversation import Conversation
from oumi.utils.logging import logger


[docs] class BaseInferenceEngine(ABC): """Base class for running model inference.""" _model_params: ModelParams """The model parameters.""" _generation_params: GenerationParams """The generation parameters.""" def __init__( self, model_params: ModelParams, *, generation_params: Optional[GenerationParams] = None, ): """Initializes the inference engine. Args: model_params: The model parameters. generation_params: The generation parameters. """ self._model_params = copy.deepcopy(model_params) if generation_params: self._check_unsupported_params(generation_params) else: generation_params = GenerationParams() self._generation_params = generation_params
[docs] def infer( self, input: Optional[list[Conversation]] = None, inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference. Args: input: A list of conversations to run inference on. Optional. inference_config: Parameters for inference. If not specified, a default config is inferred. Returns: List[Conversation]: Inference output. """ if input is not None and ( inference_config and inference_config.input_path is not None ): raise ValueError( "Only one of input or inference_config.input_path should be provided." ) if inference_config and inference_config.generation: generation_params = inference_config.generation self._check_unsupported_params(generation_params) else: generation_params = self._generation_params if input is not None: return self.infer_online(input, inference_config) elif inference_config and inference_config.input_path is not None: return self.infer_from_file(inference_config.input_path, inference_config) else: raise ValueError( "One of input or inference_config.input_path must be provided." )
def _read_conversations(self, input_filepath: str) -> list[Conversation]: """Reads conversations from a file in Oumi chat format. Args: input_filepath: The path to the file containing the conversations. Returns: List[Conversation]: A list of conversations read from the file. """ conversations = [] with open(input_filepath) as f: for line in f: # Only parse non-empty lines. if line.strip(): conversation = Conversation.from_json(line) conversations.append(conversation) return conversations def _get_scratch_filepath(self, output_filepath: str) -> str: """Returns a scratch filepath for the given output filepath. For example, if the output filepath is "/foo/bar/output.json", the scratch filepath will be "/foo/bar/scratch/output.json" Args: output_filepath: The output filepath. Returns: str: The scratch filepath. """ original_filepath = Path(output_filepath) return str(original_filepath.parent / "scratch" / original_filepath.name) def _save_conversation( self, conversation: Conversation, output_filepath: str ) -> None: """Appends a conversation to a file in Oumi chat format. Args: conversation: The conversation to save. output_filepath: The path to the file where the conversation should be saved. """ # Make the directory if it doesn't exist. Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(output_filepath, mode="a") as writer: json_obj = conversation.to_dict() writer.write(json_obj) def _save_conversations( self, conversations: list[Conversation], output_filepath: str ) -> None: """Saves conversations to a file in Oumi chat format. Args: conversations: A list of conversations to save. output_filepath: The path to the file where the conversations should be saved. """ # Make the directory if it doesn't exist. Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(output_filepath, mode="w") as writer: for conversation in tqdm(conversations, desc="Saving conversations"): json_obj = conversation.to_dict() writer.write(json_obj) def _check_unsupported_params(self, generation_params: GenerationParams): """Checks for unsupported parameters and logs warnings. If a parameter is not supported, and a non-default value is provided, a warning is logged. """ supported_params = self.get_supported_params() default_generation_params = GenerationParams() for param_name, value in generation_params: if param_name not in supported_params: is_non_default_value = ( getattr(default_generation_params, param_name) != value ) if is_non_default_value: logger.warning( f"{self.__class__.__name__} does not support {param_name}. " f"Received value: {param_name}={value}. " "This parameter will be ignored." )
[docs] @abstractmethod def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine. Override this method in derived classes to specify which parameters are supported. Returns: Set[str]: A set of supported parameter names. """ raise NotImplementedError
[docs] @abstractmethod def infer_online( self, input: list[Conversation], inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference online. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """ raise NotImplementedError
[docs] @abstractmethod def infer_from_file( self, input_filepath: str, inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference on inputs in the provided file. This is a convenience method to prevent boilerplate from asserting the existence of input_filepath in the generation_params. Args: input_filepath: Path to the input file containing prompts for generation. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """ raise NotImplementedError
[docs] def apply_chat_template( self, conversation: Conversation, **tokenizer_kwargs ) -> str: """Applies the chat template to the conversation. Args: conversation: The conversation to apply the chat template to. tokenizer_kwargs: Additional keyword arguments to pass to the tokenizer. Returns: str: The conversation with the chat template applied. """ tokenizer = getattr(self, "_tokenizer", None) if tokenizer is None: raise ValueError("Tokenizer is not initialized.") if tokenizer.chat_template is None: raise ValueError("Tokenizer does not have a chat template.") if "tokenize" not in tokenizer_kwargs: tokenizer_kwargs["tokenize"] = False return tokenizer.apply_chat_template(conversation, **tokenizer_kwargs)