Source code for oumi.core.inference.base_inference_engine
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcopyimportdataclassesimporthashlibimportjsonimporttimeimportuuidfromabcimportABC,abstractmethodfrompathlibimportPathfromtypingimportOptionalimportjsonlinesfromhdrh.histogramimportHdrHistogramfromtqdmimporttqdmfromoumi.core.configsimport(GenerationParams,InferenceConfig,ModelParams,)fromoumi.core.types.conversationimportConversationfromoumi.utils.loggingimportloggerfromoumi.utils.math_utilsimportis_power_of_two
[docs]classBaseInferenceEngine(ABC):"""Base class for running model inference."""_model_params:ModelParams"""The model parameters."""_generation_params:GenerationParams"""The generation parameters."""def__init__(self,model_params:ModelParams,*,generation_params:Optional[GenerationParams]=None,):"""Initializes the inference engine. Args: model_params: The model parameters. generation_params: The generation parameters. """self._model_params=copy.deepcopy(model_params)ifgeneration_params:self._check_unsupported_params(generation_params)else:generation_params=GenerationParams()self._generation_params=generation_paramsself._latency_histogram_online=HdrHistogram(1,60*1000,1)self._latency_histogram_from_file=HdrHistogram(20,180*1000,1)self._dataset_hash=None
[docs]definfer(self,input:Optional[list[Conversation]]=None,inference_config:Optional[InferenceConfig]=None,)->list[Conversation]:"""Runs model inference. Args: input: A list of conversations to run inference on. Optional. inference_config: Parameters for inference. If not specified, a default config is inferred. Returns: List[Conversation]: Inference output. """ifinputisnotNoneand(inference_configandinference_config.input_pathisnotNone):raiseValueError("Only one of input or inference_config.input_path should be provided.")# Ensure the inference config has up-to-date generation parameters.ifinference_config:ifinference_config.generation:self._check_unsupported_params(inference_config.generation)elifself._generation_params:inference_config=copy.deepcopy(inference_config)inference_config.generation=self._generation_params# Warn the user: They provided an inference config without generation# params, so what was the point of providing it in the first place?logger.warning("No generation parameters provided in the inference config. Using ""the generation parameters that the engine was initialized with.")output_path=inference_config.output_pathifinference_configelseNone# Load input conversations either from file or use provided inputconversations_to_process:list[Conversation]=[]ifinputisnotNone:conversations_to_process=inputelifinference_configandinference_config.input_pathisnotNone:conversations_to_process=self._read_conversations(inference_config.input_path,)else:raiseValueError("One of input or inference_config.input_path must be provided.")unique_conversation_ids=set()fori,conversationinenumerate(conversations_to_process):ifconversation.conversation_idisnotNone:ifconversation.conversation_idinunique_conversation_ids:raiseValueError(f"Conversation ID {conversation.conversation_id} is not unique.")unique_conversation_ids.add(conversation.conversation_id)continue# Generate a deterministic conversation ID based on index if none existsmessages=conversation.messagesall_str_message_content=",".join([message.contentformessageinmessagesifisinstance(message.content,str)])content_hash=hashlib.sha256(all_str_message_content.encode()).hexdigest()id_name=str(i)+"_"+content_hashconversation.conversation_id=str(uuid.uuid5(uuid.NAMESPACE_DNS,id_name,))unique_conversation_ids.add(conversation.conversation_id)# Calculate the hash of the dataset to use for referencing the inference# filename between runs, given the same model and generation parameters.dataset_hash=""iflen(conversations_to_process)>0:row_hashes=[hashlib.sha256(c.to_json().encode()).hexdigest()forcinconversations_to_process]dataset_hash=hashlib.sha256(",".join(row_hashes).encode()).hexdigest()self._dataset_hash=dataset_hashcompleted_conversations=self._load_from_scratch(output_path)# Filter out already completed conversationsremaining_conversations=self._filter_incomplete_conversations(conversations_to_process,completed_conversations)iflen(remaining_conversations)<len(conversations_to_process):logger.info(f"Found {len(completed_conversations)} completed conversations. "f"Processing remaining {len(remaining_conversations)} conversations.")# Run inference only on remaining conversationsstart_time=time.perf_counter()histogram=self._latency_histogram_onlineinference_results=self._infer_online(remaining_conversations,inference_config)histogram.record_value((time.perf_counter()-start_time)*1e3)self._maybe_log_latency_histogram(histogram)iflen(inference_results)==len(conversations_to_process):final_results=inference_resultselse:# Incomplete inference results were saved to scratch file.# Load all results from scratch to get all results.final_results=self._load_from_scratch(output_path)iflen(final_results)!=len(conversations_to_process):raiseValueError(f"Number of final results ({len(final_results)}) does not match "f"number of conversations to process "f"({len(conversations_to_process)}).")self._cleanup_scratch_file(output_path)sorted_conversations={conv.conversation_id:convforconvinconversations_to_process}forconvinfinal_results:sorted_conversations[conv.conversation_id]=convfinal_results=list(sorted_conversations.values())ifinference_configandinference_config.output_path:self._save_conversations(final_results,inference_config.output_path)returnfinal_results
def_maybe_log_latency_histogram(self,histogram:Optional[HdrHistogram])->None:"""Logs the histogram if it is not None. Args: histogram: The histogram to log. """ifhistogramisNone:returntotal_count=histogram.get_total_count()# TODO: Define a better way to enable/configure this logging.ifnot(isinstance(total_count,int)andtotal_count>=2andis_power_of_two(total_count)):returnp50=histogram.get_value_at_percentile(50)p90=histogram.get_value_at_percentile(90)p99=histogram.get_value_at_percentile(99)logger.debug(f"{self.__class__.__name__}: "f"Latency Histogram: {total_count} samples recorded:"f"\tp50: {p50:.1f}ms\tp90: {p90:.1f}ms\tp99: {p99:.1f}ms")def_read_conversations(self,input_filepath:str)->list[Conversation]:"""Reads conversations from a file in Oumi chat format. Args: input_filepath: The path to the file containing the conversations. Returns: List[Conversation]: A list of conversations read from the file. """conversations=[]withopen(input_filepath)asf:forlineinf:# Only parse non-empty lines.ifline.strip():conversation=Conversation.from_json(line)conversations.append(conversation)returnconversationsdef_load_from_scratch(self,output_filepath:Optional[str])->list[Conversation]:"""Loads conversations from a scratch file. Args: output_filepath: The path to the output file. This is used to determine the location of the scratch file. Returns: list[Conversation]: A list of conversations loaded from the scratch file, or an empty list if the scratch file does not exist or is empty. """scratch_filepath=self._get_scratch_filepath(output_filepath)ifnotPath(scratch_filepath).exists():return[]conversations=[]withjsonlines.open(scratch_filepath,mode="r")asreader:forlineinreader:conversations.append(Conversation.from_dict(line))returnconversationsdef_filter_incomplete_conversations(self,input_conversations:list[Conversation],completed_conversations:list[Conversation],)->list[Conversation]:"""Filters out conversations that have already been completed. Args: input_conversations: List of conversations to run inference on completed_conversations: List of conversations already completed Returns: list[Conversation]: List of conversations that still need inference results """completed_ids={conv.conversation_idforconvincompleted_conversationsifconv.conversation_idisnotNone}return[convforconvininput_conversationsifconv.conversation_idnotincompleted_ids]def_get_scratch_filepath(self,output_filepath:Optional[str])->str:"""Returns a scratch filepath for the given output filepath. For example, if the output filepath is "/foo/bar/output.json", the scratch filepath will be "/foo/bar/scratch/output.json" If no output filepath is provided, a temporary file is used and placed in the current working directory under the name "tmp/temp_inference_output.jsonl". Args: output_filepath: The output filepath. Returns: str: The scratch filepath. """ifoutput_filepathisnotNone:original_filepath=Path(output_filepath)returnstr(original_filepath.parent/"scratch"/original_filepath.name)model_params=self._model_paramsmodel_params_str=json.dumps(dataclasses.asdict(model_params))generation_params=self._generation_paramsgeneration_params_str=json.dumps(dataclasses.asdict(generation_params))inference_hash=hashlib.sha256(f"{model_params_str}_{generation_params_str}_{self._dataset_hash}".encode()).hexdigest()path_prefix=Path.home()/".cache"/"oumi"/"tmp"returnstr(path_prefix/f"temp_inference_output_{inference_hash}.jsonl")def_save_conversation_to_scratch(self,conversation:Conversation,output_filepath:Optional[str])->None:"""Appends a conversation to a file in Oumi chat format. Args: conversation: The conversation to save. output_filepath: The path to the file where the conversation should be saved. """# Make the directory if it doesn't exist.scratch_filepath=self._get_scratch_filepath(output_filepath)Path(scratch_filepath).parent.mkdir(parents=True,exist_ok=True)withjsonlines.open(scratch_filepath,mode="a")aswriter:json_obj=conversation.to_dict()writer.write(json_obj)def_cleanup_scratch_file(self,output_filepath:Optional[str])->None:"""Delete the scratch file from the file system if it exists. Args: output_filepath: The path to the output file. This is used to determine the location of the scratch file. """scratch_filepath=self._get_scratch_filepath(output_filepath)ifPath(scratch_filepath).exists():Path(scratch_filepath).unlink()def_save_conversations(self,conversations:list[Conversation],output_filepath:str)->None:"""Saves conversations to a file in Oumi chat format. Args: conversations: A list of conversations to save. output_filepath: The path to the file where the conversations should be saved. """# Make the directory if it doesn't exist.Path(output_filepath).parent.mkdir(parents=True,exist_ok=True)withjsonlines.open(output_filepath,mode="w")aswriter:forconversationintqdm(conversations,desc="Saving conversations"):json_obj=conversation.to_dict()writer.write(json_obj)def_check_unsupported_params(self,generation_params:GenerationParams):"""Checks for unsupported parameters and logs warnings. If a parameter is not supported, and a non-default value is provided, a warning is logged. """supported_params=self.get_supported_params()default_generation_params=GenerationParams()forparam_name,valueingeneration_params:ifparam_namenotinsupported_params:is_non_default_value=(getattr(default_generation_params,param_name)!=value)ifis_non_default_value:logger.warning(f"{self.__class__.__name__} does not support {param_name}. "f"Received value: {param_name}={value}. ""This parameter will be ignored.")
[docs]@abstractmethoddefget_supported_params(self)->set[str]:"""Returns a set of supported generation parameters for this engine. Override this method in derived classes to specify which parameters are supported. Returns: Set[str]: A set of supported parameter names. """raiseNotImplementedError
@abstractmethoddef_infer_online(self,input:list[Conversation],inference_config:Optional[InferenceConfig]=None,)->list[Conversation]:"""Runs model inference online. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """raiseNotImplementedError
[docs]defapply_chat_template(self,conversation:Conversation,**tokenizer_kwargs)->str:"""Applies the chat template to the conversation. Args: conversation: The conversation to apply the chat template to. tokenizer_kwargs: Additional keyword arguments to pass to the tokenizer. Returns: str: The conversation with the chat template applied. """tokenizer=getattr(self,"_tokenizer",None)iftokenizerisNone:raiseValueError("Tokenizer is not initialized.")iftokenizer.chat_templateisNone:raiseValueError("Tokenizer does not have a chat template.")if"tokenize"notintokenizer_kwargs:tokenizer_kwargs["tokenize"]=Falsereturntokenizer.apply_chat_template(conversation,**tokenizer_kwargs)