Source code for oumi.launcher.clients.polaris_client
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importfunctoolsimportreimportsubprocessimporttimefromdataclassesimportdataclassfromenumimportEnumfromgetpassimportgetpassfrompathlibimportPathfromtypingimportOptionalimportpexpectfromoumi.core.launcherimportJobStatusfromoumi.utils.loggingimportlogger_CTRL_PATH="-S ~/.ssh/control-%h-%p-%r"class_PolarisAuthException(Exception):passdef_check_connection(user:str):"""Checks if the connection is still open."""ssh_cmd=f"ssh {_CTRL_PATH} -O check {user}@polaris.alcf.anl.gov"try:child=subprocess.run(ssh_cmd,shell=True,capture_output=True,timeout=10,)exceptsubprocess.TimeoutExpired:raise_PolarisAuthException("Timeout while checking connection.")ifchild.returncode==0:returnraise_PolarisAuthException("Connection to Polaris is closed.")
[docs]@dataclassclassPolarisResponse:"""A response from Polaris."""stdout:strstderr:strexit_code:int
[docs]defretry_auth(user_function):"""Decorator to ensure auth is fresh before calling a function."""@functools.wraps(user_function)defwrapper(self,*args,**kwargs):self._refresh_creds()returnuser_function(self,*args,**kwargs)returnwrapper
[docs]classPolarisClient:"""A client for communicating with Polaris at ALCF."""
[docs]classSupportedQueues(Enum):"""Enum representing the supported queues on Polaris. For more details, see: https://docs.alcf.anl.gov/polaris/running-jobs/#queues """# The demand queue can only be used with explicit permission from ALCF.# Do not use this queue unless you have been granted permission.DEMAND="demand"DEBUG="debug"DEBUG_SCALING="debug-scaling"PREEMPTABLE="preemptable"PROD="prod"
_FINISHED_STATUS="F"_PROD_QUEUES={"small","medium","large","backfill-small","backfill-medium","backfill-large",}def__init__(self,user:str):"""Initializes a new instance of the PolarisClient class. Args: user: The user to act as. """self._user=userself._refresh_creds()def_split_status_line(self,line:str,metadata:str)->JobStatus:"""Splits a status line into a JobStatus object. The expected order of job fields is: 0. Job ID 1. User 2. Queue 3. Job Name 4. Session ID 5. Node Count 6. Tasks 7. Required Memory 8. Required Time 9. Status 10. Ellapsed Time Args: line: The line to split. metadata: Additional metadata to attach to the job status. Returns: A JobStatus object. """fields=re.sub(" +"," ",line.strip()).split(" ")iflen(fields)!=11:raiseValueError(f"Invalid status line: {line}. "f"Expected 11 fields, but found {len(fields)}.")returnJobStatus(id=self._get_short_job_id(fields[0]),name=fields[3],status=fields[9],cluster=fields[2],metadata=metadata,done=fields[9]==self._FINISHED_STATUS,)def_get_short_job_id(self,job_id:str)->str:"""Gets the short form of the job ID. Polaris Job IDs should be of the form: `2037042.polaris-pbs-01.hsn.cm.polaris.alcf.anl.gov` where the shortened ID is `2037042`. Args: job_id: The job ID to shorten. Returns: The short form of the job ID. """if"."notinjob_id:returnjob_idreturnjob_id.split(".")[0]def_refresh_creds(self):"""Refreshes the credentials for the client."""try:_check_connection(self._user)# We have fresh credentials, so we return early.returnexcept_PolarisAuthException:logger.warning("No connection found. Establishing a new SSH tunnel...")ssh_cmd=(f'ssh -f -N -M {_CTRL_PATH} -o "ControlPersist 4h" 'f"{self._user}@polaris.alcf.anl.gov")child=pexpect.spawn(ssh_cmd)child.expect("Password:")child.sendline(getpass(prompt=f"Polaris passcode for {self._user}: "))child.expect([pexpect.EOF,pexpect.TIMEOUT],timeout=10)output=child.beforechild.close()exit_code=child.exitstatusifexit_code!=0:logger.error(f"Credential error: {output}")raiseRuntimeError("Failed to refresh Polaris credentials.")
[docs]@staticmethoddefget_active_users()->list[str]:"""Gets the list of users with an open SSH tunnel to Polaris. Returns: A list of users. """# List all active users with an open SSH tunnel to Polaris.command="ls ~/.ssh/ | egrep 'control-polaris.alcf.anl.gov-.*-.*'"result=subprocess.run(command,shell=True,capture_output=True)ifresult.returncode!=0:return[]ssh_tunnel_pattern=r"control-polaris.alcf.anl.gov-[^-]*-(.*)"lines=result.stdout.decode("utf-8").strip().split("\n")users=set()forlineinlines:match=re.match(ssh_tunnel_pattern,line.strip())ifmatch:users.add(match.group(1))returnlist(users)
[docs]@retry_authdefrun_commands(self,commands:list[str])->PolarisResponse:"""Runs the provided commands in a single SSH command. Args: commands: The commands to run. """ssh_cmd=f"ssh {_CTRL_PATH}{self._user}@polaris.alcf.anl.gov << 'EOF'"eof_suffix="EOF"new_cmd="\n".join([ssh_cmd,*commands,eof_suffix])start_time:float=time.perf_counter()try:logger.debug(f"Running commands:\n{new_cmd}")child=subprocess.run(new_cmd,shell=True,capture_output=True,timeout=180,# time in seconds)duration_str=self._compute_duration_debug_str(start_time)ifchild.returncode==0:logger.debug(f"Commands successfully finished! {duration_str}")else:logger.error(f"Commands failed with code: {child.returncode}! {duration_str}")returnPolarisResponse(stdout=child.stdout.decode("utf-8"),stderr=child.stderr.decode("utf-8"),exit_code=child.returncode,)exceptsubprocess.TimeoutExpired:duration_str=self._compute_duration_debug_str(start_time)logger.exception(f"Commands timed out ({duration_str})! {new_cmd}")returnPolarisResponse(stdout="",stderr=f"Timeout while running command: {new_cmd}",exit_code=1,)exceptException:duration_str=self._compute_duration_debug_str(start_time)logger.exception(f"Command failed ({duration_str})! {new_cmd}")raise
[docs]defsubmit_job(self,job_path:str,working_dir:str,node_count:int,queue:SupportedQueues,name:Optional[str],)->str:"""Submits the specified job script to Polaris. Args: job_path: The path to the job script to submit. working_dir: The working directory to submit the job from. node_count: The number of nodes to use for the job. queue: The name of the queue to submit the job to. name: The name of the job (optional). Returns: The ID of the submitted job. """optional_name_args=""ifname:optional_name_args=f"-N {name}"qsub_cmd=(f"qsub -l select={node_count}:system=polaris -q {queue.value}"f" {optional_name_args}{job_path}")result=self.run_commands([f"cd {working_dir}",qsub_cmd])ifresult.exit_code!=0:raiseRuntimeError(f"Failed to submit job. stderr: {result.stderr}")returnself._get_short_job_id(result.stdout.strip())
[docs]deflist_jobs(self,queue:SupportedQueues)->list[JobStatus]:"""Lists a list of job statuses for the given queue. Returns: A list of dictionaries, each containing the status of a cluster. """command=f"qstat -s -x -w -u {self._user}"result=self.run_commands([command])ifresult.exit_code!=0:raiseRuntimeError(f"Failed to list jobs. stderr: {result.stderr}")# Parse STDOUT to retrieve job statuses.lines=result.stdout.strip().split("\n")jobs=[]# Non-empty responses should have at least 4 lines.iflen(lines)<4:returnjobsmetadata_header=lines[1:4]job_lines=lines[4:]line_number=0whileline_number<len(job_lines)-1:line=job_lines[line_number]# Every second line is metadata.metadata_line=job_lines[line_number+1]job_metadata="\n".join(metadata_header+[line,metadata_line])status=self._split_status_line(line,job_metadata)ifstatus.cluster==queue.value:jobs.append(status)elif(queue==self.SupportedQueues.PRODandstatus.clusterinself._PROD_QUEUES):jobs.append(status)line_number+=2ifline_number!=len(job_lines):raiseRuntimeError("At least one job status was not parsed.")returnjobs
[docs]defget_job(self,job_id:str,queue:SupportedQueues)->Optional[JobStatus]:"""Gets the specified job's status. Args: job_id: The ID of the job to get. queue: The name of the queue to search. Returns: The job status if found, None otherwise. """job_list=self.list_jobs(queue)forjobinjob_list:ifjob.id==job_id:returnjobreturnNone
[docs]defcancel(self,job_id,queue:SupportedQueues)->Optional[JobStatus]:"""Cancels the specified job. Args: job_id: The ID of the job to cancel. queue: The name of the queue to search. Returns: The job status if found, None otherwise. """command=f"qdel {job_id}"result=self.run_commands([command])ifresult.exit_code!=0:raiseRuntimeError(f"Failed to cancel job. stderr: {result.stderr}")returnself.get_job(job_id,queue)
[docs]@retry_authdefput_recursive(self,source:str,destination:str)->None:"""Puts the specified file/directory to the remote path using rsync. Args: source: The local file/directory to write. destination: The remote path to write the file/directory to. """ifPath(source).is_dir():self.run_commands([f"mkdir -p {destination}"])tests_dir=Path(source)/"tests"git_ignore=Path(source)/".gitignore"rsync_cmd_list=[f'rsync -e "ssh {_CTRL_PATH}" -avz --delete ']ifgit_ignore.is_file():rsync_cmd_list.append(f"--exclude-from {str(git_ignore)} ")iftests_dir.is_dir():rsync_cmd_list.append(f"--exclude {str(tests_dir)} ")rsync_cmd_list.append(f"{source} ")rsync_cmd_list.append(f"{self._user}@polaris.alcf.anl.gov:{destination}")rsync_cmd="".join(rsync_cmd_list)logger.info(f"Running rsync command: {rsync_cmd} ...")try:child=subprocess.run(rsync_cmd,shell=True,capture_output=True,timeout=300,)logger.info(f"Rsync command completed with exit code: {child.returncode}")ifchild.returncode!=0:parsed_stderr=child.stderr.decode("utf-8")ifchild.stderrelse""raiseRuntimeError(f"Rsync failed. stderr: {parsed_stderr}")exceptsubprocess.TimeoutExpired:raiseRuntimeError("Timeout while running rsync command.")
[docs]defput(self,file_contents:str,destination:str)->None:"""Puts the specified file contents to the remote path. Args: file_contents: The contents of the file to write. destination: The remote path to write the file to. """destination_path=Path(destination)parent_dir=destination_path.parentdir_cmd=f"mkdir -p {parent_dir}"create_cmd=f"touch {destination}"write_command=f'cat <<"SCRIPTFILETAG" > {destination}'file_suffix="SCRIPTFILETAG"cmds=[dir_cmd,create_cmd,write_command,file_contents,file_suffix]result=self.run_commands(cmds)ifresult.exit_code!=0:raiseRuntimeError(f"Failed to write file. stderr: {result.stderr}")