Source code for oumi.launcher.clouds.sky_cloud
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, TypeVar
import sky
from oumi.core.configs import JobConfig
from oumi.core.launcher import BaseCloud, BaseCluster, JobStatus
from oumi.core.registry import register_cloud_builder
from oumi.launcher.clients.sky_client import SkyClient
from oumi.launcher.clusters.sky_cluster import SkyCluster
T = TypeVar("T")
[docs]
class SkyCloud(BaseCloud):
"""A resource pool capable of creating clusters using Sky Pilot."""
def __init__(self, cloud_name: str, client: SkyClient):
"""Initializes a new instance of the SkyCloud class."""
self._cloud_name = cloud_name
self._client = client
def _get_clusters_by_class(self, cloud_class: type[T]) -> list[BaseCluster]:
"""Gets the appropriate clusters of type T."""
return [
SkyCluster(cluster["name"], self._client)
for cluster in self._client.status()
if (
isinstance(cluster["handle"].launched_resources.cloud, cloud_class)
and cluster["status"] == sky.ClusterStatus.UP
)
]
[docs]
def up_cluster(self, job: JobConfig, name: Optional[str], **kwargs) -> JobStatus:
"""Creates a cluster and starts the provided Job."""
job_status = self._client.launch(job, name, **kwargs)
cluster = self.get_cluster(job_status.cluster)
if not cluster:
raise RuntimeError(f"Cluster {job_status.cluster} not found.")
return cluster.get_job(job_status.id)
[docs]
def get_cluster(self, name) -> Optional[BaseCluster]:
"""Gets the cluster with the specified name, or None if not found."""
clusters = self.list_clusters()
for cluster in clusters:
if cluster.name() == name:
return cluster
return None
[docs]
def list_clusters(self) -> list[BaseCluster]:
"""Lists the active clusters on this cloud."""
if self._cloud_name == SkyClient.SupportedClouds.GCP.value:
return self._get_clusters_by_class(sky.clouds.GCP)
elif self._cloud_name == SkyClient.SupportedClouds.RUNPOD.value:
return self._get_clusters_by_class(sky.clouds.RunPod)
elif self._cloud_name == SkyClient.SupportedClouds.LAMBDA.value:
return self._get_clusters_by_class(sky.clouds.Lambda)
elif self._cloud_name == SkyClient.SupportedClouds.AWS.value:
return self._get_clusters_by_class(sky.clouds.AWS)
elif self._cloud_name == SkyClient.SupportedClouds.AZURE.value:
return self._get_clusters_by_class(sky.clouds.Azure)
raise ValueError(f"Unsupported cloud: {self._cloud_name}")
@register_cloud_builder("runpod")
def runpod_cloud_builder() -> SkyCloud:
"""Builds a SkyCloud instance for runpod."""
return SkyCloud(SkyClient.SupportedClouds.RUNPOD.value, SkyClient())
@register_cloud_builder("gcp")
def gcp_cloud_builder() -> SkyCloud:
"""Builds a SkyCloud instance for Google Cloud Platform."""
return SkyCloud(SkyClient.SupportedClouds.GCP.value, SkyClient())
@register_cloud_builder("lambda")
def lambda_cloud_builder() -> SkyCloud:
"""Builds a SkyCloud instance for Lambda."""
return SkyCloud(SkyClient.SupportedClouds.LAMBDA.value, SkyClient())
@register_cloud_builder("aws")
def aws_cloud_builder() -> SkyCloud:
"""Builds a SkyCloud instance for AWS."""
return SkyCloud(SkyClient.SupportedClouds.AWS.value, SkyClient())
@register_cloud_builder("azure")
def azure_cloud_builder() -> SkyCloud:
"""Builds a SkyCloud instance for Azure."""
return SkyCloud(SkyClient.SupportedClouds.AZURE.value, SkyClient())