Source code for oumi.core.configs.params.remote_params

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

import numpy as np

from oumi.core.configs.params.base_params import BaseParams


[docs] @dataclass class RemoteParams(BaseParams): """Parameters for running inference against a remote API.""" api_url: Optional[str] = None """URL of the API endpoint to use for inference.""" api_key: Optional[str] = None """API key to use for authentication.""" api_key_env_varname: Optional[str] = None """Name of the environment variable containing the API key for authentication.""" max_retries: int = 3 """Maximum number of retries to attempt when calling an API.""" connection_timeout: float = 20.0 """Timeout in seconds for a request to an API.""" num_workers: int = 1 """Number of workers to use for parallel inference.""" politeness_policy: float = 0.0 """Politeness policy to use when calling an API. If greater than zero, this is the amount of time in seconds a worker will sleep before making a subsequent request. """ batch_completion_window: Optional[str] = "24h" """Time window for batch completion. Currently only '24h' is supported. Only used for batch inference. """
[docs] def __post_init__(self): """Validate the remote parameters.""" if self.num_workers < 1: raise ValueError( "Number of num_workers must be greater than or equal to 1." ) if self.politeness_policy < 0: raise ValueError("Politeness policy must be greater than or equal to 0.") if self.connection_timeout < 0: raise ValueError("Connection timeout must be greater than or equal to 0.") if not np.isfinite(self.politeness_policy): raise ValueError("Politeness policy must be finite.") if self.max_retries < 0: raise ValueError("Max retries must be greater than or equal to 0.")
[docs] def finalize_and_validate(self): """Finalize the remote parameters.""" if not self.api_url: raise ValueError("The API URL must be provided in remote_params.")