Source code for oumi.inference.remote_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import copy
import json
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Optional

import aiofiles
import aiohttp
import jsonlines
import pydantic
from tqdm.asyncio import tqdm
from typing_extensions import override

from oumi.core.async_utils import safe_asyncio_run
from oumi.core.configs import (
    GenerationParams,
    InferenceConfig,
    ModelParams,
    RemoteParams,
)
from oumi.core.inference import BaseInferenceEngine
from oumi.core.types.conversation import (
    Conversation,
    Message,
    Role,
)
from oumi.utils.conversation_utils import (
    convert_message_to_json_content_list,
    create_list_of_message_json_dicts,
)

_AUTHORIZATION_KEY: str = "Authorization"
_BATCH_PURPOSE = "batch"
_BATCH_ENDPOINT = "/v1/chat/completions"


class BatchStatus(Enum):
    """Status of a batch inference job."""

    VALIDATING = "validating"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"
    FAILED = "failed"
    EXPIRED = "expired"
    CANCELLED = "cancelled"


@dataclass
class BatchInfo:
    """Information about a batch job."""

    id: str
    status: BatchStatus
    total_requests: int = 0
    completed_requests: int = 0
    failed_requests: int = 0
    endpoint: Optional[str] = None
    input_file_id: Optional[str] = None
    batch_completion_window: Optional[str] = None
    output_file_id: Optional[str] = None
    error_file_id: Optional[str] = None
    error: Optional[str] = None
    created_at: Optional[datetime] = None
    in_progress_at: Optional[datetime] = None
    expires_at: Optional[datetime] = None
    finalizing_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    failed_at: Optional[datetime] = None
    expired_at: Optional[datetime] = None
    canceling_at: Optional[datetime] = None
    canceled_at: Optional[datetime] = None
    metadata: Optional[dict[str, Any]] = None

    @staticmethod
    def _convert_timestamp(timestamp: Optional[int]) -> Optional[datetime]:
        """Convert Unix timestamp to datetime.

        Args:
            timestamp: Unix timestamp in seconds

        Returns:
            datetime: Converted datetime or None if timestamp is None
        """
        return datetime.fromtimestamp(timestamp) if timestamp is not None else None

    @classmethod
    def from_api_response(cls, response: dict[str, Any]) -> "BatchInfo":
        """Create BatchInfo from API response dictionary.

        Args:
            response: Raw API response dictionary

        Returns:
            BatchInfo: Parsed batch information
        """
        return cls(
            id=response["id"],
            status=BatchStatus(response["status"]),
            endpoint=response.get("endpoint"),
            input_file_id=response.get("input_file_id"),
            batch_completion_window=response.get("batch_completion_window"),
            output_file_id=response.get("output_file_id"),
            error_file_id=response.get("error_file_id"),
            error=response.get("error"),
            created_at=cls._convert_timestamp(response.get("created_at")),
            in_progress_at=cls._convert_timestamp(response.get("in_progress_at")),
            expires_at=cls._convert_timestamp(response.get("expires_at")),
            finalizing_at=cls._convert_timestamp(response.get("finalizing_at")),
            completed_at=cls._convert_timestamp(response.get("completed_at")),
            failed_at=cls._convert_timestamp(response.get("failed_at")),
            expired_at=cls._convert_timestamp(response.get("expired_at")),
            canceling_at=cls._convert_timestamp(response.get("cancelling_at")),
            canceled_at=cls._convert_timestamp(response.get("cancelled_at")),
            total_requests=response.get("request_counts", {}).get("total", 0),
            completed_requests=response.get("request_counts", {}).get("completed", 0),
            failed_requests=response.get("request_counts", {}).get("failed", 0),
            metadata=response.get("metadata"),
        )

    @property
    def is_terminal(self) -> bool:
        """Return True if the batch is in a terminal state."""
        return self.status in (
            BatchStatus.COMPLETED,
            BatchStatus.FAILED,
            BatchStatus.EXPIRED,
            BatchStatus.CANCELLED,
        )

    @property
    def completion_percentage(self) -> float:
        """Return the percentage of completed requests."""
        return (
            (100 * self.completed_requests / self.total_requests)
            if self.total_requests > 0
            else 0.0
        )

    @property
    def has_errors(self) -> bool:
        """Return True if the batch has any errors."""
        return bool(self.error) or self.failed_requests > 0


