# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportOptional,TypeVarfromoumi.core.configsimportJobConfigfromoumi.core.launcherimportBaseCloud,BaseCluster,JobStatusfromoumi.core.registryimportregister_cloud_builderfromoumi.launcher.clients.sky_clientimportSkyClientfromoumi.launcher.clusters.sky_clusterimportSkyClusterT=TypeVar("T")
[docs]classSkyCloud(BaseCloud):"""A resource pool capable of creating clusters using Sky Pilot."""@propertydef_client(self)->SkyClient:"""Returns the SkyClient instance."""# Instantiating a SkyClient imports sky.# Delay sky import: https://github.com/oumi-ai/oumi/issues/1605ifnotself._sky_client:self._sky_client=SkyClient()returnself._sky_clientdef__init__(self,cloud_name:str):"""Initializes a new instance of the SkyCloud class."""self._cloud_name=cloud_nameself._sky_client:Optional[SkyClient]=Nonedef_get_clusters_by_class(self,cloud_class:type[T])->list[BaseCluster]:"""Gets the appropriate clusters of type T."""# Delay sky import: https://github.com/oumi-ai/oumi/issues/1605importskyreturn[SkyCluster(cluster["name"],self._client)forclusterinself._client.status()if(isinstance(cluster["handle"].launched_resources.cloud,cloud_class)andcluster["status"]==sky.ClusterStatus.UP)]
[docs]defup_cluster(self,job:JobConfig,name:Optional[str],**kwargs)->JobStatus:"""Creates a cluster and starts the provided Job."""job_status=self._client.launch(job,name,**kwargs)cluster=self.get_cluster(job_status.cluster)ifnotcluster:raiseRuntimeError(f"Cluster {job_status.cluster} not found.")returncluster.get_job(job_status.id)
[docs]defget_cluster(self,name)->Optional[BaseCluster]:"""Gets the cluster with the specified name, or None if not found."""clusters=self.list_clusters()forclusterinclusters:ifcluster.name()==name:returnclusterreturnNone
[docs]deflist_clusters(self)->list[BaseCluster]:"""Lists the active clusters on this cloud."""# Delay sky import: https://github.com/oumi-ai/oumi/issues/1605importskyifself._cloud_name==SkyClient.SupportedClouds.GCP.value:returnself._get_clusters_by_class(sky.clouds.GCP)elifself._cloud_name==SkyClient.SupportedClouds.RUNPOD.value:returnself._get_clusters_by_class(sky.clouds.RunPod)elifself._cloud_name==SkyClient.SupportedClouds.LAMBDA.value:returnself._get_clusters_by_class(sky.clouds.Lambda)elifself._cloud_name==SkyClient.SupportedClouds.AWS.value:returnself._get_clusters_by_class(sky.clouds.AWS)elifself._cloud_name==SkyClient.SupportedClouds.AZURE.value:returnself._get_clusters_by_class(sky.clouds.Azure)raiseValueError(f"Unsupported cloud: {self._cloud_name}")
@register_cloud_builder("runpod")defrunpod_cloud_builder()->SkyCloud:"""Builds a SkyCloud instance for runpod."""returnSkyCloud(SkyClient.SupportedClouds.RUNPOD.value)@register_cloud_builder("gcp")defgcp_cloud_builder()->SkyCloud:"""Builds a SkyCloud instance for Google Cloud Platform."""returnSkyCloud(SkyClient.SupportedClouds.GCP.value)@register_cloud_builder("lambda")deflambda_cloud_builder()->SkyCloud:"""Builds a SkyCloud instance for Lambda."""returnSkyCloud(SkyClient.SupportedClouds.LAMBDA.value)@register_cloud_builder("aws")defaws_cloud_builder()->SkyCloud:"""Builds a SkyCloud instance for AWS."""returnSkyCloud(SkyClient.SupportedClouds.AWS.value)@register_cloud_builder("azure")defazure_cloud_builder()->SkyCloud:"""Builds a SkyCloud instance for Azure."""returnSkyCloud(SkyClient.SupportedClouds.AZURE.value)