Source code for oumi.launcher.clusters.modal_cluster

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Modal-backed cluster implementation.

Modal has no native cluster concept — every job is a single ``Sandbox``.
``ModalCluster`` is a thin façade that maps a logical cluster name (the
SkyPilot-style identifier callers like the Oumi worker pass to
``oumi.launcher.up``) onto sandbox lookups by ``object_id``. Job lookups
use the ``job_id`` argument directly so callers don't need to know the
mapping.

``stop()`` and ``down()`` cancel every sandbox the in-process
``ModalClient`` has launched under this cluster name. Across worker
restarts the mapping is lost; cleanup at that point should fall back
to per-sandbox ``cancel_job`` using the ``job_id`` persisted by the
caller alongside the cluster name.
"""

from __future__ import annotations

from typing import Any

from oumi.core.configs import JobConfig
from oumi.core.launcher import BaseCluster, ClusterNotFoundError, JobStatus
from oumi.launcher.clients.modal_client import ModalClient, ModalLogStream


[docs] class ModalCluster(BaseCluster): """A cluster implementation backed by Modal sandboxes.""" def __init__(self, name: str, client: ModalClient) -> None: """Initializes a new instance of the ModalCluster class. Args: name: Logical cluster name (typically the ``cluster-job-{project}-{op}-...`` style identifier the caller used when invoking ``oumi.launcher.up``). client: A configured ``ModalClient``. """ self._name = name self._client = client
[docs] def __eq__(self, other: Any) -> bool: """Checks if two ModalClusters are equal.""" if not isinstance(other, ModalCluster): return False return self.name() == other.name()
[docs] def __hash__(self) -> int: """Hashes by cluster name so instances can live in sets/dicts.""" return hash(self._name)
[docs] def name(self) -> str: """Gets the cluster name.""" return self._name
[docs] def get_job(self, job_id: str) -> JobStatus | None: """Gets the status of the sandbox identified by ``job_id``. ``job_id`` is the opaque ``Sandbox.object_id`` returned at launch time (and persisted by the caller). The cluster name is purely logical, so this method ignores ``self._name`` and goes straight to the sandbox lookup. """ try: return self._client.get_status(job_id) except ClusterNotFoundError: return None
[docs] def get_jobs(self) -> list[JobStatus]: """Lists the jobs spawned under this cluster name in this process.""" statuses: list[JobStatus] = [] for sandbox_id in self._client.find_sandboxes_for_cluster(self._name): try: statuses.append(self._client.get_status(sandbox_id)) except ClusterNotFoundError: continue return statuses
[docs] def cancel_job(self, job_id: str) -> JobStatus: """Cancels the sandbox identified by ``job_id`` and returns its status.""" self._client.cancel(job_id) return self._client.get_status(job_id)
[docs] def run_job(self, job: JobConfig) -> JobStatus: """Re-running on a Modal cluster is unsupported. Modal jobs are 1:1 with sandboxes. To run a new job, allocate a new sandbox via ``ModalCloud.up_cluster``. """ raise NotImplementedError( "Modal does not support re-running jobs on an existing cluster. " "Call ModalCloud.up_cluster(...) to spawn a new sandbox." )
[docs] def stop(self) -> None: """Best-effort cancel of every sandbox tracked under this cluster name.""" for sandbox_id in self._client.find_sandboxes_for_cluster(self._name): self._client.cancel(sandbox_id)
[docs] def down(self) -> None: """Alias for ``stop`` — Modal is serverless, nothing else to tear down.""" self.stop()
[docs] def get_logs_stream( self, cluster_name: str, job_id: str | None = None ) -> ModalLogStream: """Returns a stream of logs for ``job_id`` (sandbox object_id). ``cluster_name`` is accepted for interface compatibility and ignored. ``job_id`` is the canonical handle. If ``job_id`` is omitted, falls back to the most recently launched sandbox under this cluster name (in this process). """ target_sandbox = job_id if target_sandbox is None: tracked = self._client.find_sandboxes_for_cluster(self._name) if not tracked: raise ClusterNotFoundError( f"No sandboxes tracked for cluster '{self._name}' " "and no job_id provided." ) target_sandbox = tracked[-1] return self._client.get_logs_stream(target_sandbox)