Source code for oumi.launcher.clusters.slurm_cluster
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importosimportreimporttimeimportuuidfromdataclassesimportdataclassfromdatetimeimportdatetimefromfunctoolsimportreducefrompathlibimportPathfromtypingimportAny,Optionalfromoumi.core.configsimportJobConfigfromoumi.core.launcherimportBaseCluster,JobStatusfromoumi.launcher.clients.slurm_clientimportSlurmClientfromoumi.utils.loggingimportlogger_OUMI_SLURM_CONNECTIONS="OUMI_SLURM_CONNECTIONS"def_format_date(date:datetime)->str:"""Formats the provided date as a string. Args: date: The date to format. Returns: The formatted date. """returndate.strftime("%Y%m%d_%H%M%S%f")def_last_sbatch_line(script:list[str])->int:"""Finds the last SBATCH instruction line in the script. Args: script: The lines of the script. Returns: The index of the last SBATCH instruction line. -1 if not found. """returnreduce(lambdaacc,val:val[0]ifval[1].startswith("#SBATCH")elseacc,enumerate(script),-1,)def_create_job_script(job:JobConfig)->str:"""Creates a job script for the specified job. Args: job: The job to create a script for. Returns: The script as a string. """setup_lines=[]ifnotjob.setupelsejob.setup.strip().split("\n")run_lines=job.run.strip().split("\n")# Find the last SBATCH instruction line.last_run_sbatch=_last_sbatch_line(run_lines)+1last_setup_sbatch=_last_sbatch_line(setup_lines)+1# Inject environment variables into the script after SBATCH instructions.env_lines=[f"export {key}={value}"forkey,valueinjob.envs.items()]# Pad the environment variables with newlines.env_lines=[""]+env_lines+[""]ifenv_lineselse[]# Generate the job script.# The script should have the following structure:# 1. SBATCH instructions from Setup and Run commands (in that order).# 2. Environment variables.# 3. Setup commands.# 4. Run commands.output_lines=(setup_lines[:last_setup_sbatch]+run_lines[:last_run_sbatch]+env_lines+setup_lines[last_setup_sbatch:]+run_lines[last_run_sbatch:])# Always start the script with #!/bin/bash.script_prefix="#!/bin/bash"iflen(output_lines)>0:ifnotoutput_lines[0].startswith("script_prefix"):output_lines.insert(0,script_prefix)# Join each line. Always end the script with a new line.return"\n".join(output_lines)+"\n"def_validate_job_config(job:JobConfig)->None:"""Validates the provided job configuration. Args: job: The job to validate. """ifnotjob.user:raiseValueError("User must be provided for Slurm jobs.")ifnotjob.working_dir:raiseValueError("Working directory must be provided for Slurm jobs.")ifnotjob.run:raiseValueError("Run script must be provided for Slurm jobs.")ifjob.num_nodes<1:raiseValueError("Number of nodes must be at least 1.")ifjob.resources.cloud!="slurm":raiseValueError(f"`Resources.cloud` must be `slurm`. "f"Unsupported cloud: {job.resources.cloud}")# Warn that other resource parameters are unused for Slurm.ifjob.resources.region:logger.warning("Region is unused for Slurm jobs.")ifjob.resources.zone:logger.warning("Zone is unused for Slurm jobs.")ifjob.resources.accelerators:logger.warning("Accelerators are unused for Slurm jobs.")ifjob.resources.cpus:logger.warning("CPUs are unused for Slurm jobs.")ifjob.resources.memory:logger.warning("Memory is unused for Slurm jobs.")ifjob.resources.instance_type:logger.warning("Instance type is unused for Slurm jobs.")ifjob.resources.disk_size:logger.warning("Disk size is unused for Slurm jobs.")ifjob.resources.instance_type:logger.warning("Instance type is unused for Slurm jobs.")# Warn that storage mounts are currently unsupported.iflen(job.storage_mounts.items())>0:logger.warning("Storage mounts are currently unsupported for Slurm jobs.")
[docs]classSlurmCluster(BaseCluster):"""A cluster implementation backed by a Slurm scheduler."""
[docs]@dataclassclassConnectionInfo:"""Dataclass to hold information about a connection."""hostname:struser:str@propertydefname(self):"""Gets the name of the connection in the form user@hostname."""returnf"{self.user}@{self.hostname}"
def__init__(self,name:str,client:SlurmClient)->None:"""Initializes a new instance of the SlurmCluster class."""self._client=clientself._connection=self.parse_cluster_name(name)
[docs]def__eq__(self,other:Any)->bool:"""Checks if two SlurmClusters are equal."""ifnotisinstance(other,SlurmCluster):returnFalsereturnself.name()==other.name()
[docs]@staticmethoddefparse_cluster_name(name:str)->ConnectionInfo:"""Parses the cluster name into queue and user components. Args: name: The name of the cluster. Returns: _ConnectionInfo: The parsed cluster information. """# Expected format: <user>@<hostname>connection_regex=r"^([a-zA-Z0-9\.\-\_]+)\@([a-zA-Z0-9\.\-\_]+$)"match=re.match(connection_regex,name)ifnotmatch:raiseValueError(f"Invalid cluster name: {name}. Must be in the format 'user@hostname'.")returnSlurmCluster.ConnectionInfo(hostname=match.group(2),user=match.group(1))
[docs]@staticmethoddefget_slurm_connections()->list[ConnectionInfo]:"""Gets Slurm connections from the OUMI_SLURM_CONNECTIONS env variable."""connections_str=os.getenv(_OUMI_SLURM_CONNECTIONS,"")ifnotconnections_str:return[]valid_connections=[]forconnectionin[h.strip()forhinconnections_str.split(",")]:try:valid_connections.append(SlurmCluster.parse_cluster_name(connection))exceptValueError:logger.warning(f"Invalid Slurm connection string: {connection}. Skipping.")returnvalid_connections
[docs]defname(self)->str:"""Gets the name of the cluster."""returnself._connection.name
[docs]defget_job(self,job_id:str)->Optional[JobStatus]:"""Gets the jobs on this cluster if it exists, else returns None."""forjobinself.get_jobs():ifjob.id==job_id:returnjobreturnNone
[docs]defget_jobs(self)->list[JobStatus]:"""Lists the jobs on this cluster."""jobs=self._client.list_jobs()forjobinjobs:job.cluster=self._connection.namereturnjobs
[docs]defcancel_job(self,job_id:str)->JobStatus:"""Cancels the specified job on this cluster."""self._client.cancel(job_id)job=self.get_job(job_id)ifjobisNone:raiseRuntimeError(f"Job {job_id} not found.")returnjob
[docs]defrun_job(self,job:JobConfig)->JobStatus:"""Runs the specified job on this cluster. For Slurm this method consists of 5 parts: 1. Copy the working directory to ~/oumi_launcher/$JOB_NAME. 2. Check if there is a conda installation at /home/$USER/miniconda3/envs/oumi. If not, install it. 3. Copy all file mounts. 4. Create a job script with all env vars, setup, and run commands. 5. CD into the working directory and submit the job. Args: job: The job to run. Returns: JobStatus: The job status. """_validate_job_config(job)job_name=job.nameoruuid.uuid1().hexsubmission_time=_format_date(datetime.now())remote_working_dir=Path(f"~/oumi_launcher/{submission_time}")# Copy the working directory to ~/oumi_launcher/...self._client.put_recursive(job.working_dir,str(remote_working_dir))# Copy all file mounts.forremote_path,local_pathinjob.file_mounts.items():self._client.put_recursive(local_path,remote_path)# Create the job script by merging envs, setup, and run commands.job_script=_create_job_script(job)script_path=remote_working_dir/"oumi_job.sh"self._client.put(job_script,str(script_path))# Set the proper CHMOD permissions.self._client.run_commands([f"chmod +x {script_path}"])# Submit the job.job_id=self._client.submit_job(str(script_path),str(remote_working_dir),job.num_nodes,job_name,)max_retries=3wait_time=5for_inrange(max_retries):job_status=self.get_job(job_id)ifjob_statusisnotNone:returnjob_statuslogger.info(f"Job {job_id} not found. Retrying in {wait_time} seconds.")time.sleep(wait_time)job_status=self.get_job(job_id)ifjob_statusisNone:raiseRuntimeError(f"Job {job_id} not found after submission.")returnjob_status
[docs]defstop(self)->None:"""This is a no-op for Slurm clusters."""pass
[docs]defdown(self)->None:"""This is a no-op for Slurm clusters."""pass