Source code for oumi.cli.distributed_run

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import enum
import os
import sys
import time
from subprocess import Popen
from sys import stderr, stdout
from typing import Any, Final, NamedTuple, Optional

import typer

import oumi.cli.cli_utils as cli_utils
from oumi.utils.logging import logger

# Port range [1024, 65535] is generally available
# for application use w/o root permissions (non-privileged)
_MASTER_PORT_MIN_VALID_VALUE: Final[int] = 1024
_MASTER_PORT_MAX_VALID_VALUE: Final[int] = 65535

_SKY_ENV_VARS = {
    "SKYPILOT_NODE_RANK",
    "SKYPILOT_NODE_IPS",
    "SKYPILOT_NUM_GPUS_PER_NODE",
}

_POLARIS_ENV_VARS = {
    "PBS_NODEFILE",
    "PBS_JOBID",
}

_MASTER_ADDR_ENV = "MASTER_ADDRESS"
_MASTER_PORT_ENV = "MASTER_PORT"

_DEFAULT_MASTER_ADDR = "127.0.0.1"
_DEFAULT_MASTER_PORT = 8007


class _RunBackend(str, enum.Enum):
    SKYPILOT = "SkyPilot"
    POLARIS = "Polaris"
    LOCAL_MACHINE = "LocalMachine"


class _WorldInfo(NamedTuple):
    num_nodes: int
    """Total number of nodes (machines)."""
    gpus_per_node: int
    """Number of GPU-s per node."""


class _ProcessRunInfo:
    def __init__(
        self,
        node_rank: int,
        world_info: _WorldInfo,
        master_address: str,
        master_port: int,
        node_ips: list[str],
    ):
        """Initializes run info, and validates arguments."""
        if not (world_info.num_nodes > 0 and world_info.gpus_per_node > 0):
            raise ValueError(
                f"Non-positive number of nodes or GPUs per node: {world_info}"
            )
        elif not (node_rank >= 0 and node_rank < world_info.num_nodes):
            raise ValueError(
                f"Node rank {node_rank} is out of range: [0, {world_info.num_nodes})."
            )
        elif len(master_address) == 0:
            raise ValueError(f"Empty master address: {master_address}.")
        elif not (
            master_port >= _MASTER_PORT_MIN_VALID_VALUE
            and master_port <= _MASTER_PORT_MAX_VALID_VALUE
        ):
            raise ValueError(
                f"Master port: {master_port} is outside of valid range: "
                f"[{_MASTER_PORT_MIN_VALID_VALUE}, {_MASTER_PORT_MAX_VALID_VALUE}]."
            )

        self._world_info = world_info
        self._node_rank = int(node_rank)
        self._master_address = master_address
        self._master_port = master_port
        self._node_ips = node_ips

    @property
    def node_rank(self) -> int:
        """Node rank in the [0, num_nodes) range."""
        return self._node_rank

    @property
    def num_nodes(self) -> int:
        """Total number of nodes (machines)."""
        return self._world_info.num_nodes

    @property
    def gpus_per_node(self) -> int:
        """Number of GPU-s per node."""
        return self._world_info.gpus_per_node

    @property
    def total_gpus(self) -> int:
        """Total number of nodes (machines)."""
        return self._world_info.num_nodes * self._world_info.gpus_per_node

    @property
    def master_address(self) -> str:
        """Master address."""
        return self._master_address

    @property
    def node_ips(self) -> list[str]:
        """List of node IPs."""
        return self._node_ips

    @property
    def master_port(self) -> int:
        """Master port."""
        return self._master_port

    def __repr__(self) -> str:
        """Defines how this class is properly printed."""
        fields_dict: dict[str, Any] = {
            "node_rank": self.node_rank,
            "num_nodes": self.num_nodes,
            "gpus_per_node": self.gpus_per_node,
            "total_gpus": self.total_gpus,
            "master_address": self.master_address,
            "master_port": self.master_port,
            "node_ips": self.node_ips,
        }
        return repr(fields_dict)


