# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportOptionalfromoumi.core.configsimportJobConfigfromoumi.core.launcherimportBaseCloud,BaseCluster,JobStatusfromoumi.core.registryimportregister_cloud_builderfromoumi.launcher.clients.slurm_clientimportSlurmClientfromoumi.launcher.clusters.slurm_clusterimportSlurmCluster
[docs]classSlurmCloud(BaseCloud):"""A resource pool for managing jobs in Slurm clusters."""def__init__(self):"""Initializes a new instance of the SlurmCloud class."""# A mapping from cluster names to Slurm Cluster instances.self._clusters={}# Initialize default connections.self.initialize_clusters()def_get_or_create_cluster(self,name:str)->SlurmCluster:"""Gets the cluster with the specified name, or creates one if it doesn't exist. Args: name: The name of the cluster. Returns: SlurmCluster: The cluster instance. """ifnamenotinself._clusters:cluster_info=SlurmCluster.parse_cluster_name(name)self._clusters[name]=SlurmCluster(name,SlurmClient(user=cluster_info.user,slurm_host=cluster_info.hostname,cluster_name=cluster_info.name,),)returnself._clusters[name]
[docs]definitialize_clusters(self)->list[BaseCluster]:"""Initializes clusters for the specified user for all Slurm queues. Returns: List[SlurmCluster]: The list of initialized clusters. """connections=SlurmCluster.get_slurm_connections()clusters=[]forcinconnections:cluster=self._get_or_create_cluster(c.name)clusters.append(cluster)returnclusters
[docs]defup_cluster(self,job:JobConfig,name:Optional[str],**kwargs)->JobStatus:"""Creates a cluster and starts the provided Job."""ifnotjob.user:raiseValueError("User must be provided in the job config.")ifname:cluster_info=SlurmCluster.parse_cluster_name(name)ifcluster_info.user!=job.user:raiseValueError(f"Invalid cluster name: `{name}`. "f"User must match the provided job user: `{job.user}`.")else:raiseValueError("A cluster name must be provided for Slurm. ""Cluster names are of the form 'user@hostname'.")cluster=self._get_or_create_cluster(cluster_info.name)job_status=cluster.run_job(job)ifnotjob_status:raiseRuntimeError("Failed to start job.")returnjob_status
[docs]defget_cluster(self,name)->Optional[BaseCluster]:"""Gets the cluster with the specified name, or None if not found."""clusters=self.list_clusters()forclusterinclusters:ifcluster.name()==name:returnclusterreturnNone
[docs]deflist_clusters(self)->list[BaseCluster]:"""Lists the active clusters on this cloud."""returnlist(self._clusters.values())
@register_cloud_builder("slurm")defslurm_cloud_builder()->SlurmCloud:"""Builds a SlurmCloud instance."""returnSlurmCloud()