# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportOptional,Unionfromoumi.core.configsimportJobConfigfromoumi.core.launcherimportBaseCloud,BaseCluster,JobStatusfromoumi.core.registryimportREGISTRY,RegistryType
[docs]classLauncher:"""A class for managing the lifecycle of jobs on different clouds."""def__init__(self):"""Initializes a new instance of the Launcher class."""self._clouds:dict[str,BaseCloud]=dict()self._initialize_new_clouds()def_initialize_new_clouds(self)->None:"""Initializes new clouds. Existing clouds are not re-initialized."""forname,builderinREGISTRY.get_all(RegistryType.CLOUD).items():ifnamenotinself._clouds:self._clouds[name]=builder()def_get_cloud_by_name(self,cloud:str)->BaseCloud:"""Gets the cloud instance for the specified cloud name."""ifcloudnotinself._clouds:cloud_builder=REGISTRY.get(cloud,RegistryType.CLOUD)ifnotcloud_builder:raiseValueError(f"Cloud {cloud} not found in the registry.")self._clouds[cloud]=cloud_builder()returnself._clouds[cloud]
[docs]defcancel(self,job_id:str,cloud_name:str,cluster_name:str)->JobStatus:"""Cancels the specified job."""cloud=self._get_cloud_by_name(cloud_name)cluster=cloud.get_cluster(cluster_name)ifnotcluster:raiseValueError(f"Cluster {cluster_name} not found.")returncluster.cancel_job(job_id)
[docs]defdown(self,cloud_name:str,cluster_name:str)->None:"""Turns down the specified cluster."""cloud=self._get_cloud_by_name(cloud_name)cluster=cloud.get_cluster(cluster_name)ifnotcluster:raiseValueError(f"Cluster {cluster_name} not found.")cluster.down()
[docs]defget_cloud(self,job_or_cloud:Union[JobConfig,str])->BaseCloud:"""Gets the cloud instance for the specified job."""ifisinstance(job_or_cloud,str):returnself._get_cloud_by_name(job_or_cloud)returnself._get_cloud_by_name(job_or_cloud.resources.cloud)
[docs]defrun(self,job:JobConfig,cluster_name:str)->JobStatus:"""Runs the specified job on the specified cluster. Args: job: The job configuration. cluster_name: The name of the cluster to run the job on. Returns: Optional[JobStatus]: The status of the job. """cloud=self.get_cloud(job)cluster=cloud.get_cluster(cluster_name)ifnotcluster:raiseValueError(f"Cluster {cluster_name} not found.")returncluster.run_job(job)
[docs]defstatus(self,cloud:Optional[str]=None,cluster:Optional[str]=None,id:Optional[str]=None,)->dict[str,list[JobStatus]]:"""Gets the status of all jobs across all clusters. Args: cloud: If specified, filters all jobs to only those on the specified cloud. cluster: If specified, filters all jobs to only those on the specified cluster. id: If specified, filters all jobs to only those with the specified ID. Returns: dict[str, list(JobStatus)]: The status of all jobs, indexed by cloud name. """# Pick up any newly registered cloud builders.self._initialize_new_clouds()statuses:dict[str,list[JobStatus]]={}forcloud_name,target_cloudinself._clouds.items():# Ignore clouds not matching the filter criteria.ifcloudandcloud_name!=cloud:continuestatuses[cloud_name]=[]fortarget_clusterintarget_cloud.list_clusters():# Ignore clusters not matching the filter criteria.ifclusterandtarget_cluster.name()!=cluster:continueforjobintarget_cluster.get_jobs():# Ignore jobs not matching the filter criteria.ifidandjob.id!=id:continuestatuses[cloud_name].append(job)returnstatuses
[docs]defstop(self,cloud_name:str,cluster_name:str)->None:"""Stops the specified cluster."""cloud=self._get_cloud_by_name(cloud_name)cluster=cloud.get_cluster(cluster_name)ifnotcluster:raiseValueError(f"Cluster {cluster_name} not found.")cluster.stop()
[docs]defup(self,job:JobConfig,cluster_name:Optional[str],**kwargs)->tuple[BaseCluster,JobStatus]:"""Creates a new cluster and starts the specified job on it."""cloud=self.get_cloud(job)job_status=cloud.up_cluster(job,cluster_name,**kwargs)cluster=cloud.get_cluster(job_status.cluster)ifnotcluster:raiseRuntimeError(f"Cluster {job_status.cluster} not found.")return(cluster,job_status)
[docs]defwhich_clouds(self)->list[str]:"""Gets the names of all clouds in the registry."""return[nameforname,_inREGISTRY.get_all(RegistryType.CLOUD).items()]
LAUNCHER=Launcher()# Explicitly expose the public methods of the default Launcher instance.#: A convenience method for Launcher.cancel.cancel=LAUNCHER.cancel#: A convenience method for Launcher.down.down=LAUNCHER.down#: A convenience method for Launcher.get_cloud.get_cloud=LAUNCHER.get_cloud#: A convenience method for Launcher.run.run=LAUNCHER.run#: A convenience method for Launcher.status.status=LAUNCHER.status#: A convenience method for Launcher.stop.stop=LAUNCHER.stop#: A convenience method for Launcher.up.up=LAUNCHER.up#: A convenience method for Launcher.which_clouds.which_clouds=LAUNCHER.which_clouds