# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcopyimportenumimportosimportsysimporttimefromsubprocessimportPopenfromsysimportstderr,stdoutfromtypingimportAny,Final,NamedTuple,Optionalimporttyperimportoumi.cli.cli_utilsascli_utilsfromoumi.utils.loggingimportlogger# Port range [1024, 65535] is generally available# for application use w/o root permissions (non-privileged)_MASTER_PORT_MIN_VALID_VALUE:Final[int]=1024_MASTER_PORT_MAX_VALID_VALUE:Final[int]=65535_SKY_ENV_VARS={"SKYPILOT_NODE_RANK","SKYPILOT_NODE_IPS","SKYPILOT_NUM_GPUS_PER_NODE",}_POLARIS_ENV_VARS={"PBS_NODEFILE","PBS_JOBID",}_MASTER_ADDR_ENV="MASTER_ADDRESS"_MASTER_PORT_ENV="MASTER_PORT"_DEFAULT_MASTER_ADDR="127.0.0.1"_DEFAULT_MASTER_PORT=8007class_RunBackend(str,enum.Enum):SKYPILOT="SkyPilot"POLARIS="Polaris"LOCAL_MACHINE="LocalMachine"class_WorldInfo(NamedTuple):num_nodes:int"""Total number of nodes (machines)."""gpus_per_node:int"""Number of GPU-s per node."""class_ProcessRunInfo:def__init__(self,node_rank:int,world_info:_WorldInfo,master_address:str,master_port:int,node_ips:list[str],):"""Initializes run info, and validates arguments."""ifnot(world_info.num_nodes>0andworld_info.gpus_per_node>0):raiseValueError(f"Non-positive number of nodes or GPUs per node: {world_info}")elifnot(node_rank>=0andnode_rank<world_info.num_nodes):raiseValueError(f"Node rank {node_rank} is out of range: [0, {world_info.num_nodes}).")eliflen(master_address)==0:raiseValueError(f"Empty master address: {master_address}.")elifnot(master_port>=_MASTER_PORT_MIN_VALID_VALUEandmaster_port<=_MASTER_PORT_MAX_VALID_VALUE):raiseValueError(f"Master port: {master_port} is outside of valid range: "f"[{_MASTER_PORT_MIN_VALID_VALUE}, {_MASTER_PORT_MAX_VALID_VALUE}].")self._world_info=world_infoself._node_rank=int(node_rank)self._master_address=master_addressself._master_port=master_portself._node_ips=node_ips@propertydefnode_rank(self)->int:"""Node rank in the [0, num_nodes) range."""returnself._node_rank@propertydefnum_nodes(self)->int:"""Total number of nodes (machines)."""returnself._world_info.num_nodes@propertydefgpus_per_node(self)->int:"""Number of GPU-s per node."""returnself._world_info.gpus_per_node@propertydeftotal_gpus(self)->int:"""Total number of nodes (machines)."""returnself._world_info.num_nodes*self._world_info.gpus_per_node@propertydefmaster_address(self)->str:"""Master address."""returnself._master_address@propertydefnode_ips(self)->list[str]:"""List of node IPs."""returnself._node_ips@propertydefmaster_port(self)->int:"""Master port."""returnself._master_portdef__repr__(self)->str:"""Defines how this class is properly printed."""fields_dict:dict[str,Any]={"node_rank":self.node_rank,"num_nodes":self.num_nodes,"gpus_per_node":self.gpus_per_node,"total_gpus":self.total_gpus,"master_address":self.master_address,"master_port":self.master_port,"node_ips":self.node_ips,}returnrepr(fields_dict)## Comamnds#
[docs]deftorchrun(ctx:typer.Context,level:cli_utils.LOG_LEVEL_TYPE=None,)->None:"""Starts `torchrun` sub-process w/ automatically configured common params. Args: ctx: The Typer context object. level: The logging level for the specified command. """try:run_info:_ProcessRunInfo=_detect_process_run_info(os.environ.copy())except(ValueError,RuntimeError):logger.exception("Failed to detect process run info!")raisetry:cmds:list[str]=["torchrun",f"--nnodes={run_info.num_nodes}",f"--node-rank={run_info.node_rank}",f"--nproc-per-node={run_info.gpus_per_node}",f"--master-addr={run_info.master_address}",f"--master-port={run_info.master_port}",]cmds.extend(ctx.args)_run_subprocess(cmds,rank=run_info.node_rank)exceptException:logger.exception(f"`torchrun` failed (Rank: {run_info.node_rank})!")raise
[docs]defaccelerate(ctx:typer.Context,level:cli_utils.LOG_LEVEL_TYPE=None,)->None:"""Starts `accelerate` sub-process w/ automatically configured common params. Args: ctx: The Typer context object. level: The logging level for the specified command. """try:run_info:_ProcessRunInfo=_detect_process_run_info(os.environ.copy())except(ValueError,RuntimeError):logger.exception("Failed to detect process run info!")raisetry:accelerate_subcommand:Optional[str]=Noneextra_args=copy.deepcopy(ctx.args)if(len(extra_args)>0andlen(extra_args[0])>0andnotextra_args[0].startswith("-")):# Copy sub-commands like "launch" to insert them right after `accelerate`# ("accelerate launch ...")accelerate_subcommand=extra_args.pop(0)cmds:list[str]=(["accelerate"]+([accelerate_subcommand]ifaccelerate_subcommandisnotNoneelse[])+[f"--num_machines={run_info.num_nodes}",f"--machine_rank={run_info.node_rank}",f"--num_processes={run_info.total_gpus}",f"--main_process_ip={run_info.master_address}",f"--main_process_port={run_info.master_port}",])cmds.extend(extra_args)_run_subprocess(cmds,rank=run_info.node_rank)exceptException:logger.exception(f"`accelerate` failed (Rank: {run_info.node_rank})!")raise
## Helper functions#def_detect_process_run_info(env:dict[str,str])->_ProcessRunInfo:"""Detects process run info. Uses known environment variables to detect common runtime parameters. Args: env: All environment variables. Returns: Process run info. Raises: ValueError: If any of the required environment variables are missing or invalid. RuntimeError: If the node list is empty, or there are issues with backend detection. """# Detect the process run info depending on the runtime environment.# Each runtime environment is checked in the order of priority.process_run_info=_detect_skypilot_process_run_info(env)ifprocess_run_infoisNone:process_run_info=_detect_polaris_process_run_info(env)ifprocess_run_infoisNone:process_run_info=_detect_local_machine_process_run_info(env)ifprocess_run_infoisNone:raiseRuntimeError("Failed to detect process run info!")# Extra verification logic to make sure that the detected process run info is# consistent with the environment variables.# Will raise an exception if the detected process run info is not consistent._verify_process_run_info(process_run_info,env)returnprocess_run_infodef_run_subprocess(cmds:list[str],*,rank:int)->None:env_copy=os.environ.copy()start_time=time.perf_counter()logger.info(f"Running the command: {cmds}")p=Popen(cmds,env=env_copy,stdout=stdout,stderr=stderr,bufsize=1,universal_newlines=True,)rc=p.wait()duration_sec=time.perf_counter()-start_timeduration_str=f"Duration: {duration_sec:.1f} sec"ifrc!=0:logger.error(f"{cmds[0]} failed with exit code: {rc} ({duration_str}). Command: {cmds}")sys.exit(rc)logger.info(f"Successfully completed! (Rank: {rank}. {duration_str})")def_verify_process_run_info(run_info:_ProcessRunInfo,env:dict[str,str])->None:oumi_total_gpus:Optional[int]=_get_optional_int_env_var("OUMI_TOTAL_NUM_GPUS",env)oumi_num_nodes:Optional[int]=_get_optional_int_env_var("OUMI_NUM_NODES",env)oumi_master_address:Optional[str]=env.get("OUMI_MASTER_ADDR",None)ifoumi_master_addressisnotNoneandlen(oumi_master_address)==0:raiseValueError("Empty master address in 'OUMI_MASTER_ADDR'!")assertlen(run_info.node_ips)>0,"Empty list of nodes!"assertrun_info.node_rankisnotNoneifoumi_num_nodesisnotNoneandoumi_num_nodes!=run_info.num_nodes:raiseValueError("Inconsistent number of nodes: "f"{run_info.num_nodes} vs {oumi_num_nodes} in 'OUMI_NUM_NODES'.")elifoumi_total_gpusisnotNoneand(oumi_total_gpus!=run_info.total_gpus):raiseValueError("Inconsistent total number of GPUs: "f"{run_info.total_gpus} vs {oumi_total_gpus} ""in 'OUMI_TOTAL_NUM_GPUS'. "f"Nodes: {run_info.num_nodes}. GPU-s per node: {run_info.gpus_per_node}.")elifoumi_master_addressandoumi_master_addressnotinrun_info.node_ips:raiseValueError(f"Master address '{oumi_master_address}' not found in the list of nodes.")## Parse environment variables#def_detect_polaris_process_run_info(env:dict[str,str])->Optional[_ProcessRunInfo]:polaris_node_file=env.get("PBS_NODEFILE",None)ifpolaris_node_fileisNone:returnNonelogger.debug("Running in Polaris environment!")forenv_var_namein_POLARIS_ENV_VARS:ifenv.get(env_var_name,None)isNone:raiseValueError(f"Polaris environment variable '{env_var_name}' is not defined!")ifnotpolaris_node_file:raiseValueError("Empty value in the 'PBS_NODEFILE' environment variable!")withopen(polaris_node_file)asf:nodes_str=f.read()node_ips=_parse_nodes_str(nodes_str)iflen(node_ips)==0:raiseRuntimeError("Empty list of nodes in 'PBS_NODEFILE'!")gpus_per_node=4# Per Polaris spec.node_rank=_get_optional_int_env_var("PMI_RANK",env)ifnode_rankisNone:node_rank=0return_ProcessRunInfo(node_rank=node_rank,world_info=_WorldInfo(num_nodes=len(node_ips),gpus_per_node=gpus_per_node),master_address=node_ips[0],master_port=_DEFAULT_MASTER_PORT,node_ips=node_ips,)def_detect_skypilot_process_run_info(env:dict[str,str])->Optional[_ProcessRunInfo]:node_rank:Optional[int]=_get_optional_int_env_var("SKYPILOT_NODE_RANK",env)ifnode_rankisNone:returnNonelogger.debug("Running in SkyPilot environment!")forenv_var_namein_SKY_ENV_VARS:ifenv.get(env_var_name,None)isNone:raiseValueError(f"SkyPilot environment variable '{env_var_name}' is not defined!")node_ips=_parse_nodes_str(env.get("SKYPILOT_NODE_IPS",""))iflen(node_ips)==0:raiseRuntimeError("Empty list of nodes in 'SKYPILOT_NODE_IPS'!")gpus_per_node=_get_positive_int_env_var("SKYPILOT_NUM_GPUS_PER_NODE",env)return_ProcessRunInfo(node_rank=node_rank,world_info=_WorldInfo(num_nodes=len(node_ips),gpus_per_node=gpus_per_node),master_address=node_ips[0],master_port=_DEFAULT_MASTER_PORT,node_ips=node_ips,)def_detect_local_machine_process_run_info(env:dict[str,str])->_ProcessRunInfo:importtorch# Importing torch takes time so only load it in this scenario.# Attempt to produce a local configurationifnottorch.cuda.is_available():raiseRuntimeError("No supported distributed backends found and no GPUs on local machine!")num_gpus_available=torch.cuda.device_count()ifnum_gpus_available>0:oumi_num_nodes=1oumi_master_address=env.get(_MASTER_ADDR_ENV,_DEFAULT_MASTER_ADDR)oumi_master_port=int(env.get(_MASTER_PORT_ENV,_DEFAULT_MASTER_PORT))node_rank=0gpus_per_node=num_gpus_availablenode_ips=[oumi_master_address]cli_utils.configure_common_env_vars()else:raiseRuntimeError("CUDA available but no GPUs found on local machine!")return_ProcessRunInfo(node_rank=node_rank,world_info=_WorldInfo(num_nodes=oumi_num_nodes,gpus_per_node=gpus_per_node),master_address=oumi_master_address,master_port=oumi_master_port,node_ips=node_ips,)## Private helper functions to parse environment variables#def_get_optional_int_env_var(var_name:str,env:dict[str,str])->Optional[int]:str_value=env.get(var_name,None)ifstr_valueisNone:returnNonetry:int_value=int(str_value)exceptValueErrorase:raiseValueError(f"Environment variable '{var_name}' is not an integer!")fromereturnint_valuedef_get_int_env_var(var_name:str,env:dict[str,str])->int:int_value=_get_optional_int_env_var(var_name,env)ifint_valueisNone:raiseValueError(f"Environment variable '{var_name}' is not defined!")returnint_valuedef_get_positive_int_env_var(var_name:str,env:dict[str,str])->int:int_value=_get_int_env_var(var_name,env)ifnot(int_value>0):raiseValueError(f"Environment variable '{var_name}' is not positive: {int_value}!")returnint_valuedef_parse_nodes_str(nodes_str:str)->list[str]:node_ips=[x.strip()forxinnodes_str.split("\n")]node_ips=[xforxinnode_ipsiflen(x)>0]returnnode_ips