Source code for oumi.launcher.clusters.sky_cluster
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,Optionalfromoumi.core.configsimportJobConfigfromoumi.core.launcherimportBaseCluster,JobStatusfromoumi.launcher.clients.sky_clientimportSkyClient
[docs]classSkyCluster(BaseCluster):"""A cluster implementation backed by Sky Pilot."""def__init__(self,name:str,client:SkyClient)->None:"""Initializes a new instance of the SkyCluster class."""# Delay sky import: https://github.com/oumi-ai/oumi/issues/1605importsky.exceptionsself._sky_exceptions=sky.exceptionsself._name=nameself._client=client
[docs]def__eq__(self,other:Any)->bool:"""Checks if two SkyClusters are equal."""ifnotisinstance(other,SkyCluster):returnFalsereturnself.name()==other.name()
def_convert_sky_job_to_status(self,sky_job:dict)->JobStatus:"""Converts a sky job to a JobStatus."""required_fields=["job_id","job_name","status"]forfieldinrequired_fields:iffieldnotinsky_job:raiseValueError(f"Missing required field: {field}")returnJobStatus(id=str(sky_job["job_id"]),name=str(sky_job["job_name"]),status=str(sky_job["status"]),cluster=self.name(),metadata="",# See sky job states here:# https://skypilot.readthedocs.io/en/latest/reference/cli.html#sky-jobs-queuedone=str(sky_job["status"])notin["JobStatus.PENDING","JobStatus.SUBMITTED","JobStatus.STARTING","JobStatus.RUNNING","JobStatus.RECOVERING","JobStatus.CANCELLING",],)
[docs]defname(self)->str:"""Gets the name of the cluster."""returnself._name
[docs]defget_job(self,job_id:str)->Optional[JobStatus]:"""Gets the jobs on this cluster if it exists, else returns None."""forjobinself.get_jobs():ifjob.id==job_id:returnjobreturnNone
[docs]defget_jobs(self)->list[JobStatus]:"""Lists the jobs on this cluster."""try:return[self._convert_sky_job_to_status(job)forjobinself._client.queue(self.name())]exceptself._sky_exceptions.ClusterNotUpError:return[]
[docs]defcancel_job(self,job_id:str)->JobStatus:"""Cancels the specified job on this cluster."""self._client.cancel(self.name(),job_id)job=self.get_job(job_id)ifjobisNone:raiseRuntimeError(f"Job {job_id} not found.")returnjob
[docs]defrun_job(self,job:JobConfig)->JobStatus:"""Runs the specified job on this cluster."""job_id=self._client.exec(job,self.name())job_status=self.get_job(job_id)ifjob_statusisNone:raiseRuntimeError(f"Job {job_id} not found after submission.")returnjob_status
[docs]defstop(self)->None:"""Stops the current cluster."""self._client.stop(self.name())
[docs]defdown(self)->None:"""Tears down the current cluster."""self._client.down(self.name())