#
# Comamnds
#
[docs] def torchrun( ctx: typer.Context, level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Starts `torchrun` sub-process w/ automatically configured common params. Args: ctx: The Typer context object. level: The logging level for the specified command. """ try: run_info: _ProcessRunInfo = _detect_process_run_info(os.environ.copy()) except (ValueError, RuntimeError): logger.exception("Failed to detect process run info!") raise try: cmds: list[str] = [ "torchrun", f"--nnodes={run_info.num_nodes}", f"--node-rank={run_info.node_rank}", f"--nproc-per-node={run_info.gpus_per_node}", f"--master-addr={run_info.master_address}", f"--master-port={run_info.master_port}", ] cmds.extend(ctx.args) _run_subprocess(cmds, rank=run_info.node_rank) except Exception: logger.exception(f"`torchrun` failed (Rank: {run_info.node_rank})!") raise
[docs] def accelerate( ctx: typer.Context, level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Starts `accelerate` sub-process w/ automatically configured common params. Args: ctx: The Typer context object. level: The logging level for the specified command. """ try: run_info: _ProcessRunInfo = _detect_process_run_info(os.environ.copy()) except (ValueError, RuntimeError): logger.exception("Failed to detect process run info!") raise try: accelerate_subcommand: Optional[str] = None extra_args = copy.deepcopy(ctx.args) if ( len(extra_args) > 0 and len(extra_args[0]) > 0 and not extra_args[0].startswith("-") ): # Copy sub-commands like "launch" to insert them right after `accelerate` # ("accelerate launch ...") accelerate_subcommand = extra_args.pop(0) cmds: list[str] = ( ["accelerate"] + ([accelerate_subcommand] if accelerate_subcommand is not None else []) + [ f"--num_machines={run_info.num_nodes}", f"--machine_rank={run_info.node_rank}", f"--num_processes={run_info.total_gpus}", f"--main_process_ip={run_info.master_address}", f"--main_process_port={run_info.master_port}", ] ) cmds.extend(extra_args) _run_subprocess(cmds, rank=run_info.node_rank) except Exception: logger.exception(f"`accelerate` failed (Rank: {run_info.node_rank})!") raise
# # Helper functions # def _detect_process_run_info(env: dict[str, str]) -> _ProcessRunInfo: """Detects process run info. Uses known environment variables to detect common runtime parameters. Args: env: All environment variables. Returns: Process run info. Raises: ValueError: If any of the required environment variables are missing or invalid. RuntimeError: If the node list is empty, or there are issues with backend detection. """ # Detect the process run info depending on the runtime environment. # Each runtime environment is checked in the order of priority. process_run_info = _detect_skypilot_process_run_info(env) if process_run_info is None: process_run_info = _detect_polaris_process_run_info(env) if process_run_info is None: process_run_info = _detect_local_machine_process_run_info(env) if process_run_info is None: raise RuntimeError("Failed to detect process run info!") # Extra verification logic to make sure that the detected process run info is # consistent with the environment variables. # Will raise an exception if the detected process run info is not consistent. _verify_process_run_info(process_run_info, env) return process_run_info def _run_subprocess(cmds: list[str], *, rank: int) -> None: env_copy = os.environ.copy() start_time = time.perf_counter() logger.info(f"Running the command: {cmds}") p = Popen( cmds, env=env_copy, stdout=stdout, stderr=stderr, bufsize=1, universal_newlines=True, ) rc = p.wait() duration_sec = time.perf_counter() - start_time duration_str = f"Duration: {duration_sec:.1f} sec" if rc != 0: logger.error( f"{cmds[0]} failed with exit code: {rc} ({duration_str}). Command: {cmds}" ) sys.exit(rc) logger.info(f"Successfully completed! (Rank: {rank}. {duration_str})") def _verify_process_run_info(run_info: _ProcessRunInfo, env: dict[str, str]) -> None: oumi_total_gpus: Optional[int] = _get_optional_int_env_var( "OUMI_TOTAL_NUM_GPUS", env ) oumi_num_nodes: Optional[int] = _get_optional_int_env_var("OUMI_NUM_NODES", env) oumi_master_address: Optional[str] = env.get("OUMI_MASTER_ADDR", None) if oumi_master_address is not None and len(oumi_master_address) == 0: raise ValueError("Empty master address in 'OUMI_MASTER_ADDR'!") assert len(run_info.node_ips) > 0, "Empty list of nodes!" assert run_info.node_rank is not None if oumi_num_nodes is not None and oumi_num_nodes != run_info.num_nodes: raise ValueError( "Inconsistent number of nodes: " f"{run_info.num_nodes} vs {oumi_num_nodes} in 'OUMI_NUM_NODES'." ) elif oumi_total_gpus is not None and (oumi_total_gpus != run_info.total_gpus): raise ValueError( "Inconsistent total number of GPUs: " f"{run_info.total_gpus} vs {oumi_total_gpus} " "in 'OUMI_TOTAL_NUM_GPUS'. " f"Nodes: {run_info.num_nodes}. GPU-s per node: {run_info.gpus_per_node}." ) elif oumi_master_address and oumi_master_address not in run_info.node_ips: raise ValueError( f"Master address '{oumi_master_address}' not found in the list of nodes." ) # # Parse environment variables # def _detect_polaris_process_run_info(env: dict[str, str]) -> Optional[_ProcessRunInfo]: polaris_node_file = env.get("PBS_NODEFILE", None) if polaris_node_file is None: return None logger.debug("Running in Polaris environment!") for env_var_name in _POLARIS_ENV_VARS: if env.get(env_var_name, None) is None: raise ValueError( f"Polaris environment variable '{env_var_name}' is not defined!" ) if not polaris_node_file: raise ValueError("Empty value in the 'PBS_NODEFILE' environment variable!") with open(polaris_node_file) as f: nodes_str = f.read() node_ips = _parse_nodes_str(nodes_str) if len(node_ips) == 0: raise RuntimeError("Empty list of nodes in 'PBS_NODEFILE'!") gpus_per_node = 4 # Per Polaris spec. node_rank = _get_optional_int_env_var("PMI_RANK", env) if node_rank is None: node_rank = 0 return _ProcessRunInfo( node_rank=node_rank, world_info=_WorldInfo(num_nodes=len(node_ips), gpus_per_node=gpus_per_node), master_address=node_ips[0], master_port=_DEFAULT_MASTER_PORT, node_ips=node_ips, ) def _detect_skypilot_process_run_info(env: dict[str, str]) -> Optional[_ProcessRunInfo]: node_rank: Optional[int] = _get_optional_int_env_var("SKYPILOT_NODE_RANK", env) if node_rank is None: return None logger.debug("Running in SkyPilot environment!") for env_var_name in _SKY_ENV_VARS: if env.get(env_var_name, None) is None: raise ValueError( f"SkyPilot environment variable '{env_var_name}' is not defined!" ) node_ips = _parse_nodes_str(env.get("SKYPILOT_NODE_IPS", "")) if len(node_ips) == 0: raise RuntimeError("Empty list of nodes in 'SKYPILOT_NODE_IPS'!") gpus_per_node = _get_positive_int_env_var("SKYPILOT_NUM_GPUS_PER_NODE", env) return _ProcessRunInfo( node_rank=node_rank, world_info=_WorldInfo(num_nodes=len(node_ips), gpus_per_node=gpus_per_node), master_address=node_ips[0], master_port=_DEFAULT_MASTER_PORT, node_ips=node_ips, ) def _detect_local_machine_process_run_info(env: dict[str, str]) -> _ProcessRunInfo: import torch # Importing torch takes time so only load it in this scenario. # Attempt to produce a local configuration if not torch.cuda.is_available(): raise RuntimeError( "No supported distributed backends found and no GPUs on local machine!" ) num_gpus_available = torch.cuda.device_count() if num_gpus_available > 0: oumi_num_nodes = 1 oumi_master_address = env.get(_MASTER_ADDR_ENV, _DEFAULT_MASTER_ADDR) oumi_master_port = int(env.get(_MASTER_PORT_ENV, _DEFAULT_MASTER_PORT)) node_rank = 0 gpus_per_node = num_gpus_available node_ips = [oumi_master_address] cli_utils.configure_common_env_vars() else: raise RuntimeError("CUDA available but no GPUs found on local machine!") return _ProcessRunInfo( node_rank=node_rank, world_info=_WorldInfo(num_nodes=oumi_num_nodes, gpus_per_node=gpus_per_node), master_address=oumi_master_address, master_port=oumi_master_port, node_ips=node_ips, ) # # Private helper functions to parse environment variables # def _get_optional_int_env_var(var_name: str, env: dict[str, str]) -> Optional[int]: str_value = env.get(var_name, None) if str_value is None: return None try: int_value = int(str_value) except ValueError as e: raise ValueError(f"Environment variable '{var_name}' is not an integer!") from e return int_value def _get_int_env_var(var_name: str, env: dict[str, str]) -> int: int_value = _get_optional_int_env_var(var_name, env) if int_value is None: raise ValueError(f"Environment variable '{var_name}' is not defined!") return int_value def _get_positive_int_env_var(var_name: str, env: dict[str, str]) -> int: int_value = _get_int_env_var(var_name, env) if not (int_value > 0): raise ValueError( f"Environment variable '{var_name}' is not positive: {int_value}!" ) return int_value def _parse_nodes_str(nodes_str: str) -> list[str]: node_ips = [x.strip() for x in nodes_str.split("\n")] node_ips = [x for x in node_ips if len(x) > 0] return node_ips