@dataclass
class BatchListResponse:
    """Response from listing batch jobs."""

    batches: list[BatchInfo]
    first_id: Optional[str] = None
    last_id: Optional[str] = None
    has_more: bool = False


@dataclass
class FileInfo:
    """Information about a file."""

    id: str
    filename: str
    bytes: int
    created_at: int
    purpose: str


@dataclass
class FileListResponse:
    """Response from listing files."""

    files: list[FileInfo]
    has_more: bool = False


[docs] class RemoteInferenceEngine(BaseInferenceEngine): """Engine for running inference against a server implementing the OpenAI API.""" base_url: Optional[str] = None """The base URL for the remote API.""" api_key_env_varname: Optional[str] = None """The environment variable name for the API key.""" def __init__( self, model_params: ModelParams, *, generation_params: Optional[GenerationParams] = None, remote_params: Optional[RemoteParams] = None, ): """Initializes the inference Engine. Args: model_params: The model parameters to use for inference. generation_params: Generation parameters to use for inference. remote_params: Remote server params. **kwargs: Additional keyword arguments. """ super().__init__(model_params=model_params, generation_params=generation_params) self._model = model_params.model_name self._adapter_model = model_params.adapter_model if remote_params: remote_params = copy.deepcopy(remote_params) else: remote_params = RemoteParams() if not remote_params.api_url: remote_params.api_url = self.base_url if not remote_params.api_key_env_varname: remote_params.api_key_env_varname = self.api_key_env_varname self._remote_params = remote_params self._remote_params.finalize_and_validate() @staticmethod def _get_list_of_message_json_dicts( messages: list[Message], *, group_adjacent_same_role_turns: bool, ) -> list[dict[str, Any]]: return create_list_of_message_json_dicts( messages, group_adjacent_same_role_turns=group_adjacent_same_role_turns ) def _convert_conversation_to_api_input( self, conversation: Conversation, generation_params: GenerationParams ) -> dict[str, Any]: """Converts a conversation to an OpenAI input. Documentation: https://platform.openai.com/docs/api-reference/chat/create Args: conversation: The conversation to convert. generation_params: Parameters for generation during inference. Returns: Dict[str, Any]: A dictionary representing the OpenAI input. """ api_input = { "model": self._model, "messages": [ { "content": convert_message_to_json_content_list(message), "role": message.role.value, } for message in conversation.messages ], "max_completion_tokens": generation_params.max_new_tokens, "temperature": generation_params.temperature, "top_p": generation_params.top_p, "frequency_penalty": generation_params.frequency_penalty, "presence_penalty": generation_params.presence_penalty, "n": 1, # Number of completions to generate for each prompt. "seed": generation_params.seed, "logit_bias": generation_params.logit_bias, } if generation_params.stop_strings: api_input["stop"] = generation_params.stop_strings if generation_params.guided_decoding: json_schema = generation_params.guided_decoding.json if json_schema is not None: if isinstance(json_schema, type) and issubclass( json_schema, pydantic.BaseModel ): schema_name = json_schema.__name__ schema_value = json_schema.model_json_schema() elif isinstance(json_schema, dict): # Use a generic name if no schema is provided. schema_name = "Response" schema_value = json_schema elif isinstance(json_schema, str): # Use a generic name if no schema is provided. schema_name = "Response" # Try to parse as JSON string schema_value = json.loads(json_schema) else: raise ValueError( f"Got unsupported JSON schema type: {type(json_schema)}" "Please provide a Pydantic model or a JSON schema as a " "string or dict." ) api_input["response_format"] = { "type": "json_schema", "json_schema": { "name": schema_name, "schema": schema_value, }, } else: raise ValueError( "Only JSON schema guided decoding is supported, got '%s'", generation_params.guided_decoding, ) return api_input def _convert_api_output_to_conversation( self, response: dict[str, Any], original_conversation: Conversation ) -> Conversation: """Converts an API response to a conversation. Args: response: The API response to convert. original_conversation: The original conversation. Returns: Conversation: The conversation including the generated response. """ message = response["choices"][0]["message"] return Conversation( messages=[ *original_conversation.messages, Message( content=message["content"], role=Role(message["role"]), ), ], metadata=original_conversation.metadata, conversation_id=original_conversation.conversation_id, ) def _get_api_key(self, remote_params: RemoteParams) -> Optional[str]: if not remote_params: return None if remote_params.api_key: return remote_params.api_key if remote_params.api_key_env_varname: return os.environ.get(remote_params.api_key_env_varname) return None def _get_request_headers( self, remote_params: Optional[RemoteParams] ) -> dict[str, str]: headers = {} if not remote_params: return headers headers[_AUTHORIZATION_KEY] = f"Bearer {self._get_api_key(remote_params)}" return headers async def _query_api( self, conversation: Conversation, semaphore: asyncio.Semaphore, session: aiohttp.ClientSession, inference_config: Optional[InferenceConfig] = None, ) -> Conversation: """Queries the API with the provided input. Args: conversation: The conversations to run inference on. semaphore: Semaphore to limit concurrent requests. session: The aiohttp session to use for the request. inference_config: Parameters for inference. Returns: Conversation: Inference output. """ if inference_config is None: remote_params = self._remote_params generation_params = self._generation_params output_path = None else: remote_params = inference_config.remote_params or self._remote_params generation_params = inference_config.generation or self._generation_params output_path = inference_config.output_path assert remote_params.api_url async with semaphore: api_input = self._convert_conversation_to_api_input( conversation, generation_params ) headers = self._get_request_headers(remote_params) retries = 0 failure_reason = None # Retry the request if it fails. for _ in range(remote_params.max_retries + 1): async with session.post( remote_params.api_url, json=api_input, headers=headers, timeout=remote_params.connection_timeout, ) as response: response_json = await response.json() if response.status == 200: result = self._convert_api_output_to_conversation( response_json, conversation ) if output_path: # Write what we have so far to our scratch directory. self._save_conversation( result, self._get_scratch_filepath(output_path), ) await asyncio.sleep(remote_params.politeness_policy) return result else: failure_reason = ( response_json.get("error").get("message") if response_json and response_json.get("error") else None ) retries += 1 await asyncio.sleep(remote_params.politeness_policy) raise RuntimeError( f"Failed to query API after {remote_params.max_retries} retries. " + (f"Reason: {failure_reason}" if failure_reason else "") ) async def _infer( self, input: list[Conversation], inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference on the provided input. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. remote_params: Parameters for running inference against a remote API. Returns: List[Conversation]: Inference output. """ # Limit number of HTTP connections to the number of workers. connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) # Control the number of concurrent tasks via a semaphore. semaphore = asyncio.BoundedSemaphore(self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: tasks = [ self._query_api( conversation, semaphore, session, inference_config=inference_config, ) for conversation in input ] disable_tqdm = len(tasks) < 2 return await tqdm.gather(*tasks, disable=disable_tqdm)
[docs] @override def infer_online( self, input: list[Conversation], inference_config: Optional[InferenceConfig] = None, ) -> list[Conversation]: """Runs model inference online. Args: input: A list of conversations to run inference on. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """ conversations = safe_asyncio_run(self._infer(input, inference_config)) if inference_config and inference_config.output_path: self._save_conversations(conversations, inference_config.output_path) return conversations
[docs] @override def infer_from_file( self, input_filepath: str, inference_config: Optional[InferenceConfig] = None ) -> list[Conversation]: """Runs model inference on inputs in the provided file. This is a convenience method to prevent boilerplate from asserting the existence of input_filepath in the generation_params. Args: input_filepath: Path to the input file containing prompts for generation. inference_config: Parameters for inference. Returns: List[Conversation]: Inference output. """ input = self._read_conversations(input_filepath) conversations = safe_asyncio_run(self._infer(input, inference_config)) if inference_config and inference_config.output_path: self._save_conversations(conversations, inference_config.output_path) return conversations
[docs] @override def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { "frequency_penalty", "guided_decoding", "logit_bias", "max_new_tokens", "presence_penalty", "seed", "stop_strings", "temperature", "top_p", }
# # Batch inference #
[docs] def infer_batch( self, conversations: list[Conversation], inference_config: Optional[InferenceConfig] = None, ) -> str: """Creates a new batch inference job. Args: conversations: List of conversations to process in batch inference_config: Parameters for inference Returns: str: The batch job ID """ generation_params = ( inference_config.generation if inference_config else self._generation_params ) return safe_asyncio_run(self._create_batch(conversations, generation_params))
[docs] def get_batch_status( self, batch_id: str, ) -> BatchInfo: """Gets the status of a batch inference job. Args: batch_id: The batch job ID Returns: BatchInfo: Current status of the batch job """ return safe_asyncio_run(self._get_batch_status(batch_id))
[docs] def list_batches( self, after: Optional[str] = None, limit: Optional[int] = None, ) -> BatchListResponse: """Lists batch jobs. Args: after: Cursor for pagination (batch ID to start after) limit: Maximum number of batches to return (1-100) Returns: BatchListResponse: List of batch jobs """ return safe_asyncio_run( self._list_batches( after=after, limit=limit, ) )
[docs] def get_batch_results( self, batch_id: str, conversations: list[Conversation], ) -> list[Conversation]: """Gets the results of a completed batch job. Args: batch_id: The batch job ID conversations: Original conversations used to create the batch Returns: List[Conversation]: The processed conversations with responses Raises: RuntimeError: If the batch failed or has not completed """ return safe_asyncio_run( self._get_batch_results_with_mapping(batch_id, conversations) )
async def _upload_batch_file( self, batch_requests: list[dict], ) -> str: """Uploads a JSONL file containing batch requests. Args: batch_requests: List of request objects to include in the batch Returns: str: The uploaded file ID """ # Create temporary JSONL file with tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as tmp: with jsonlines.Writer(tmp) as writer: for request in batch_requests: writer.write(request) tmp_path = tmp.name try: # Upload the file connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) # Create form data with file form = aiohttp.FormData() async with aiofiles.open(tmp_path, "rb") as f: file_data = await f.read() form.add_field("file", file_data, filename="batch_requests.jsonl") form.add_field("purpose", _BATCH_PURPOSE) async with session.post( f"{self._remote_params.api_url}/files", data=form, headers=headers, ) as response: if response.status != 200: raise RuntimeError( f"Failed to upload batch file: {await response.text()}" ) data = await response.json() return data["id"] finally: # Clean up temporary file Path(tmp_path).unlink() async def _create_batch( self, conversations: list[Conversation], generation_params: GenerationParams, ) -> str: """Creates a new batch job. Args: conversations: List of conversations to process in batch generation_params: Generation parameters Returns: str: The batch job ID """ # Prepare batch requests batch_requests = [] for i, conv in enumerate(conversations): api_input = self._convert_conversation_to_api_input(conv, generation_params) batch_requests.append( { "custom_id": f"request-{i}", "method": "POST", "url": _BATCH_ENDPOINT, "body": api_input, } ) # Upload batch file file_id = await self._upload_batch_file(batch_requests) # Create batch connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) async with session.post( f"{self._remote_params.api_url}/batches", json={ "input_file_id": file_id, "endpoint": _BATCH_ENDPOINT, "batch_completion_window": ( self._remote_params.batch_completion_window ), }, headers=headers, ) as response: if response.status != 200: raise RuntimeError( f"Failed to create batch: {await response.text()}" ) data = await response.json() return data["id"] async def _get_batch_status( self, batch_id: str, ) -> BatchInfo: """Gets the status of a batch job. Args: batch_id: ID of the batch job Returns: BatchInfo: Current status of the batch job """ connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) async with session.get( f"{self._remote_params.api_url}/batches/{batch_id}", headers=headers, ) as response: if response.status != 200: raise RuntimeError( f"Failed to get batch status: {await response.text()}" ) data = await response.json() return BatchInfo.from_api_response(data) async def _list_batches( self, after: Optional[str] = None, limit: Optional[int] = None, ) -> BatchListResponse: """Lists batch jobs. Args: after: Cursor for pagination (batch ID to start after) limit: Maximum number of batches to return (1-100) Returns: BatchListResponse: List of batch jobs """ connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) params = {} if after: params["after"] = after if limit: params["limit"] = str(limit) async with session.get( f"{self._remote_params.api_url}/batches", headers=headers, params=params, ) as response: if response.status != 200: raise RuntimeError( f"Failed to list batches: {await response.text()}" ) data = await response.json() batches = [ BatchInfo.from_api_response(batch_data) for batch_data in data["data"] ] return BatchListResponse( batches=batches, first_id=data.get("first_id"), last_id=data.get("last_id"), has_more=data.get("has_more", False), ) async def _get_batch_results_with_mapping( self, batch_id: str, conversations: list[Conversation], ) -> list[Conversation]: """Gets the results of a completed batch job and maps them to conversations. Args: batch_id: ID of the batch job conversations: Original conversations used to create the batch Returns: List[Conversation]: The processed conversations with responses Raises: RuntimeError: If batch status is not completed or if there are errors """ # Get batch status first batch_info = await self._get_batch_status(batch_id) if not batch_info.is_terminal: raise RuntimeError( f"Batch is not in terminal state. Status: {batch_info.status}" ) if batch_info.has_errors: # Download error file if there are failed requests if batch_info.error_file_id: error_content = await self._download_file(batch_info.error_file_id) raise RuntimeError(f"Batch has failed requests: {error_content}") raise RuntimeError(f"Batch failed with error: {batch_info.error}") # Download results file if not batch_info.output_file_id: raise RuntimeError("No output file available") results_content = await self._download_file(batch_info.output_file_id) # Parse results processed_conversations = [] for line, conv in zip(results_content.splitlines(), conversations): result = json.loads(line) if result.get("error"): raise RuntimeError(f"Batch request failed: {result['error']}") processed_conv = self._convert_api_output_to_conversation( result["response"]["body"], conv ) processed_conversations.append(processed_conv) return processed_conversations # # File operations #
[docs] def list_files( self, purpose: Optional[str] = None, limit: Optional[int] = None, order: str = "desc", after: Optional[str] = None, ) -> FileListResponse: """Lists files.""" return safe_asyncio_run( self._list_files( purpose=purpose, limit=limit, order=order, after=after, ) )
[docs] def get_file( self, file_id: str, ) -> FileInfo: """Gets information about a file.""" return safe_asyncio_run(self._get_file(file_id))
[docs] def delete_file( self, file_id: str, ) -> bool: """Deletes a file.""" return safe_asyncio_run(self._delete_file(file_id))
[docs] def get_file_content( self, file_id: str, ) -> str: """Gets a file's content.""" return safe_asyncio_run(self._download_file(file_id))
async def _list_files( self, purpose: Optional[str] = None, limit: Optional[int] = None, order: str = "desc", after: Optional[str] = None, ) -> FileListResponse: """Lists files. Args: purpose: Only return files with this purpose limit: Maximum number of files to return (1-10000) order: Sort order (asc or desc) after: Cursor for pagination Returns: FileListResponse: List of files """ connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) params = {"order": order} if purpose: params["purpose"] = purpose if limit: params["limit"] = str(limit) if after: params["after"] = after async with session.get( f"{self._remote_params.api_url}/files", headers=headers, params=params, ) as response: if response.status != 200: raise RuntimeError(f"Failed to list files: {await response.text()}") data = await response.json() files = [ FileInfo( id=file["id"], filename=file["filename"], bytes=file["bytes"], created_at=file["created_at"], purpose=file["purpose"], ) for file in data["data"] ] return FileListResponse( files=files, has_more=len(files) == limit if limit else False ) async def _get_file( self, file_id: str, ) -> FileInfo: """Gets information about a file. Args: file_id: ID of the file remote_params: Remote API parameters Returns: FileInfo: File information """ connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) async with session.get( f"{self._remote_params.api_url}/files/{file_id}", headers=headers, ) as response: if response.status != 200: raise RuntimeError(f"Failed to get file: {await response.text()}") data = await response.json() return FileInfo( id=data["id"], filename=data["filename"], bytes=data["bytes"], created_at=data["created_at"], purpose=data["purpose"], ) async def _delete_file( self, file_id: str, ) -> bool: """Deletes a file. Args: file_id: ID of the file to delete remote_params: Remote API parameters Returns: bool: True if deletion was successful """ connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) async with session.delete( f"{self._remote_params.api_url}/files/{file_id}", headers=headers, ) as response: if response.status != 200: raise RuntimeError( f"Failed to delete file: {await response.text()}" ) data = await response.json() return data.get("deleted", False) async def _download_file( self, file_id: str, ) -> str: """Downloads a file's content. Args: file_id: ID of the file to download remote_params: Remote API parameters Returns: str: The file content """ connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers) async with aiohttp.ClientSession(connector=connector) as session: headers = self._get_request_headers(self._remote_params) async with session.get( f"{self._remote_params.api_url}/files/{file_id}/content", headers=headers, ) as response: if response.status != 200: raise RuntimeError( f"Failed to download file: {await response.text()}" ) return await response.text()