Source code for oumi.core.configs.base_config
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import logging
import re
from collections.abc import Iterator
from io import StringIO
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, cast
from omegaconf import OmegaConf
from oumi.core.configs.params.base_params import BaseParams
T = TypeVar("T", bound="BaseConfig")
_CLI_IGNORED_PREFIXES = ["--local-rank"]
def _filter_ignored_args(arg_list: list[str]) -> list[str]:
"""Filters out ignored CLI arguments."""
return [
arg
for arg in arg_list
if not any(arg.startswith(prefix) for prefix in _CLI_IGNORED_PREFIXES)
]
def _read_config_without_interpolation(config_path: str) -> str:
"""Reads a configuration file without interpolating variables.
Args:
config_path: The path to the configuration file.
Returns:
str: The stringified configuration.
"""
with open(config_path) as f:
stringified_config = f.read()
pattern = r"(?<!\\)\$\{" # Matches "${" but not "\${"
stringified_config = re.sub(pattern, "\\${", stringified_config)
return stringified_config
[docs]
@dataclasses.dataclass
class BaseConfig:
[docs]
def to_yaml(self, config_path: Union[str, Path, StringIO]) -> None:
"""Saves the configuration to a YAML file."""
OmegaConf.save(config=self, f=config_path)
[docs]
@classmethod
def from_yaml(
cls: type[T], config_path: Union[str, Path], ignore_interpolation=True
) -> T:
"""Loads a configuration from a YAML file.
Args:
config_path: The path to the YAML file.
ignore_interpolation: If True, then any interpolation variables in the
configuration file will be escaped.
Returns:
BaseConfig: The merged configuration object.
"""
schema = OmegaConf.structured(cls)
if ignore_interpolation:
stringified_config = _read_config_without_interpolation(str(config_path))
file_config = OmegaConf.create(stringified_config)
else:
file_config = OmegaConf.load(config_path)
config = OmegaConf.to_object(OmegaConf.merge(schema, file_config))
if not isinstance(config, cls):
raise TypeError(f"config is not {cls}")
return cast(T, config)
[docs]
@classmethod
def from_yaml_and_arg_list(
cls: type[T],
config_path: Optional[str],
arg_list: list[str],
logger: Optional[logging.Logger] = None,
ignore_interpolation=True,
) -> T:
"""Loads a configuration from various sources.
If both YAML and arguments list are provided, then
parameters specified in `arg_list` have higher precedence.
Args:
config_path: The path to the YAML file.
arg_list: Command line arguments list.
logger: (optional) Logger.
ignore_interpolation: If True, then any interpolation variables in the
configuration file will be escaped.
Returns:
BaseConfig: The merged configuration object.
"""
# Start with an empty typed config. This forces OmegaConf to validate
# that all other configs are of this structured type as well.
all_configs = [OmegaConf.structured(cls)]
# Override with configuration file if provided.
if config_path is not None:
if ignore_interpolation:
stringified_config = _read_config_without_interpolation(config_path)
all_configs.append(OmegaConf.create(stringified_config))
else:
all_configs.append(cls.from_yaml(config_path))
# Filter out CLI arguments that should be ignored.
arg_list = _filter_ignored_args(arg_list)
# Override with CLI arguments.
all_configs.append(OmegaConf.from_cli(arg_list))
try:
# Merge and validate configs
config = OmegaConf.merge(*all_configs)
except Exception:
if logger:
configs_str = "\n\n".join([f"{config}" for config in all_configs])
logger.exception(
f"Failed to merge {len(all_configs)} Omega configs:\n{configs_str}"
)
raise
config = OmegaConf.to_object(config)
if not isinstance(config, cls):
raise TypeError(f"config {type(config)} is not {type(cls)}")
return cast(T, config)
[docs]
def finalize_and_validate(self) -> None:
"""Finalizes and validates the top level params objects."""
for _, attr_value in self:
if isinstance(attr_value, BaseParams):
attr_value.finalize_and_validate()
self.__finalize_and_validate__()
[docs]
def __finalize_and_validate__(self) -> None:
"""Finalizes and validates the parameters of this object.
This method can be overridden by subclasses to implement custom
validation logic.
In case of validation errors, this method should raise a `ValueError`
or other appropriate exception.
"""
[docs]
def __iter__(self) -> Iterator[tuple[str, Any]]:
"""Returns an iterator over field names and values.
Note: for an attribute to be a field, it must be declared in the
dataclass definition and have a type annotation.
"""
for param in dataclasses.fields(self):
yield param.name, getattr(self, param.name)