Source code for oumi.utils.logging
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import warnings
from pathlib import Path
from typing import Optional, Union
[docs]
def get_logger(
name: str,
level: str = "info",
log_dir: Optional[Union[str, Path]] = None,
) -> logging.Logger:
"""Gets a logger instance with the specified name and log level.
Args:
name : The name of the logger.
level (optional): The log level to set for the logger. Defaults to "info".
log_dir (optional): Directory to store log files. Defaults to None.
Returns:
logging.Logger: The logger instance.
"""
if name not in logging.Logger.manager.loggerDict:
configure_logger(name, level=level, log_dir=log_dir)
logger = logging.getLogger(name)
return logger
def _detect_rank() -> int:
"""Detects rank.
Reading the rank from the environment variables instead of
get_device_rank_info to avoid circular imports.
"""
for var_name in (
"RANK",
"SKYPILOT_NODE_RANK", # SkyPilot
"PMI_RANK", # HPC
):
rank = os.environ.get(var_name, None)
if rank is not None:
rank = int(rank)
if rank < 0:
raise ValueError(f"Negative rank: {rank} specified in '{var_name}'!")
return rank
return 0
[docs]
def configure_logger(
name: str,
level: str = "info",
log_dir: Optional[Union[str, Path]] = None,
) -> None:
"""Configures a logger with the specified name and log level."""
logger = logging.getLogger(name)
# Remove any existing handlers
logger.handlers = []
# Configure the logger
logger.setLevel(level.upper())
device_rank = _detect_rank()
formatter = logging.Formatter(
"[%(asctime)s][%(name)s]"
f"[rank{device_rank}]"
"[pid:%(process)d][%(threadName)s]"
"[%(levelname)s]][%(filename)s:%(lineno)s] %(message)s"
)
# Add a console handler to the logger for only global leader.
if device_rank == 0:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
console_handler.setLevel(level.upper())
logger.addHandler(console_handler)
# Add a file handler if log_dir is provided
if log_dir:
log_dir = Path(log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_dir / f"rank_{device_rank:04d}.log")
file_handler.setFormatter(formatter)
file_handler.setLevel(level.upper())
logger.addHandler(file_handler)
logger.propagate = False
[docs]
def update_logger_level(name: str, level: str = "info") -> None:
"""Updates the log level of the logger.
Args:
name (str): The logger instance to update.
level (str, optional): The log level to set for the logger. Defaults to "info".
"""
logger = get_logger(name, level=level)
logger.setLevel(level.upper())
for handler in logger.handlers:
handler.setLevel(level.upper())
[docs]
def configure_dependency_warnings(level: Union[str, int] = "info") -> None:
"""Ignores non-critical warnings from dependencies, unless in debug mode.
Args:
level (str, optional): The log level to set for the logger. Defaults to "info".
"""
level_value = logging.DEBUG
if isinstance(level, str):
level_value = logging.getLevelName(level.upper())
if not isinstance(level_value, int):
raise TypeError(
f"getLevelName() mapped log level name to non-integer: "
f"{type(level_value)}!"
)
elif isinstance(level, int):
level_value = int(level)
if level_value > logging.DEBUG:
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
warnings.filterwarnings(
action="ignore", category=UserWarning, module="huggingface_hub"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="transformers"
)
# Default logger for the package
logger = get_logger("oumi")