Source code for oumi.inference.sglang_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import json
from typing import Any, NamedTuple

import pydantic
from typing_extensions import override

from oumi.builders import (
    build_processor,
    build_tokenizer,
    is_image_text_llm,
)
from oumi.core.configs import (
    GenerationParams,
    ModelParams,
    RemoteParams,
)
from oumi.core.processors.base_processor import BaseProcessor
from oumi.core.types.conversation import Conversation, Message, Role, Type
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
from oumi.utils.conversation_utils import base64encode_content_item_image_bytes
from oumi.utils.logging import logger


class _SamplingParams(NamedTuple):
    """It's a clone of `sglang.lang.ir.SglSamplingParams`.

    Only includes a subset of parameters supported in oumi.
    Unsupported params are left commented out for reference.
    """

    max_new_tokens: int = 128
    # min_new_tokens: int = 0
    stop: str | list[str] = ""
    stop_token_ids: list[int] | None = None
    temperature: float = 1.0
    top_p: float = 1.0
    # top_k: int = -1  # -1 means disable
    min_p: float = 0.0
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    # ignore_eos: bool = False
    # return_logprob: bool | None = None
    # logprob_start_len: int | None = None
    # top_logprobs_num: int | None = None
    # return_text_in_logprobs: bool | None = None
    json_schema: str | None = None

    # For constrained generation:
    # dtype: str | None = None
    regex: str | None = None


[docs] class SGLangInferenceEngine(RemoteInferenceEngine): """Engine for running SGLang inference.""" def __init__( self, model_params: ModelParams, *, remote_params: RemoteParams | None = None, generation_params: GenerationParams | None = None, ): """Initializes the SGL inference Engine. Args: model_params: The model parameters to use for inference. remote_params: Remote server params. generation_params: The generation parameters to use for inference. """ if remote_params is None: raise ValueError("remote_params is required") super().__init__( model_params=model_params, generation_params=generation_params, remote_params=remote_params, ) self._tokenizer = build_tokenizer(self._model_params) self._processor: BaseProcessor | None = None if is_image_text_llm(self._model_params): # Only enable Processor for vision language models for now. self._processor = build_processor( self._model_params.model_name, self._tokenizer, trust_remote_code=self._model_params.trust_remote_code, ) # TODO Launch a local SGLLang server if requested. def _create_sampling_params( self, generation_params: GenerationParams ) -> _SamplingParams: regex: str | None = None json_schema: str | None = None if generation_params.guided_decoding is not None: if generation_params.guided_decoding.regex is not None: regex = generation_params.guided_decoding.regex else: json_schema_value = None if generation_params.guided_decoding.json is not None: json_schema_value = generation_params.guided_decoding.json elif ( generation_params.guided_decoding.choice is not None and len(generation_params.guided_decoding.choice) > 0 ): json_schema_value = { "enum": generation_params.guided_decoding.choice } if isinstance(json_schema_value, str): json_schema = json_schema_value elif isinstance(json_schema_value, dict): json_schema = json.dumps(json_schema_value, ensure_ascii=False) elif isinstance(json_schema_value, pydantic.BaseModel) or ( isinstance(json_schema_value, type) and issubclass(json_schema_value, pydantic.BaseModel) ): json_schema = json.dumps(json_schema_value.model_json_schema()) else: raise ValueError( "Unsupported type of generation_params.guided_decoding.json: " f"{type(generation_params.guided_decoding.json)}" ) return _SamplingParams( max_new_tokens=generation_params.max_new_tokens, temperature=generation_params.temperature, top_p=generation_params.top_p, min_p=generation_params.min_p, frequency_penalty=generation_params.frequency_penalty, presence_penalty=generation_params.presence_penalty, stop=(generation_params.stop_strings or []), stop_token_ids=generation_params.stop_token_ids, regex=regex, json_schema=json_schema, ) def _create_sampling_params_as_dict( self, generation_params: GenerationParams ) -> dict[str, Any]: return self._create_sampling_params(generation_params)._asdict() def _apply_chat_template_impl(self, conversation: Conversation) -> str: if self._processor is None: return self._tokenizer.apply_chat_template( conversation, # type: ignore tokenize=False, add_generation_prompt=True, ) return self._processor.apply_chat_template( conversation, # type: ignore add_generation_prompt=True, ) def _create_image_data_as_str(self, conversation: Conversation) -> str | None: image_items = [ item for m in conversation.messages for item in m.image_content_items ] num_images = len(image_items) if num_images <= 0: return None elif num_images > 1: # FIXME OPE-355 Support multiple images logger.warning( conversation.append_id_to_string( f"A conversation contains multiple images ({num_images}). " "Only 1 image is currently supported. Using the last image." ) ) image_item = image_items[-1] if image_item.type == Type.IMAGE_BINARY: if not image_item.binary: raise ValueError( conversation.append_id_to_string( f"No image bytes in message: {image_item.type}" ) ) return base64encode_content_item_image_bytes(image_item) assert image_item.type in (Type.IMAGE_PATH, Type.IMAGE_URL) image_path_or_url = image_item.content if not image_path_or_url: friendly_type_name = ( "image path" if image_item.type == Type.IMAGE_PATH else "image URL" ) raise ValueError( conversation.append_id_to_string( f"Empty {friendly_type_name} in message: {image_item.type}" ) ) return image_path_or_url @override def _convert_conversation_to_api_input( self, conversation: Conversation, generation_params: GenerationParams ) -> dict[str, Any]: """Converts a conversation to SGLang Native API input. See https://sgl-project.github.io/references/sampling_params.html for details. Args: conversation: The Oumi Conversation object to convert. generation_params: Parameters for text generation. Returns: Dict[str, Any]: A dictionary containing the formatted input for the SGLang server native API, including the model, messages, generation params. """ # Chat templates loaded by SGLang server are generally different from oumi's # chat template, hence, let's apply the template here ourselves. prompt = self._apply_chat_template_impl(conversation) sampling_params_dict = self._create_sampling_params_as_dict(generation_params) body = { "text": prompt, "sampling_params": sampling_params_dict, } image_data = self._create_image_data_as_str(conversation) if image_data: body["image_data"] = image_data return body @override def _convert_api_output_to_conversation( self, response: dict[str, Any], original_conversation: Conversation ) -> Conversation: """Converts an SGLang Native API response to a conversation.""" new_message = Message( content=response["text"], role=Role.ASSISTANT, ) return Conversation( messages=[*original_conversation.messages, new_message], metadata=original_conversation.metadata, conversation_id=original_conversation.conversation_id, ) @override def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]: return { "Content-Type": "application/json", }
[docs] @override @functools.cache def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { "frequency_penalty", "guided_decoding", "max_new_tokens", "min_p", "presence_penalty", "stop_strings", "stop_token_ids", "temperature", "top_p", }