Source code for oumi.launcher.launcher
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from oumi.core.configs import JobConfig
from oumi.core.launcher import BaseCloud, BaseCluster, JobStatus
from oumi.core.registry import REGISTRY, RegistryType
[docs]
class Launcher:
"""A class for managing the lifecycle of jobs on different clouds."""
def __init__(self):
"""Initializes a new instance of the Launcher class."""
self._clouds: dict[str, BaseCloud] = dict()
self._initialize_new_clouds()
def _initialize_new_clouds(self) -> None:
"""Initializes new clouds. Existing clouds are not re-initialized."""
for name, builder in REGISTRY.get_all(RegistryType.CLOUD).items():
if name not in self._clouds:
self._clouds[name] = builder()
def _get_cloud_by_name(self, cloud: str) -> BaseCloud:
"""Gets the cloud instance for the specified cloud name."""
if cloud not in self._clouds:
cloud_builder = REGISTRY.get(cloud, RegistryType.CLOUD)
if not cloud_builder:
raise ValueError(f"Cloud {cloud} not found in the registry.")
self._clouds[cloud] = cloud_builder()
return self._clouds[cloud]
[docs]
def cancel(self, job_id: str, cloud_name: str, cluster_name: str) -> JobStatus:
"""Cancels the specified job."""
cloud = self._get_cloud_by_name(cloud_name)
cluster = cloud.get_cluster(cluster_name)
if not cluster:
raise ValueError(f"Cluster {cluster_name} not found.")
return cluster.cancel_job(job_id)
[docs]
def down(self, cloud_name: str, cluster_name: str) -> None:
"""Turns down the specified cluster."""
cloud = self._get_cloud_by_name(cloud_name)
cluster = cloud.get_cluster(cluster_name)
if not cluster:
raise ValueError(f"Cluster {cluster_name} not found.")
cluster.down()
[docs]
def get_cloud(self, job_or_cloud: Union[JobConfig, str]) -> BaseCloud:
"""Gets the cloud instance for the specified job."""
if isinstance(job_or_cloud, str):
return self._get_cloud_by_name(job_or_cloud)
return self._get_cloud_by_name(job_or_cloud.resources.cloud)
[docs]
def run(self, job: JobConfig, cluster_name: str) -> JobStatus:
"""Runs the specified job on the specified cluster.
Args:
job: The job configuration.
cluster_name: The name of the cluster to run the job on.
Returns:
Optional[JobStatus]: The status of the job.
"""
cloud = self.get_cloud(job)
cluster = cloud.get_cluster(cluster_name)
if not cluster:
raise ValueError(f"Cluster {cluster_name} not found.")
return cluster.run_job(job)
[docs]
def status(
self,
cloud: Optional[str] = None,
cluster: Optional[str] = None,
id: Optional[str] = None,
) -> dict[str, list[JobStatus]]:
"""Gets the status of all jobs across all clusters.
Args:
cloud: If specified, filters all jobs to only those on the specified cloud.
cluster: If specified, filters all jobs to only those on the specified
cluster.
id: If specified, filters all jobs to only those with the specified ID.
Returns:
dict[str, list(JobStatus)]: The status of all jobs, indexed by cloud name.
"""
# Pick up any newly registered cloud builders.
self._initialize_new_clouds()
statuses: dict[str, list[JobStatus]] = {}
for cloud_name, target_cloud in self._clouds.items():
# Ignore clouds not matching the filter criteria.
if cloud and cloud_name != cloud:
continue
statuses[cloud_name] = []
for target_cluster in target_cloud.list_clusters():
# Ignore clusters not matching the filter criteria.
if cluster and target_cluster.name() != cluster:
continue
for job in target_cluster.get_jobs():
# Ignore jobs not matching the filter criteria.
if id and job.id != id:
continue
statuses[cloud_name].append(job)
return statuses
[docs]
def stop(self, cloud_name: str, cluster_name: str) -> None:
"""Stops the specified cluster."""
cloud = self._get_cloud_by_name(cloud_name)
cluster = cloud.get_cluster(cluster_name)
if not cluster:
raise ValueError(f"Cluster {cluster_name} not found.")
cluster.stop()
[docs]
def up(
self, job: JobConfig, cluster_name: Optional[str], **kwargs
) -> tuple[BaseCluster, JobStatus]:
"""Creates a new cluster and starts the specified job on it."""
cloud = self.get_cloud(job)
job_status = cloud.up_cluster(job, cluster_name, **kwargs)
cluster = cloud.get_cluster(job_status.cluster)
if not cluster:
raise RuntimeError(f"Cluster {job_status.cluster} not found.")
return (cluster, job_status)
[docs]
def which_clouds(self) -> list[str]:
"""Gets the names of all clouds in the registry."""
return [name for name, _ in REGISTRY.get_all(RegistryType.CLOUD).items()]
LAUNCHER = Launcher()
# Explicitly expose the public methods of the default Launcher instance.
#: A convenience method for Launcher.cancel.
cancel = LAUNCHER.cancel
#: A convenience method for Launcher.down.
down = LAUNCHER.down
#: A convenience method for Launcher.get_cloud.
get_cloud = LAUNCHER.get_cloud
#: A convenience method for Launcher.run.
run = LAUNCHER.run
#: A convenience method for Launcher.status.
status = LAUNCHER.status
#: A convenience method for Launcher.stop.
stop = LAUNCHER.stop
#: A convenience method for Launcher.up.
up = LAUNCHER.up
#: A convenience method for Launcher.which_clouds.
which_clouds = LAUNCHER.which_clouds