Source code for oumi.launcher.clouds.polaris_cloud
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromdataclassesimportdataclassfromtypingimportOptionalfromoumi.core.configsimportJobConfigfromoumi.core.launcherimportBaseCloud,BaseCluster,JobStatusfromoumi.core.registryimportregister_cloud_builderfromoumi.launcher.clients.polaris_clientimportPolarisClientfromoumi.launcher.clusters.polaris_clusterimportPolarisClusterfromoumi.utils.loggingimportlogger@dataclassclass_ClusterInfo:"""Dataclass to hold information about a cluster."""queue:struser:strdefname(self):returnf"{self.queue}.{self.user}"
[docs]classPolarisCloud(BaseCloud):"""A resource pool for managing the Polaris ALCF job queues."""def__init__(self):"""Initializes a new instance of the PolarisCloud class."""# A mapping from user names to Polaris Clients.self._clients={}# A mapping from cluster names to Polaris Cluster instances.self._clusters={}# Check if any users have open SSH tunnels to Polaris.foruserinPolarisClient.get_active_users():self.initialize_clusters(user)def_parse_cluster_name(self,name:str)->_ClusterInfo:"""Parses the cluster name into queue and user components. Args: name: The name of the cluster. Returns: _ClusterInfo: The parsed cluster information. """name_splits=name.split(".")iflen(name_splits)!=2:raiseValueError(f"Invalid cluster name: {name}. Must be in the format 'queue.user'.")queue,user=name_splitsreturn_ClusterInfo(queue,user)def_get_or_create_client(self,user:str)->PolarisClient:"""Gets the client for the specified user, or creates one if it doesn't exist. Args: user: The user to get the client for. Returns: PolarisClient: The client instance. """ifusernotinself._clients:self._clients[user]=PolarisClient(user)returnself._clients[user]def_get_or_create_cluster(self,name:str)->PolarisCluster:"""Gets the cluster with the specified name, or creates one if it doesn't exist. Args: name: The name of the cluster. Returns: PolarisCluster: The cluster instance. """ifnamenotinself._clusters:cluster_info=self._parse_cluster_name(name)self._clusters[name]=PolarisCluster(name,self._get_or_create_client(cluster_info.user))returnself._clusters[name]
[docs]definitialize_clusters(self,user)->list[BaseCluster]:"""Initializes clusters for the specified user for all Polaris queues. Args: user: The user to initialize clusters for. Returns: List[PolarisCluster]: The list of initialized clusters. """clusters=[]queue_set={q.valueforqinPolarisClient.SupportedQueues}forqinqueue_set:name=f"{q}.{user}"cluster=self._get_or_create_cluster(name)clusters.append(cluster)returnclusters
[docs]defup_cluster(self,job:JobConfig,name:Optional[str],**kwargs)->JobStatus:"""Creates a cluster and starts the provided Job."""ifnotjob.user:raiseValueError("User must be provided in the job config.")# The default queue is PROD.cluster_info=_ClusterInfo(PolarisClient.SupportedQueues.PROD.value,job.user)ifname:cluster_info=self._parse_cluster_name(name)ifcluster_info.user!=job.user:raiseValueError(f"Invalid cluster name: {name}. ""User must match the provided job user.")else:logger.warning("No cluster name provided. Using default queue: "f"{PolarisClient.SupportedQueues.PROD.value}.")cluster=self._get_or_create_cluster(cluster_info.name())job_status=cluster.run_job(job)ifnotjob_status:raiseRuntimeError("Failed to start job.")returnjob_status
[docs]defget_cluster(self,name)->Optional[BaseCluster]:"""Gets the cluster with the specified name, or None if not found."""clusters=self.list_clusters()forclusterinclusters:ifcluster.name()==name:returnclusterreturnNone
[docs]deflist_clusters(self)->list[BaseCluster]:"""Lists the active clusters on this cloud."""returnlist(self._clusters.values())
@register_cloud_builder("polaris")defpolaris_cloud_builder()->PolarisCloud:"""Builds a PolarisCloud instance."""returnPolarisCloud()