Source code for oumi.core.tuners.base_tuner

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from abc import ABC, abstractmethod
from typing import Any, Callable

from oumi.core.configs.tuning_config import TuningConfig, TuningParams


[docs] class BaseTuner(ABC): """Abstract base class for hyperparameter tuners. This class defines the interface that all tuner implementations must follow, allowing for different optimization backends (Optuna, Ray Tune, etc.) while maintaining a consistent API. """ def __init__(self, tuning_params: TuningParams): """Initialize the tuner with configuration parameters. Args: tuning_params: Configuration for the tuning process. """ self.tuning_params = tuning_params self._study = None
[docs] @abstractmethod def create_study(self) -> None: """Create a new optimization study. This method should initialize the tuner's internal study object with the appropriate configuration (study name, direction, etc.). """ pass
[docs] @abstractmethod def suggest_parameters(self, trial: Any) -> dict[str, Any]: """Suggest hyperparameters for a trial. Args: trial: The trial object from the underlying tuner backend. Returns: Dictionary mapping parameter names to suggested values. """ pass
[docs] @abstractmethod def optimize( self, objective_fn: Callable[..., Any], n_trials: int, ) -> None: """Run the optimization process. Args: objective_fn: Function that takes suggested parameters and returns a dictionary of metric values. n_trials: Number of trials to run. Returns: None """ pass
[docs] @abstractmethod def get_best_trial(self) -> dict[str, Any]: """Get the best trial from the study, if only one objective is being optimized. Returns: Dictionary containing best parameters and their metric values. """ pass
[docs] @abstractmethod def get_best_trials(self) -> list[dict[str, Any]]: """Get the best trials from the study, for multiple objectives. Returns: Dictionary containing best parameters and their metric values for the best trials. """ pass
[docs] @abstractmethod def save_study(self, config: TuningConfig) -> None: """Saves the study object to the specified output directory. Args: config (TrainingConfig): The Oumi training config. Returns: None """ pass