# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import json
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import aiofiles
import aiohttp
import jsonlines
import pydantic
from tqdm.asyncio import tqdm
from typing_extensions import override
from oumi.core.async_utils import safe_asyncio_run
from oumi.core.configs import (
GenerationParams,
InferenceConfig,
ModelParams,
RemoteParams,
)
from oumi.core.inference import BaseInferenceEngine
from oumi.core.types.conversation import (
Conversation,
Message,
Role,
)
from oumi.utils.conversation_utils import (
convert_message_to_json_content_list,
create_list_of_message_json_dicts,
)
_AUTHORIZATION_KEY: str = "Authorization"
_BATCH_PURPOSE = "batch"
_BATCH_ENDPOINT = "/v1/chat/completions"
class BatchStatus(Enum):
"""Status of a batch inference job."""
VALIDATING = "validating"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
EXPIRED = "expired"
CANCELLED = "cancelled"
@dataclass
class BatchInfo:
"""Information about a batch job."""
id: str
status: BatchStatus
total_requests: int = 0
completed_requests: int = 0
failed_requests: int = 0
endpoint: Optional[str] = None
input_file_id: Optional[str] = None
batch_completion_window: Optional[str] = None
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
error: Optional[str] = None
created_at: Optional[datetime] = None
in_progress_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
finalizing_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
failed_at: Optional[datetime] = None
expired_at: Optional[datetime] = None
canceling_at: Optional[datetime] = None
canceled_at: Optional[datetime] = None
metadata: Optional[dict[str, Any]] = None
@staticmethod
def _convert_timestamp(timestamp: Optional[int]) -> Optional[datetime]:
"""Convert Unix timestamp to datetime.
Args:
timestamp: Unix timestamp in seconds
Returns:
datetime: Converted datetime or None if timestamp is None
"""
return datetime.fromtimestamp(timestamp) if timestamp is not None else None
@classmethod
def from_api_response(cls, response: dict[str, Any]) -> "BatchInfo":
"""Create BatchInfo from API response dictionary.
Args:
response: Raw API response dictionary
Returns:
BatchInfo: Parsed batch information
"""
return cls(
id=response["id"],
status=BatchStatus(response["status"]),
endpoint=response.get("endpoint"),
input_file_id=response.get("input_file_id"),
batch_completion_window=response.get("batch_completion_window"),
output_file_id=response.get("output_file_id"),
error_file_id=response.get("error_file_id"),
error=response.get("error"),
created_at=cls._convert_timestamp(response.get("created_at")),
in_progress_at=cls._convert_timestamp(response.get("in_progress_at")),
expires_at=cls._convert_timestamp(response.get("expires_at")),
finalizing_at=cls._convert_timestamp(response.get("finalizing_at")),
completed_at=cls._convert_timestamp(response.get("completed_at")),
failed_at=cls._convert_timestamp(response.get("failed_at")),
expired_at=cls._convert_timestamp(response.get("expired_at")),
canceling_at=cls._convert_timestamp(response.get("cancelling_at")),
canceled_at=cls._convert_timestamp(response.get("cancelled_at")),
total_requests=response.get("request_counts", {}).get("total", 0),
completed_requests=response.get("request_counts", {}).get("completed", 0),
failed_requests=response.get("request_counts", {}).get("failed", 0),
metadata=response.get("metadata"),
)
@property
def is_terminal(self) -> bool:
"""Return True if the batch is in a terminal state."""
return self.status in (
BatchStatus.COMPLETED,
BatchStatus.FAILED,
BatchStatus.EXPIRED,
BatchStatus.CANCELLED,
)
@property
def completion_percentage(self) -> float:
"""Return the percentage of completed requests."""
return (
(100 * self.completed_requests / self.total_requests)
if self.total_requests > 0
else 0.0
)
@property
def has_errors(self) -> bool:
"""Return True if the batch has any errors."""
return bool(self.error) or self.failed_requests > 0
@dataclass
class BatchListResponse:
"""Response from listing batch jobs."""
batches: list[BatchInfo]
first_id: Optional[str] = None
last_id: Optional[str] = None
has_more: bool = False
@dataclass
class FileInfo:
"""Information about a file."""
id: str
filename: str
bytes: int
created_at: int
purpose: str
@dataclass
class FileListResponse:
"""Response from listing files."""
files: list[FileInfo]
has_more: bool = False
[docs]
class RemoteInferenceEngine(BaseInferenceEngine):
"""Engine for running inference against a server implementing the OpenAI API."""
base_url: Optional[str] = None
"""The base URL for the remote API."""
api_key_env_varname: Optional[str] = None
"""The environment variable name for the API key."""
def __init__(
self,
model_params: ModelParams,
*,
generation_params: Optional[GenerationParams] = None,
remote_params: Optional[RemoteParams] = None,
):
"""Initializes the inference Engine.
Args:
model_params: The model parameters to use for inference.
generation_params: Generation parameters to use for inference.
remote_params: Remote server params.
**kwargs: Additional keyword arguments.
"""
super().__init__(model_params=model_params, generation_params=generation_params)
self._model = model_params.model_name
self._adapter_model = model_params.adapter_model
if remote_params:
remote_params = copy.deepcopy(remote_params)
else:
remote_params = RemoteParams()
if not remote_params.api_url:
remote_params.api_url = self.base_url
if not remote_params.api_key_env_varname:
remote_params.api_key_env_varname = self.api_key_env_varname
self._remote_params = remote_params
self._remote_params.finalize_and_validate()
@staticmethod
def _get_list_of_message_json_dicts(
messages: list[Message],
*,
group_adjacent_same_role_turns: bool,
) -> list[dict[str, Any]]:
return create_list_of_message_json_dicts(
messages, group_adjacent_same_role_turns=group_adjacent_same_role_turns
)
def _convert_conversation_to_api_input(
self, conversation: Conversation, generation_params: GenerationParams
) -> dict[str, Any]:
"""Converts a conversation to an OpenAI input.
Documentation: https://platform.openai.com/docs/api-reference/chat/create
Args:
conversation: The conversation to convert.
generation_params: Parameters for generation during inference.
Returns:
Dict[str, Any]: A dictionary representing the OpenAI input.
"""
api_input = {
"model": self._model,
"messages": [
{
"content": convert_message_to_json_content_list(message),
"role": message.role.value,
}
for message in conversation.messages
],
"max_completion_tokens": generation_params.max_new_tokens,
"temperature": generation_params.temperature,
"top_p": generation_params.top_p,
"frequency_penalty": generation_params.frequency_penalty,
"presence_penalty": generation_params.presence_penalty,
"n": 1, # Number of completions to generate for each prompt.
"seed": generation_params.seed,
"logit_bias": generation_params.logit_bias,
}
if generation_params.stop_strings:
api_input["stop"] = generation_params.stop_strings
if generation_params.guided_decoding:
json_schema = generation_params.guided_decoding.json
if json_schema is not None:
if isinstance(json_schema, type) and issubclass(
json_schema, pydantic.BaseModel
):
schema_name = json_schema.__name__
schema_value = json_schema.model_json_schema()
elif isinstance(json_schema, dict):
# Use a generic name if no schema is provided.
schema_name = "Response"
schema_value = json_schema
elif isinstance(json_schema, str):
# Use a generic name if no schema is provided.
schema_name = "Response"
# Try to parse as JSON string
schema_value = json.loads(json_schema)
else:
raise ValueError(
f"Got unsupported JSON schema type: {type(json_schema)}"
"Please provide a Pydantic model or a JSON schema as a "
"string or dict."
)
api_input["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": schema_name,
"schema": schema_value,
},
}
else:
raise ValueError(
"Only JSON schema guided decoding is supported, got '%s'",
generation_params.guided_decoding,
)
return api_input
def _convert_api_output_to_conversation(
self, response: dict[str, Any], original_conversation: Conversation
) -> Conversation:
"""Converts an API response to a conversation.
Args:
response: The API response to convert.
original_conversation: The original conversation.
Returns:
Conversation: The conversation including the generated response.
"""
message = response["choices"][0]["message"]
return Conversation(
messages=[
*original_conversation.messages,
Message(
content=message["content"],
role=Role(message["role"]),
),
],
metadata=original_conversation.metadata,
conversation_id=original_conversation.conversation_id,
)
def _get_api_key(self, remote_params: RemoteParams) -> Optional[str]:
if not remote_params:
return None
if remote_params.api_key:
return remote_params.api_key
if remote_params.api_key_env_varname:
return os.environ.get(remote_params.api_key_env_varname)
return None
def _get_request_headers(
self, remote_params: Optional[RemoteParams]
) -> dict[str, str]:
headers = {}
if not remote_params:
return headers
headers[_AUTHORIZATION_KEY] = f"Bearer {self._get_api_key(remote_params)}"
return headers
async def _query_api(
self,
conversation: Conversation,
semaphore: asyncio.Semaphore,
session: aiohttp.ClientSession,
inference_config: Optional[InferenceConfig] = None,
) -> Conversation:
"""Queries the API with the provided input.
Args:
conversation: The conversations to run inference on.
semaphore: Semaphore to limit concurrent requests.
session: The aiohttp session to use for the request.
inference_config: Parameters for inference.
Returns:
Conversation: Inference output.
"""
if inference_config is None:
remote_params = self._remote_params
generation_params = self._generation_params
output_path = None
else:
remote_params = inference_config.remote_params or self._remote_params
generation_params = inference_config.generation or self._generation_params
output_path = inference_config.output_path
assert remote_params.api_url
async with semaphore:
api_input = self._convert_conversation_to_api_input(
conversation, generation_params
)
headers = self._get_request_headers(remote_params)
retries = 0
failure_reason = None
# Retry the request if it fails.
for _ in range(remote_params.max_retries + 1):
async with session.post(
remote_params.api_url,
json=api_input,
headers=headers,
timeout=remote_params.connection_timeout,
) as response:
response_json = await response.json()
if response.status == 200:
result = self._convert_api_output_to_conversation(
response_json, conversation
)
if output_path:
# Write what we have so far to our scratch directory.
self._save_conversation(
result,
self._get_scratch_filepath(output_path),
)
await asyncio.sleep(remote_params.politeness_policy)
return result
else:
failure_reason = (
response_json.get("error").get("message")
if response_json and response_json.get("error")
else None
)
retries += 1
await asyncio.sleep(remote_params.politeness_policy)
raise RuntimeError(
f"Failed to query API after {remote_params.max_retries} retries. "
+ (f"Reason: {failure_reason}" if failure_reason else "")
)
async def _infer(
self,
input: list[Conversation],
inference_config: Optional[InferenceConfig] = None,
) -> list[Conversation]:
"""Runs model inference on the provided input.
Args:
input: A list of conversations to run inference on.
inference_config: Parameters for inference.
remote_params: Parameters for running inference against a remote API.
Returns:
List[Conversation]: Inference output.
"""
# Limit number of HTTP connections to the number of workers.
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
# Control the number of concurrent tasks via a semaphore.
semaphore = asyncio.BoundedSemaphore(self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
tasks = [
self._query_api(
conversation,
semaphore,
session,
inference_config=inference_config,
)
for conversation in input
]
disable_tqdm = len(tasks) < 2
return await tqdm.gather(*tasks, disable=disable_tqdm)
[docs]
@override
def infer_online(
self,
input: list[Conversation],
inference_config: Optional[InferenceConfig] = None,
) -> list[Conversation]:
"""Runs model inference online.
Args:
input: A list of conversations to run inference on.
inference_config: Parameters for inference.
Returns:
List[Conversation]: Inference output.
"""
conversations = safe_asyncio_run(self._infer(input, inference_config))
if inference_config and inference_config.output_path:
self._save_conversations(conversations, inference_config.output_path)
return conversations
[docs]
@override
def infer_from_file(
self, input_filepath: str, inference_config: Optional[InferenceConfig] = None
) -> list[Conversation]:
"""Runs model inference on inputs in the provided file.
This is a convenience method to prevent boilerplate from asserting the
existence of input_filepath in the generation_params.
Args:
input_filepath: Path to the input file containing prompts for
generation.
inference_config: Parameters for inference.
Returns:
List[Conversation]: Inference output.
"""
input = self._read_conversations(input_filepath)
conversations = safe_asyncio_run(self._infer(input, inference_config))
if inference_config and inference_config.output_path:
self._save_conversations(conversations, inference_config.output_path)
return conversations
[docs]
@override
def get_supported_params(self) -> set[str]:
"""Returns a set of supported generation parameters for this engine."""
return {
"frequency_penalty",
"guided_decoding",
"logit_bias",
"max_new_tokens",
"presence_penalty",
"seed",
"stop_strings",
"temperature",
"top_p",
}
#
# Batch inference
#
[docs]
def infer_batch(
self,
conversations: list[Conversation],
inference_config: Optional[InferenceConfig] = None,
) -> str:
"""Creates a new batch inference job.
Args:
conversations: List of conversations to process in batch
inference_config: Parameters for inference
Returns:
str: The batch job ID
"""
generation_params = (
inference_config.generation if inference_config else self._generation_params
)
return safe_asyncio_run(self._create_batch(conversations, generation_params))
[docs]
def get_batch_status(
self,
batch_id: str,
) -> BatchInfo:
"""Gets the status of a batch inference job.
Args:
batch_id: The batch job ID
Returns:
BatchInfo: Current status of the batch job
"""
return safe_asyncio_run(self._get_batch_status(batch_id))
[docs]
def list_batches(
self,
after: Optional[str] = None,
limit: Optional[int] = None,
) -> BatchListResponse:
"""Lists batch jobs.
Args:
after: Cursor for pagination (batch ID to start after)
limit: Maximum number of batches to return (1-100)
Returns:
BatchListResponse: List of batch jobs
"""
return safe_asyncio_run(
self._list_batches(
after=after,
limit=limit,
)
)
[docs]
def get_batch_results(
self,
batch_id: str,
conversations: list[Conversation],
) -> list[Conversation]:
"""Gets the results of a completed batch job.
Args:
batch_id: The batch job ID
conversations: Original conversations used to create the batch
Returns:
List[Conversation]: The processed conversations with responses
Raises:
RuntimeError: If the batch failed or has not completed
"""
return safe_asyncio_run(
self._get_batch_results_with_mapping(batch_id, conversations)
)
async def _upload_batch_file(
self,
batch_requests: list[dict],
) -> str:
"""Uploads a JSONL file containing batch requests.
Args:
batch_requests: List of request objects to include in the batch
Returns:
str: The uploaded file ID
"""
# Create temporary JSONL file
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False
) as tmp:
with jsonlines.Writer(tmp) as writer:
for request in batch_requests:
writer.write(request)
tmp_path = tmp.name
try:
# Upload the file
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
# Create form data with file
form = aiohttp.FormData()
async with aiofiles.open(tmp_path, "rb") as f:
file_data = await f.read()
form.add_field("file", file_data, filename="batch_requests.jsonl")
form.add_field("purpose", _BATCH_PURPOSE)
async with session.post(
f"{self._remote_params.api_url}/files",
data=form,
headers=headers,
) as response:
if response.status != 200:
raise RuntimeError(
f"Failed to upload batch file: {await response.text()}"
)
data = await response.json()
return data["id"]
finally:
# Clean up temporary file
Path(tmp_path).unlink()
async def _create_batch(
self,
conversations: list[Conversation],
generation_params: GenerationParams,
) -> str:
"""Creates a new batch job.
Args:
conversations: List of conversations to process in batch
generation_params: Generation parameters
Returns:
str: The batch job ID
"""
# Prepare batch requests
batch_requests = []
for i, conv in enumerate(conversations):
api_input = self._convert_conversation_to_api_input(conv, generation_params)
batch_requests.append(
{
"custom_id": f"request-{i}",
"method": "POST",
"url": _BATCH_ENDPOINT,
"body": api_input,
}
)
# Upload batch file
file_id = await self._upload_batch_file(batch_requests)
# Create batch
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
async with session.post(
f"{self._remote_params.api_url}/batches",
json={
"input_file_id": file_id,
"endpoint": _BATCH_ENDPOINT,
"batch_completion_window": (
self._remote_params.batch_completion_window
),
},
headers=headers,
) as response:
if response.status != 200:
raise RuntimeError(
f"Failed to create batch: {await response.text()}"
)
data = await response.json()
return data["id"]
async def _get_batch_status(
self,
batch_id: str,
) -> BatchInfo:
"""Gets the status of a batch job.
Args:
batch_id: ID of the batch job
Returns:
BatchInfo: Current status of the batch job
"""
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
async with session.get(
f"{self._remote_params.api_url}/batches/{batch_id}",
headers=headers,
) as response:
if response.status != 200:
raise RuntimeError(
f"Failed to get batch status: {await response.text()}"
)
data = await response.json()
return BatchInfo.from_api_response(data)
async def _list_batches(
self,
after: Optional[str] = None,
limit: Optional[int] = None,
) -> BatchListResponse:
"""Lists batch jobs.
Args:
after: Cursor for pagination (batch ID to start after)
limit: Maximum number of batches to return (1-100)
Returns:
BatchListResponse: List of batch jobs
"""
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
params = {}
if after:
params["after"] = after
if limit:
params["limit"] = str(limit)
async with session.get(
f"{self._remote_params.api_url}/batches",
headers=headers,
params=params,
) as response:
if response.status != 200:
raise RuntimeError(
f"Failed to list batches: {await response.text()}"
)
data = await response.json()
batches = [
BatchInfo.from_api_response(batch_data)
for batch_data in data["data"]
]
return BatchListResponse(
batches=batches,
first_id=data.get("first_id"),
last_id=data.get("last_id"),
has_more=data.get("has_more", False),
)
async def _get_batch_results_with_mapping(
self,
batch_id: str,
conversations: list[Conversation],
) -> list[Conversation]:
"""Gets the results of a completed batch job and maps them to conversations.
Args:
batch_id: ID of the batch job
conversations: Original conversations used to create the batch
Returns:
List[Conversation]: The processed conversations with responses
Raises:
RuntimeError: If batch status is not completed or if there are errors
"""
# Get batch status first
batch_info = await self._get_batch_status(batch_id)
if not batch_info.is_terminal:
raise RuntimeError(
f"Batch is not in terminal state. Status: {batch_info.status}"
)
if batch_info.has_errors:
# Download error file if there are failed requests
if batch_info.error_file_id:
error_content = await self._download_file(batch_info.error_file_id)
raise RuntimeError(f"Batch has failed requests: {error_content}")
raise RuntimeError(f"Batch failed with error: {batch_info.error}")
# Download results file
if not batch_info.output_file_id:
raise RuntimeError("No output file available")
results_content = await self._download_file(batch_info.output_file_id)
# Parse results
processed_conversations = []
for line, conv in zip(results_content.splitlines(), conversations):
result = json.loads(line)
if result.get("error"):
raise RuntimeError(f"Batch request failed: {result['error']}")
processed_conv = self._convert_api_output_to_conversation(
result["response"]["body"], conv
)
processed_conversations.append(processed_conv)
return processed_conversations
#
# File operations
#
[docs]
def list_files(
self,
purpose: Optional[str] = None,
limit: Optional[int] = None,
order: str = "desc",
after: Optional[str] = None,
) -> FileListResponse:
"""Lists files."""
return safe_asyncio_run(
self._list_files(
purpose=purpose,
limit=limit,
order=order,
after=after,
)
)
[docs]
def get_file(
self,
file_id: str,
) -> FileInfo:
"""Gets information about a file."""
return safe_asyncio_run(self._get_file(file_id))
[docs]
def delete_file(
self,
file_id: str,
) -> bool:
"""Deletes a file."""
return safe_asyncio_run(self._delete_file(file_id))
[docs]
def get_file_content(
self,
file_id: str,
) -> str:
"""Gets a file's content."""
return safe_asyncio_run(self._download_file(file_id))
async def _list_files(
self,
purpose: Optional[str] = None,
limit: Optional[int] = None,
order: str = "desc",
after: Optional[str] = None,
) -> FileListResponse:
"""Lists files.
Args:
purpose: Only return files with this purpose
limit: Maximum number of files to return (1-10000)
order: Sort order (asc or desc)
after: Cursor for pagination
Returns:
FileListResponse: List of files
"""
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
params = {"order": order}
if purpose:
params["purpose"] = purpose
if limit:
params["limit"] = str(limit)
if after:
params["after"] = after
async with session.get(
f"{self._remote_params.api_url}/files",
headers=headers,
params=params,
) as response:
if response.status != 200:
raise RuntimeError(f"Failed to list files: {await response.text()}")
data = await response.json()
files = [
FileInfo(
id=file["id"],
filename=file["filename"],
bytes=file["bytes"],
created_at=file["created_at"],
purpose=file["purpose"],
)
for file in data["data"]
]
return FileListResponse(
files=files, has_more=len(files) == limit if limit else False
)
async def _get_file(
self,
file_id: str,
) -> FileInfo:
"""Gets information about a file.
Args:
file_id: ID of the file
remote_params: Remote API parameters
Returns:
FileInfo: File information
"""
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
async with session.get(
f"{self._remote_params.api_url}/files/{file_id}",
headers=headers,
) as response:
if response.status != 200:
raise RuntimeError(f"Failed to get file: {await response.text()}")
data = await response.json()
return FileInfo(
id=data["id"],
filename=data["filename"],
bytes=data["bytes"],
created_at=data["created_at"],
purpose=data["purpose"],
)
async def _delete_file(
self,
file_id: str,
) -> bool:
"""Deletes a file.
Args:
file_id: ID of the file to delete
remote_params: Remote API parameters
Returns:
bool: True if deletion was successful
"""
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
async with session.delete(
f"{self._remote_params.api_url}/files/{file_id}",
headers=headers,
) as response:
if response.status != 200:
raise RuntimeError(
f"Failed to delete file: {await response.text()}"
)
data = await response.json()
return data.get("deleted", False)
async def _download_file(
self,
file_id: str,
) -> str:
"""Downloads a file's content.
Args:
file_id: ID of the file to download
remote_params: Remote API parameters
Returns:
str: The file content
"""
connector = aiohttp.TCPConnector(limit=self._remote_params.num_workers)
async with aiohttp.ClientSession(connector=connector) as session:
headers = self._get_request_headers(self._remote_params)
async with session.get(
f"{self._remote_params.api_url}/files/{file_id}/content",
headers=headers,
) as response:
if response.status != 200:
raise RuntimeError(
f"Failed to download file: {await response.text()}"
)
return await response.text()