Shortcuts

Source code for torchx.schedulers.ray_scheduler

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import json
import logging
import os
import re
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime
from shutil import copy2, rmtree
from typing import Any, cast, Dict, Final, Iterable, List, Optional, Tuple  # noqa

import urllib3

from torchx.schedulers.api import (
    AppDryRunInfo,
    AppState,
    DescribeAppResponse,
    filter_regex,
    ListAppResponse,
    Scheduler,
    split_lines,
    Stream,
)
from torchx.schedulers.ids import make_unique
from torchx.schedulers.ray.ray_common import RayActor, TORCHX_RANK0_HOST
from torchx.specs import AppDef, macros, NONE, ReplicaStatus, Role, RoleStatus, runopts
from torchx.workspace.dir_workspace import TmpDirWorkspaceMixin
from typing_extensions import TypedDict


try:
    from ray.autoscaler import sdk as ray_autoscaler_sdk
    from ray.dashboard.modules.job.common import JobStatus
    from ray.dashboard.modules.job.sdk import JobSubmissionClient

    _has_ray = True

except ImportError:
    _has_ray = False


[docs]def has_ray() -> bool: """Indicates whether Ray is installed in the current Python environment.""" return _has_ray
class RayOpts(TypedDict, total=False): cluster_config_file: Optional[str] cluster_name: Optional[str] dashboard_address: Optional[str] working_dir: Optional[str] requirements: Optional[str] if _has_ray: _logger: logging.Logger = logging.getLogger(__name__) _ray_status_to_torchx_appstate: Dict[JobStatus, AppState] = { JobStatus.PENDING: AppState.PENDING, JobStatus.RUNNING: AppState.RUNNING, JobStatus.SUCCEEDED: AppState.SUCCEEDED, JobStatus.FAILED: AppState.FAILED, JobStatus.STOPPED: AppState.CANCELLED, } class _EnhancedJSONEncoder(json.JSONEncoder): def default(self, o: RayActor): # pyre-ignore[3] if dataclasses.is_dataclass(o): return dataclasses.asdict(o) return super().default(o)
[docs] def serialize( actors: List[RayActor], dirpath: str, output_filename: str = "actors.json" ) -> None: actors_json = json.dumps(actors, cls=_EnhancedJSONEncoder) with open(os.path.join(dirpath, output_filename), "w") as tmp: json.dump(actors_json, tmp)
[docs] @dataclass class RayJob: """Represents a job that should be run on a Ray cluster. Attributes: app_id: The unique ID of the application (a.k.a. job). cluster_config_file: The Ray cluster configuration file. cluster_name: The cluster name to use. dashboard_address: The existing dashboard IP address to connect to working_dir: The working directory to copy to the cluster requirements: The libraries to install on the cluster per requirements.txt actors: The Ray actors which represent the job to be run. This attribute is dumped to a JSON file and copied to the cluster where `ray_main.py` uses it to initiate the job. """ app_id: str working_dir: str cluster_config_file: Optional[str] = None cluster_name: Optional[str] = None dashboard_address: Optional[str] = None requirements: Optional[str] = None actors: List[RayActor] = field(default_factory=list)
[docs] class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]): """ RayScheduler is a TorchX scheduling interface to Ray. The job def workers will be launched as Ray actors The job environment is specified by the TorchX workspace. Any files in the workspace will be present in the Ray job unless specified in ``.torchxignore``. Python dependencies will be read from the ``requirements.txt`` file located at the root of the workspace unless it's overridden via ``-c ...,requirements=foo/requirements.txt``. **Config Options** .. runopts:: class: torchx.schedulers.ray_scheduler.create_scheduler **Compatibility** .. compatibility:: type: scheduler features: cancel: true logs: | Partial support. Ray only supports a single log stream so only a dummy "ray/0" combined log role is supported. Tailing and time seeking are not supported. distributed: true describe: | Partial support. RayScheduler will return job status but does not provide the complete original AppSpec. workspaces: true mounts: false elasticity: Partial support. Multi role jobs are not supported. """ def __init__( self, session_name: str, ray_client: Optional[JobSubmissionClient] = None ) -> None: # NOTE: make sure any new init options are supported in create_scheduler(...) super().__init__("ray", session_name) # w/o Final None check in _get_ray_client does not work as it pyre assumes mutability self._ray_client: Final[Optional[JobSubmissionClient]] = ray_client def _get_ray_client( self, job_submission_netloc: Optional[str] = None ) -> JobSubmissionClient: if self._ray_client is not None: client_netloc = urllib3.util.parse_url( self._ray_client.get_address() ).netloc if job_submission_netloc and job_submission_netloc != client_netloc: raise ValueError( f"client netloc ({client_netloc}) does not match job netloc ({job_submission_netloc})" ) return self._ray_client elif os.getenv("RAY_ADDRESS"): return JobSubmissionClient(os.getenv("RAY_ADDRESS")) elif not job_submission_netloc: raise Exception( "RAY_ADDRESS env variable or a scheduler with an attached Ray JobSubmissionClient is expected." " See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info" ) return JobSubmissionClient(f"http://{job_submission_netloc}") # TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file def _run_opts(self) -> runopts: opts = runopts() opts.add( "cluster_config_file", type_=str, required=False, help="Use CLUSTER_CONFIG_FILE to access or create the Ray cluster.", ) opts.add( "cluster_name", type_=str, help="Override the configured cluster name.", ) opts.add( "dashboard_address", type_=str, required=False, default="127.0.0.1:8265", help="Use ray status to get the dashboard address you will submit jobs against", ) opts.add("requirements", type_=str, help="Path to requirements.txt") return opts
[docs] def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str: cfg: RayJob = dryrun_info.request # Create serialized actors for ray_driver.py actors = cfg.actors dirpath = cfg.working_dir serialize(actors, dirpath) job_submission_addr: str = "" if cfg.cluster_config_file: job_submission_addr = ray_autoscaler_sdk.get_head_node_ip( cfg.cluster_config_file ) # pragma: no cover elif cfg.dashboard_address: job_submission_addr = cfg.dashboard_address else: raise RuntimeError( "Either `dashboard_address` or `cluster_config_file` must be specified" ) # 0. Create Job Client client = self._get_ray_client(job_submission_netloc=job_submission_addr) # 1. Copy Ray driver utilities current_directory = os.path.dirname(os.path.abspath(__file__)) copy2(os.path.join(current_directory, "ray", "ray_driver.py"), dirpath) copy2(os.path.join(current_directory, "ray", "ray_common.py"), dirpath) runtime_env = {"working_dir": dirpath} if cfg.requirements: runtime_env["pip"] = cfg.requirements # 1. Submit Job via the Ray Job Submission API try: job_id: str = client.submit_job( submission_id=cfg.app_id, # we will pack, hash, zip, upload, register working_dir in GCS of ray cluster # and use it to configure your job execution. entrypoint="python3 ray_driver.py", runtime_env=runtime_env, ) finally: if dirpath.startswith(tempfile.gettempdir()): rmtree(dirpath) # Encode job submission client in job_id return f"{job_submission_addr}-{job_id}"
def _submit_dryrun(self, app: AppDef, cfg: RayOpts) -> AppDryRunInfo[RayJob]: app_id = make_unique(app.name) working_dir = app.roles[0].image if not os.path.exists(working_dir): raise RuntimeError( f"Role image must be a valid directory, got: {working_dir} " ) requirements: Optional[str] = cfg.get("requirements") if requirements is None: workspace_reqs = os.path.join(working_dir, "requirements.txt") if os.path.exists(workspace_reqs): requirements = workspace_reqs cluster_cfg = cfg.get("cluster_config_file") if cluster_cfg: if not isinstance(cluster_cfg, str) or not os.path.isfile(cluster_cfg): raise ValueError( "The cluster configuration file must be a YAML file." ) job: RayJob = RayJob( app_id, cluster_config_file=cluster_cfg, requirements=requirements, working_dir=working_dir, ) else: # pragma: no cover dashboard_address = cfg.get("dashboard_address") job: RayJob = RayJob( app_id=app_id, dashboard_address=dashboard_address, requirements=requirements, working_dir=working_dir, ) job.cluster_name = cfg.get("cluster_name") for role in app.roles: for replica_id in range(role.num_replicas): # Replace the ${img_root}, ${app_id}, and ${replica_id} placeholders # in arguments and environment variables. replica_role = macros.Values( img_root=role.image, app_id=app_id, replica_id=str(replica_id), rank0_env=TORCHX_RANK0_HOST, ).apply(role) actor = RayActor( name=role.name, min_replicas=role.min_replicas, command=[replica_role.entrypoint] + replica_role.args, env=replica_role.env, num_cpus=max(1, replica_role.resource.cpu), num_gpus=max(0, replica_role.resource.gpu), ) job.actors.append(actor) if len(app.roles) > 1 and app.roles[0].min_replicas is not None: raise ValueError("min_replicas is only supported with single role jobs") return AppDryRunInfo(job, repr) def _validate(self, app: AppDef, scheduler: str) -> None: if scheduler != "ray": raise ValueError( f"An unknown scheduler backend '{scheduler}' has been passed to the Ray scheduler." ) if app.metadata: _logger.warning("The Ray scheduler does not use metadata information.") for role in app.roles: if role.resource.capabilities: _logger.warning( "The Ray scheduler does not support custom resource capabilities." ) break for role in app.roles: if role.port_map: _logger.warning("The Ray scheduler does not support port mapping.") break
[docs] def wait_until_finish(self, app_id: str, timeout: int = 30) -> None: """ ``wait_until_finish`` waits until the specified job has finished with a given timeout. This is intended for testing. Programmatic usage should use the runner wait method instead. """ start = time.time() while time.time() - start <= timeout: status_info = self._get_job_status(app_id) status = status_info if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}: break time.sleep(1)
def _parse_app_id(self, app_id: str) -> Tuple[str, str]: # find index of '-' in the first :\d+- m = re.search(r":\d+-", app_id) if m: sep = m.span()[1] addr = app_id[: sep - 1] app_id = app_id[sep:] return addr, app_id addr, _, app_id = app_id.partition("-") return addr, app_id def _cancel_existing(self, app_id: str) -> None: # pragma: no cover addr, app_id = self._parse_app_id(app_id) client = self._get_ray_client(job_submission_netloc=addr) client.stop_job(app_id) def _get_job_status(self, app_id: str) -> JobStatus: addr, app_id = self._parse_app_id(app_id) client = self._get_ray_client(job_submission_netloc=addr) status = client.get_job_status(app_id) if isinstance(status, str): return cast(JobStatus, status) return status.status
[docs] def describe(self, app_id: str) -> Optional[DescribeAppResponse]: job_status_info = self._get_job_status(app_id) state = _ray_status_to_torchx_appstate[job_status_info] roles = [Role(name="ray", num_replicas=1, image="<N/A>")] # get ip_address and put it in hostname roles_statuses = [ RoleStatus( role="ray", replicas=[ ReplicaStatus( id=0, role="ray", hostname=NONE, state=state, ) ], ) ] return DescribeAppResponse( app_id=app_id, state=state, msg=job_status_info, roles_statuses=roles_statuses, roles=roles, )
[docs] def log_iter( self, app_id: str, role_name: Optional[str] = None, k: int = 0, regex: Optional[str] = None, since: Optional[datetime] = None, until: Optional[datetime] = None, should_tail: bool = False, streams: Optional[Stream] = None, ) -> Iterable[str]: # TODO: support tailing, streams etc.. addr, app_id = self._parse_app_id(app_id) client: JobSubmissionClient = self._get_ray_client( job_submission_netloc=addr ) logs: str = client.get_job_logs(app_id) iterator = split_lines(logs) if regex: return filter_regex(regex, iterator) return iterator
[docs] def list(self) -> List[ListAppResponse]: client = self._get_ray_client() jobs = client.list_jobs() netloc = urllib3.util.parse_url(client.get_address()).netloc return [ ListAppResponse( app_id=f"{netloc}-{details.submission_id}", state=_ray_status_to_torchx_appstate[details.status], ) for details in jobs ]
[docs]def create_scheduler( session_name: str, ray_client: Optional[JobSubmissionClient] = None, **kwargs: Any ) -> "RayScheduler": if not has_ray(): # pragma: no cover raise ModuleNotFoundError( "Ray is not installed in the current Python environment." ) return RayScheduler(session_name=session_name, ray_client=ray_client)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources