Source code for oumi.cli.launch

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from collections import defaultdict
from multiprocessing.pool import Pool
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional

import typer
from rich.columns import Columns
from rich.panel import Panel
from rich.table import Table
from rich.text import Text

import oumi.cli.cli_utils as cli_utils
from oumi.cli.alias import AliasType, try_get_config_name_for_alias
from oumi.utils.git_utils import get_git_root_dir
from oumi.utils.logging import logger
from oumi.utils.version_utils import is_dev_build

if TYPE_CHECKING:
    from oumi.core.launcher import BaseCluster, JobStatus


def _get_working_dir(current: Optional[str]) -> Optional[str]:
    """Prompts the user to select the working directory, if relevant."""
    if not is_dev_build():
        return current
    oumi_root = get_git_root_dir()
    if current and (not oumi_root or oumi_root == Path(current).resolve()):
        return current
    use_root = typer.confirm(
        "You are using a dev build of oumi. "
        f"Use oumi's root directory ({oumi_root}) as your working directory?",
        abort=False,
        default=True,
    )
    return str(oumi_root) if use_root else current


def _print_and_wait(
    message: str, task: Callable[..., bool], asynchronous=True, **kwargs
) -> None:
    """Prints a message with a loading spinner until the provided task is done."""
    with cli_utils.CONSOLE.status(message):
        if asynchronous:
            with Pool(processes=1) as worker_pool:
                task_done = False
                while not task_done:
                    worker_result = worker_pool.apply_async(task, kwds=kwargs)
                    worker_result.wait()
                    # Call get() to reraise any exceptions that occurred in the worker.
                    task_done = worker_result.get()
        else:
            # Synchronous tasks should be atomic and not block for a significant amount
            # of time. If a task is blocking, it should be run asynchronously.
            while not task(**kwargs):
                sleep_duration = 0.1
                time.sleep(sleep_duration)


def _is_job_done(id: str, cloud: str, cluster: str) -> bool:
    """Returns true IFF a job is no longer running."""
    from oumi import launcher

    running_cloud = launcher.get_cloud(cloud)
    running_cluster = running_cloud.get_cluster(cluster)
    if not running_cluster:
        return True
    status = running_cluster.get_job(id)
    return status.done


def _cancel_worker(id: str, cloud: str, cluster: str) -> bool:
    """Cancels a job.

    All workers must return a boolean to indicate whether the task is done.
    Cancel has no intermediate states, so it always returns True.
    """
    from oumi import launcher

    if not cluster:
        return True
    if not id:
        return True
    if not cloud:
        return True
    launcher.cancel(id, cloud, cluster)
    return True  # Always return true to indicate that the task is done.


def _down_worker(cluster: str, cloud: Optional[str]) -> bool:
    """Turns down a cluster.

    All workers must return a boolean to indicate whether the task is done.
    Down has no intermediate states, so it always returns True.
    """
    from oumi import launcher

    if cloud:
        target_cloud = launcher.get_cloud(cloud)
        target_cluster = target_cloud.get_cluster(cluster)
        if target_cluster:
            target_cluster.down()
        else:
            cli_utils.CONSOLE.print(
                f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
            )
        return True
    # Make a best effort to find a single cluster to turn down without a cloud.
    clusters = []
    for name in launcher.which_clouds():
        target_cloud = launcher.get_cloud(name)
        target_cluster = target_cloud.get_cluster(cluster)
        if target_cluster:
            clusters.append(target_cluster)
    if len(clusters) == 0:
        cli_utils.CONSOLE.print(
            f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
        )
        return True
    if len(clusters) == 1:
        clusters[0].down()
    else:
        cli_utils.CONSOLE.print(
            f"[red]Multiple clusters found with name [yellow]{cluster}[/yellow]. "
            "Specify a cloud to turn down with `--cloud`.[/red]"
        )
    return True  # Always return true to indicate that the task is done.


def _stop_worker(cluster: str, cloud: Optional[str]) -> bool:
    """Stops a cluster.

    All workers must return a boolean to indicate whether the task is done.
    Stop has no intermediate states, so it always returns True.
    """
    from oumi import launcher

    if cloud:
        target_cloud = launcher.get_cloud(cloud)
        target_cluster = target_cloud.get_cluster(cluster)
        if target_cluster:
            target_cluster.stop()
        else:
            cli_utils.CONSOLE.print(
                f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
            )
        return True
    # Make a best effort to find a single cluster to stop without a cloud.
    clusters = []
    for name in launcher.which_clouds():
        target_cloud = launcher.get_cloud(name)
        target_cluster = target_cloud.get_cluster(cluster)
        if target_cluster:
            clusters.append(target_cluster)
    if len(clusters) == 0:
        cli_utils.CONSOLE.print(
            f"[red]Cluster [yellow]{cluster}[/yellow] not found.[/red]"
        )
        return True
    if len(clusters) == 1:
        clusters[0].stop()
    else:
        cli_utils.CONSOLE.print(
            f"[red]Multiple clusters found with name [yellow]{cluster}[/yellow]. "
            "Specify a cloud to stop with `--cloud`.[/red]"
        )
    return True  # Always return true to indicate that the task is done.


def _poll_job(
    job_status: "JobStatus",
    detach: bool,
    cloud: str,
    running_cluster: Optional["BaseCluster"] = None,
) -> None:
    """Polls a job until it is complete.

    If the job is running in detached mode and the job is not on the local cloud,
    the function returns immediately.
    """
    import oumi.launcher.clients.sky_client as sky_client
    from oumi import launcher

    is_local = cloud == "local"
    if detach and not is_local:
        cli_utils.CONSOLE.print(
            f"Running job [yellow]{job_status.id}[/yellow] in detached mode."
        )
        return
    if detach and is_local:
        cli_utils.CONSOLE.print("Cannot detach from jobs in local mode.")

    if not running_cluster:
        running_cloud = launcher.get_cloud(cloud)
        running_cluster = running_cloud.get_cluster(job_status.cluster)

    assert running_cluster

    # Check if this is a Skypilot job and tail logs automatically
    if cloud in [cloud.value for cloud in sky_client.SkyClient.SupportedClouds]:
        cli_utils.CONSOLE.print(
            f"Tailing logs for job [yellow]{job_status.id}[/yellow]..."
        )
        # Delay sky import: https://github.com/oumi-ai/oumi/issues/1605
        import sky

        sky.tail_logs(
            cluster_name=job_status.cluster,
            job_id=job_status.id,
        )
    else:
        _print_and_wait(
            f"Running job [yellow]{job_status.id}[/yellow]",
            _is_job_done,
            asynchronous=not is_local,
            id=job_status.id,
            cloud=cloud,
            cluster=job_status.cluster,
        )

    final_status = running_cluster.get_job(job_status.id)
    if final_status:
        cli_utils.CONSOLE.print(
            f"Job [yellow]{final_status.id}[/yellow] finished with "
            f"status [yellow]{final_status.status}[/yellow]"
        )
        cli_utils.CONSOLE.print("Job metadata:")
        cli_utils.CONSOLE.print(f"[yellow]{final_status.metadata}[/yellow]")


# ----------------------------
# Launch CLI subcommands
# ----------------------------


[docs] def cancel( cloud: Annotated[str, typer.Option(help="Filter results by this cloud.")], cluster: Annotated[ str, typer.Option(help="Filter results by clusters matching this name."), ], id: Annotated[ str, typer.Option(help="Filter results by jobs matching this job ID.") ], level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Cancels a job. Args: cloud: Filter results by this cloud. cluster: Filter results by clusters matching this name. id: Filter results by jobs matching this job ID. level: The logging level for the specified command. """ _print_and_wait( f"Canceling job [yellow]{id}[/yellow]", _cancel_worker, id=id, cloud=cloud, cluster=cluster, )
[docs] def down( cluster: Annotated[str, typer.Option(help="The cluster to turn down.")], cloud: Annotated[ Optional[str], typer.Option( help="If specified, only clusters on this cloud will be affected." ), ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Turns down a cluster. Args: cluster: The cluster to turn down. cloud: If specified, only clusters on this cloud will be affected. level: The logging level for the specified command. """ _print_and_wait( f"Turning down cluster [yellow]{cluster}[/yellow]", _down_worker, cluster=cluster, cloud=cloud, ) cli_utils.CONSOLE.print(f"Cluster [yellow]{cluster}[/yellow] turned down!")
[docs] def run( ctx: typer.Context, config: Annotated[ str, typer.Option( *cli_utils.CONFIG_FLAGS, help="Path to the configuration file for the job." ), ], cluster: Annotated[ Optional[str], typer.Option( help=( "The cluster to use for this job. If unspecified, a new cluster will " "be created." ) ), ] = None, detach: Annotated[ bool, typer.Option(help="Run the job in the background.") ] = False, level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Runs a job on the target cluster. Args: ctx: The Typer context object. config: Path to the configuration file for the job. cluster: The cluster to use for this job. If no such cluster exists, a new cluster will be created. If unspecified, a new cluster will be created with a unique name. detach: Run the job in the background. level: The logging level for the specified command. """ extra_args = cli_utils.parse_extra_cli_args(ctx) config = str( cli_utils.resolve_and_fetch_config( try_get_config_name_for_alias(config, AliasType.JOB), ) ) # Delayed imports from oumi import launcher # End imports parsed_config: launcher.JobConfig = launcher.JobConfig.from_yaml_and_arg_list( config, extra_args, logger=logger ) parsed_config.finalize_and_validate() parsed_config.working_dir = _get_working_dir(parsed_config.working_dir) if not cluster: raise ValueError("No cluster specified for the `run` action.") job_status = launcher.run(parsed_config, cluster) cli_utils.CONSOLE.print( f"Job [yellow]{job_status.id}[/yellow] queued on cluster " f"[yellow]{cluster}[/yellow]." ) _poll_job(job_status=job_status, detach=detach, cloud=parsed_config.resources.cloud)
[docs] def status( cloud: Annotated[ Optional[str], typer.Option(help="Filter results by this cloud.") ] = None, cluster: Annotated[ Optional[str], typer.Option(help="Filter results by clusters matching this name."), ] = None, id: Annotated[ Optional[str], typer.Option(help="Filter results by jobs matching this job ID.") ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Prints the status of jobs launched from Oumi. Optionally, the caller may specify a job id, cluster, or cloud to further filter results. Args: cloud: Filter results by this cloud. cluster: Filter results by clusters matching this name. id: Filter results by jobs matching this job ID. level: The logging level for the specified command. """ # Delayed imports from oumi import launcher # End imports filtered_jobs = launcher.status(cloud=cloud, cluster=cluster, id=id) num_jobs = sum(len(cloud_jobs) for cloud_jobs in filtered_jobs.keys()) # Print the filtered jobs. if num_jobs == 0 and (cloud or cluster or id): cli_utils.CONSOLE.print( "[red]No jobs found for the specified filter criteria: [/red]" ) if cloud: cli_utils.CONSOLE.print(f"Cloud: [yellow]{cloud}[/yellow]") if cluster: cli_utils.CONSOLE.print(f"Cluster: [yellow]{cluster}[/yellow]") if id: cli_utils.CONSOLE.print(f"Job ID: [yellow]{id}[/yellow]") for target_cloud, job_list in filtered_jobs.items(): cli_utils.section_header(f"Cloud: [yellow]{target_cloud}[/yellow]") cluster_name_list = [ c.name() for c in launcher.get_cloud(target_cloud).list_clusters() ] if len(cluster_name_list) == 0: cli_utils.CONSOLE.print("[red]No matching clusters found.[/red]") continue # Organize all jobs by cluster. jobs_by_cluster: dict[str, list[JobStatus]] = defaultdict(list) # List all clusters, even if they don't have jobs. for cluster_name in cluster_name_list: if not cluster or cluster == cluster_name: jobs_by_cluster[cluster_name] = [] for job in job_list: jobs_by_cluster[job.cluster].append(job) for target_cluster, jobs in jobs_by_cluster.items(): title = f"[cyan]Cluster: [yellow]{target_cluster}[/yellow][/cyan]" if not jobs: body = Text("[red]No matching jobs found.[/red]") else: jobs_table = Table(show_header=True, show_lines=False) jobs_table.add_column("Job", justify="left", style="green") jobs_table.add_column("Status", justify="left", style="yellow") for job in jobs: jobs_table.add_row(job.id, job.status) body = jobs_table cli_utils.CONSOLE.print(Panel(body, title=title, border_style="blue"))
[docs] def stop( cluster: Annotated[str, typer.Option(help="The cluster to stop.")], cloud: Annotated[ Optional[str], typer.Option( help="If specified, only clusters on this cloud will be affected." ), ] = None, level: cli_utils.LOG_LEVEL_TYPE = None, ) -> None: """Stops a cluster. Args: cluster: The cluster to stop. cloud: If specified, only clusters on this cloud will be affected. level: The logging level for the specified command. """ _print_and_wait( f"Stopping cluster [yellow]{cluster}[/yellow]", _stop_worker, cluster=cluster, cloud=cloud, ) cli_utils.CONSOLE.print( f"Cluster [yellow]{cluster}[/yellow] stopped!\n" "Use [green]oumi launch down[/green] to turn it down." )
[docs] def up( ctx: typer.Context, config: Annotated[ str, typer.Option( *cli_utils.CONFIG_FLAGS, help="Path to the configuration file for the job." ), ], cluster: Annotated[ Optional[str], typer.Option( help=( "The cluster to use for this job. If unspecified, a new cluster will " "be created." ) ), ] = None, detach: Annotated[ bool, typer.Option(help="Run the job in the background.") ] = False, level: cli_utils.LOG_LEVEL_TYPE = None, ): """Launches a job. Args: ctx: The Typer context object. config: Path to the configuration file for the job. cluster: The cluster to use for this job. If no such cluster exists, a new cluster will be created. If unspecified, a new cluster will be created with a unique name. detach: Run the job in the background. level: The logging level for the specified command. """ # Delayed imports from oumi import launcher # End imports extra_args = cli_utils.parse_extra_cli_args(ctx) config = str( cli_utils.resolve_and_fetch_config( try_get_config_name_for_alias(config, AliasType.JOB), ) ) parsed_config: launcher.JobConfig = launcher.JobConfig.from_yaml_and_arg_list( config, extra_args, logger=logger ) parsed_config.finalize_and_validate() if cluster: target_cloud = launcher.get_cloud(parsed_config.resources.cloud) target_cluster = target_cloud.get_cluster(cluster) if target_cluster: cli_utils.CONSOLE.print( f"Found an existing cluster: [yellow]{target_cluster.name()}[/yellow]." ) run(ctx, config, cluster, detach) return parsed_config.working_dir = _get_working_dir(parsed_config.working_dir) # Start the job running_cluster, job_status = launcher.up(parsed_config, cluster) cli_utils.CONSOLE.print( f"Job [yellow]{job_status.id}[/yellow] queued on cluster " f"[yellow]{running_cluster.name()}[/yellow]." ) _poll_job( job_status=job_status, detach=detach, cloud=parsed_config.resources.cloud, running_cluster=running_cluster, )
[docs] def which(level: cli_utils.LOG_LEVEL_TYPE = None) -> None: """Prints the available clouds.""" # Delayed imports from oumi import launcher # End imports clouds = launcher.which_clouds() cloud_options = [Text(f"{cloud}", style="bold cyan") for cloud in clouds] cli_utils.CONSOLE.print( Panel( Columns(cloud_options, equal=True, expand=True, padding=(0, 2)), title="[yellow]Available Clouds[/yellow]", border_style="blue", ) )