Source code for oumi.inference.remote_vllm_inference_engine
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing_extensions import override
from oumi.core.configs import GenerationParams
from oumi.core.types.conversation import Conversation
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
_CONTENT_KEY: str = "content"
_ROLE_KEY: str = "role"
[docs]
class RemoteVLLMInferenceEngine(RemoteInferenceEngine):
"""Engine for running inference against Remote vLLM."""
@property
@override
def base_url(self) -> Optional[str]:
"""Return the default base URL for the Remote vLLM API."""
return None
@property
@override
def api_key_env_varname(self) -> Optional[str]:
"""Return the default environment variable name for the Remote vLLM API key."""
return None
[docs]
@override
def get_supported_params(self) -> set[str]:
"""Returns a set of supported generation parameters for this engine."""
return {
"frequency_penalty",
"logit_bias",
"presence_penalty",
"seed",
"stop_strings",
"stop_token_ids",
"temperature",
"top_p",
"guided_decoding",
"max_new_tokens",
}
@override
def _convert_conversation_to_api_input(
self, conversation: Conversation, generation_params: GenerationParams
) -> dict[str, Any]:
"""Converts a conversation to an OpenAI input.
Documentation: https://platform.openai.com/docs/api-reference/chat/create
Args:
conversation: The conversation to convert.
generation_params: Parameters for generation during inference.
Returns:
Dict[str, Any]: A dictionary representing the OpenAI input.
"""
api_input = {
"model": (self._adapter_model if self._adapter_model else self._model),
"messages": self._get_list_of_message_json_dicts(
conversation.messages, group_adjacent_same_role_turns=True
),
"max_tokens": generation_params.max_new_tokens,
# "max_completion_tokens": generation_params.max_new_tokens,
# Future transition instead of `max_tokens`. See https://github.com/vllm-project/vllm/issues/9845
"temperature": generation_params.temperature,
"top_p": generation_params.top_p,
"frequency_penalty": generation_params.frequency_penalty,
"presence_penalty": generation_params.presence_penalty,
"n": 1, # Number of completions to generate for each prompt.
"seed": generation_params.seed,
"logit_bias": generation_params.logit_bias,
}
if generation_params.guided_decoding:
if generation_params.guided_decoding.json:
api_input["guided_json"] = generation_params.guided_decoding.json
elif generation_params.guided_decoding.regex is not None:
api_input["guided_regex"] = generation_params.guided_decoding.regex
elif generation_params.guided_decoding.choice is not None:
api_input["guided_choice"] = generation_params.guided_decoding.choice
if generation_params.stop_strings:
api_input["stop"] = generation_params.stop_strings
if generation_params.stop_token_ids:
api_input["stop_token_ids"] = generation_params.stop_token_ids
return api_input