Source code for oumi.inference.vllm_inference_engine
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from__future__importannotationsimportcopyimportmathimporttorchfromtyping_extensionsimportoverridefromoumi.buildersimportbuild_tokenizerfromoumi.core.configsimportGenerationParams,InferenceConfig,ModelParamsfromoumi.core.inferenceimportBaseInferenceEnginefromoumi.core.types.conversationimportConversation,Message,Rolefromoumi.utils.conversation_utilsimportcreate_list_of_message_json_dictsfromoumi.utils.loggingimportloggerfromoumi.utils.model_cachingimportget_local_filepath_for_gguffromoumi.utils.peft_utilsimportget_lora_ranktry:importvllm# pyright: ignore[reportMissingImports]fromvllm.entrypoints.chat_utilsimport(# pyright: ignore[reportMissingImports]ChatCompletionMessageParam,)fromvllm.lora.requestimportLoRARequest# pyright: ignore[reportMissingImports]fromvllm.sampling_paramsimport(# pyright: ignore[reportMissingImports]GuidedDecodingParamsasVLLMGuidedDecodingParams,)fromvllm.sampling_paramsimport(# pyright: ignore[reportMissingImports]SamplingParams,)exceptModuleNotFoundError:vllm=None
[docs]classVLLMInferenceEngine(BaseInferenceEngine):"""Engine for running vLLM inference locally."""def__init__(self,model_params:ModelParams,*,generation_params:GenerationParams|None=None,tensor_parallel_size:int=-1,quantization:str|None=None,enable_prefix_caching:bool=True,gpu_memory_utilization:float=0.9,enforce_eager:bool=True,max_num_seqs:int|None=None,):"""Initializes the inference Engine. Args: model_params: The model parameters to use for inference. generation_params: The generation parameters to use for inference. tensor_parallel_size: The number of tensor parallel processes to use. If set to -1, we will use all the available GPUs. quantization: The quantization method to use for inference. enable_prefix_caching: Whether to enable prefix caching. gpu_memory_utilization: The fraction of available GPU memory the model's executor will use. It can range from 0 to 1. Defaults to 0.9, i.e., (90%) memory utilization. enforce_eager: Whether to enforce eager execution. Defaults to True. If False, will use eager mode and CUDA graph in hybrid mode. max_num_seqs: Maximum number of sequences per iteration. """super().__init__(model_params=model_params,generation_params=generation_params)ifnotvllm:raiseRuntimeError("vLLM is not installed. ""Please install the GPU dependencies for this package.")ifnot(math.isfinite(gpu_memory_utilization)andgpu_memory_utilization>0andgpu_memory_utilization<=1.0):raiseValueError("GPU memory utilization must be within (0, 1]. Got "f"{gpu_memory_utilization}.")# Infer the `quantization` type from the model's kwargs.ifmodel_params.model_kwargs:ifnotquantization:# Check if quantization is BitsAndBytes.bnb_quantization_kwargs=["load_in_4bit","load_in_8bit"]forkeyinbnb_quantization_kwargs:ifmodel_params.model_kwargs.get(key):quantization="bitsandbytes"breakifnotquantizationandmodel_params.model_kwargs.get("filename"):# Check if quantization is GGUF.gguf_filename=str(model_params.model_kwargs.get("filename"))ifgguf_filename.lower().endswith(".gguf"):quantization="gguf"if(notmodel_params.tokenizer_nameormodel_params.tokenizer_name==model_params.model_name):raiseValueError("GGUF quantization with the VLLM engine requires that you ""explicitly set the `tokenizer_name` in `model_params`.")vllm_kwargs={}# Set the proper VLLM keys for the quantization type.ifquantizationandquantization=="bitsandbytes":vllm_kwargs["load_format"]="bitsandbytes"logger.info("VLLM engine loading a `bitsandbytes` quantized model.")elifquantizationandquantization=="gguf":# Download the GGUF file from HuggingFace to a local cache.gguf_local_path=get_local_filepath_for_gguf(repo_id=model_params.model_name,filename=gguf_filename,)# Overwrite `model_name` with the locally cached GGUF model.model_params=copy.deepcopy(model_params)model_params.model_name=gguf_local_pathlogger.info("VLLM engine loading a `GGUF` quantized model.")iftensor_parallel_size<=0:iftorch.cuda.device_count()>1:tensor_parallel_size=torch.cuda.device_count()else:tensor_parallel_size=1self._lora_request=Noneifmodel_params.adapter_model:# ID should be unique for this adapter, but isn't enforced by vLLM.self._lora_request=LoRARequest(lora_name="oumi_lora_adapter",lora_int_id=1,lora_path=model_params.adapter_model,)logger.info(f"Loaded LoRA adapter: {model_params.adapter_model}")lora_rank=get_lora_rank(model_params.adapter_model)vllm_kwargs["max_lora_rank"]=lora_ranklogger.info(f"Setting vLLM max LoRA rank to {lora_rank}")ifmax_num_seqsisnotNone:vllm_kwargs["max_num_seqs"]=max_num_seqsself._tokenizer=build_tokenizer(model_params)self._llm=vllm.LLM(model=model_params.model_name,tokenizer=model_params.tokenizer_name,trust_remote_code=model_params.trust_remote_code,dtype=model_params.torch_dtype_str,# TODO: these params should be settable via config,# but they don't belong to model_paramsquantization=quantization,tensor_parallel_size=tensor_parallel_size,enable_prefix_caching=enable_prefix_caching,enable_lora=self._lora_requestisnotNone,max_model_len=model_params.model_max_length,gpu_memory_utilization=gpu_memory_utilization,enforce_eager=enforce_eager,**vllm_kwargs,)# Ensure the tokenizer is set properlyself._llm.set_tokenizer(self._tokenizer)def_convert_conversation_to_vllm_input(self,conversation:Conversation)->list[ChatCompletionMessageParam]:"""Converts a conversation to a list of vllm input messages. Args: conversation: The conversation to convert. Returns: List[ChatCompletionMessageParam]: A list of vllm input messages. """result:list[ChatCompletionMessageParam]=[]forjson_dictincreate_list_of_message_json_dicts(conversation.messages,group_adjacent_same_role_turns=True):forkeyin("role","content"):ifkeynotinjson_dict:raiseRuntimeError(f"The required field '{key}' is missing!")ifnotisinstance(json_dict["content"],(str,list)):raiseRuntimeError("The 'content' field must be `str` or `list`. "f"Actual: {type(json_dict['content'])}.")result.append({"role":json_dict["role"],"content":json_dict["content"]})returnresultdef_infer(self,input:list[Conversation],inference_config:InferenceConfig|None=None,)->list[Conversation]:"""Runs model inference on the provided input. Documentation: https://docs.vllm.ai/en/stable/dev/sampling_params.html Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """generation_params=(inference_config.generationifinference_configandinference_config.generationelseself._generation_params)ifgeneration_params.guided_decodingisnotNone:guided_decoding=VLLMGuidedDecodingParams.from_optional(json=generation_params.guided_decoding.json,regex=generation_params.guided_decoding.regex,choice=generation_params.guided_decoding.choice,)else:guided_decoding=Nonesampling_params=SamplingParams(n=1,max_tokens=generation_params.max_new_tokens,temperature=generation_params.temperature,top_p=generation_params.top_p,frequency_penalty=generation_params.frequency_penalty,presence_penalty=generation_params.presence_penalty,stop=generation_params.stop_strings,stop_token_ids=generation_params.stop_token_ids,min_p=generation_params.min_p,guided_decoding=guided_decoding,)output_conversations=[]vllm_conversations=[]non_skipped_conversations=[]forconversationininput:ifnotconversation.messages:logger.warning("Conversation must have at least one message.")continuevllm_input=self._convert_conversation_to_vllm_input(conversation)vllm_conversations.append(vllm_input)non_skipped_conversations.append(conversation)iflen(vllm_conversations)==0:return[]enable_tqdm=len(vllm_conversations)>=2# Note: vLLM performs continuous batching under the hood.# We pass all the conversations and let vLLM handle the rest.chat_responses=self._llm.chat(vllm_conversations,sampling_params=sampling_params,lora_request=self._lora_request,use_tqdm=enable_tqdm,chat_template=None,chat_template_content_format="auto",)forconversation,chat_responseinzip(non_skipped_conversations,chat_responses):new_messages=[Message(content=message.text,role=Role.ASSISTANT)formessageinchat_response.outputsiflen(chat_response.outputs)>0]messages=[*conversation.messages,*new_messages,]new_conversation=Conversation(messages=messages,metadata=conversation.metadata,conversation_id=conversation.conversation_id,)output_conversations.append(new_conversation)ifinference_configandinference_config.output_path:self._save_conversations(output_conversations,inference_config.output_path,)returnoutput_conversations
[docs]@overridedefinfer_online(self,input:list[Conversation],inference_config:InferenceConfig|None=None,)->list[Conversation]:"""Runs model inference online. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """returnself._infer(input,inference_config)
[docs]@overridedefinfer_from_file(self,input_filepath:str,inference_config:InferenceConfig|None=None,)->list[Conversation]:"""Runs model inference on inputs in the provided file. This is a convenience method to prevent boilerplate from asserting the existence of input_filepath in the generation_params. Args: input_filepath: Path to the input file containing prompts for generation. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """input=self._read_conversations(input_filepath)returnself._infer(input,inference_config)
[docs]@overridedefget_supported_params(self)->set[str]:"""Returns a set of supported generation parameters for this engine."""return{"frequency_penalty","guided_decoding","max_new_tokens","min_p","presence_penalty","stop_strings","stop_token_ids","temperature","top_p",}