Source code for torchx.schedulers.aws_batch_scheduler

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


This contains the TorchX AWS Batch scheduler which can be used to run TorchX
components directly on AWS Batch.

This scheduler is in prototype stage and may change without notice.


You'll need to create an AWS Batch queue configured for multi-node parallel jobs.

for how to setup a job queue and compute environment. It needs to be backed by
EC2 for multi-node parallel jobs.

for more information on distributed jobs.

If you want to use workspaces and container patching you'll also need to
configure a docker registry to store the patched containers with your changes
such as AWS ECR.

for how to create a image repository.
import getpass
import re
import threading
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import auto, Enum
from typing import (

import torchx
import yaml
from torchx.schedulers.api import (

from torchx.schedulers.devices import get_device_mounts
from torchx.schedulers.ids import make_unique
from torchx.specs.api import (
from torchx.specs.named_resources_aws import instance_type_from_resource
from torchx.util.types import none_throws
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
from typing_extensions import TypedDict





    from docker import DockerClient

JOB_STATE: Dict[str, AppState] = {
    "PENDING": AppState.PENDING,
    "RUNNING": AppState.RUNNING,
    "FAILED": AppState.FAILED,

def to_millis_since_epoch(ts: datetime) -> int:
    # datetime's timestamp returns seconds since epoch
    return int(round(ts.timestamp() * 1000))

def to_datetime(ms_since_epoch: int) -> datetime:
    return datetime.fromtimestamp(ms_since_epoch / 1000)

class ResourceType(Enum):
    VCPU = auto()
    GPU = auto()
    MEMORY = auto()

    def from_str(resource_type: str) -> "ResourceType":
        for rt in ResourceType:
            if == resource_type.upper():
                return rt
        raise ValueError(
            f"No ResourceType found for `{resource_type}`. Valid types: {[ for r in ResourceType]}"

def resource_requirements_from_resource(resource: Resource) -> List[Dict[str, str]]:
    cpu = resource.cpu if resource.cpu > 0 else 1
    gpu = resource.gpu
    memMB = resource.memMB
    assert (
        memMB > 0
    ), f"AWSBatchScheduler requires memMB to be set to a positive value, got {memMB}"

    resource_requirements = [
        {"type":, "value": str(cpu)},
        {"type":, "value": str(memMB)},
    if gpu > 0:
        resource_requirements.append({"type":, "value": str(gpu)})
    return resource_requirements

def resource_from_resource_requirements(
    resource_requirements: List[Dict[str, str]]
) -> Resource:
    resrc_req = {
        ResourceType.from_str(r["type"]): int(r["value"]) for r in resource_requirements
    return Resource(
        gpu=resrc_req.get(ResourceType.GPU, 0),
        # TODO kiukchung@ map back capabilities and devices
        # might be better to tag the named resource and finding the resource
        # this requires the named resource to be part of the AppDef spec
        # but today we lose the named resource str at the component level

def _role_to_node_properties(
    role: Role,
    start_idx: int,
    privileged: bool = False,
    job_role_arn: Optional[str] = None,
    execution_role_arn: Optional[str] = None,
) -> Dict[str, object]:
    role.mounts += get_device_mounts(role.resource.devices)

    mount_points = []
    volumes = []
    devices = []
    for i, mount in enumerate(role.mounts):
        name = f"mount_{i}"
        if isinstance(mount, BindMount):
                    "name": name,
                    "host": {
                        "sourcePath": mount.src_path,
                    "containerPath": mount.dst_path,
                    "readOnly": mount.read_only,
                    "sourceVolume": name,
        elif isinstance(mount, VolumeMount):
                    "name": name,
                    "efsVolumeConfiguration": {
                        "fileSystemId": mount.src,
                    "containerPath": mount.dst_path,
                    "readOnly": mount.read_only,
                    "sourceVolume": name,
        elif isinstance(mount, DeviceMount):
            perm_map = {
                "r": "READ",
                "w": "WRITE",
                "m": "MKNOD",
                    "hostPath": mount.src_path,
                    "containerPath": mount.dst_path,
                    "permissions": [perm_map[p] for p in mount.permissions],
            raise TypeError(f"unknown mount type {mount}")

    container = {
        "command": [role.entrypoint] + role.args,
        "image": role.image,
        "environment": [{"name": k, "value": v} for k, v in role.env.items()],
        "privileged": privileged,
        "resourceRequirements": resource_requirements_from_resource(role.resource),
        "linuxParameters": {
            # To support PyTorch dataloaders we need to set /dev/shm to larger
            # than the 64M default.
            "sharedMemorySize": role.resource.memMB,
            "devices": devices,
        "logConfiguration": {
            "logDriver": "awslogs",
        "mountPoints": mount_points,
        "volumes": volumes,
    if job_role_arn:
        container["jobRoleArn"] = job_role_arn
    if execution_role_arn:
        container["executionRoleArn"] = execution_role_arn
    if role.num_replicas > 1:
        instance_type = instance_type_from_resource(role.resource)
        if instance_type is not None:
            container["instanceType"] = instance_type

    return {
        "targetNodes": f"{start_idx}:{start_idx + role.num_replicas - 1}",
        "container": container,

def _job_ui_url(job_arn: str) -> Optional[str]:
    match = re.match(
    if match is None:
        return None
    region =
    job_id =
    return f"https://{region}{region}#jobs/mnp-job/{job_id}"

def _parse_num_replicas(target_nodes: str, num_nodes: int) -> int:
    Parses the number of replicas for a role given the target_nodes string
    and total num_nodes. See docstring for ``_parse_start_and_end_idx()``
    for details on the format of ``target_nodes`` string.

    start_idx, end_idx = _parse_start_and_end_idx(target_nodes, num_nodes)
    return end_idx - start_idx + 1

def _parse_start_and_end_idx(target_nodes: str, num_nodes: int) -> Tuple[int, int]:
    Takes the ``target_nodes`` str (as required by AWS Batch NodeRangeProperties)
    and parses out the start and end indices (aka global rank) of the replicas in the node group.
    The ``target_nodes`` string is of the form:

    #. ``[start_node_index]:[end_node_index]`` (e.g. ``0:5``)
    #. --or-- ``:[end_node_index]`` (e.g. ``:5``)
    #. --or-- ``[start_node_index]:`` (e.g. ``0:``)
    #. --or-- ``[node_index]`` (e.g. ``0`` - single node multi-node-parallel job)


    indices = target_nodes.split(":")
    if len(indices) == 1:
        return int(indices[0]), int(indices[0])
        start_idx = indices[0]
        end_idx = indices[1]
        return int(start_idx or "0"), int(end_idx or str(num_nodes - 1))

[docs]@dataclass class BatchJob: name: str queue: str share_id: Optional[str] job_def: Dict[str, object] images_to_push: Dict[str, Tuple[str, str]] def __str__(self) -> str: return yaml.dump(asdict(self)) def __repr__(self) -> str: return str(self)
T = TypeVar("T") def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]: local: threading.local = threading.local() key: str = "value" def wrapper() -> T: if key in local.__dict__: return local.__dict__[key] v = f() local.__dict__[key] = v return v return wrapper @_thread_local_cache def _local_session() -> "boto3.session.Session": import boto3.session return boto3.session.Session() class AWSBatchOpts(TypedDict, total=False): queue: str user: str image_repo: Optional[str] privileged: bool share_id: Optional[str] priority: int job_role_arn: Optional[str] execution_role_arn: Optional[str]
[docs]class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]): """ AWSBatchScheduler is a TorchX scheduling interface to AWS Batch. .. code-block:: bash $ pip install torchx[kubernetes] $ torchx run --scheduler aws_batch --scheduler_args queue=torchx utils.echo --image alpine:latest --msg hello aws_batch://torchx_user/1234 $ torchx status aws_batch://torchx_user/1234 ... Authentication is loaded from the environment using the ``boto3`` credential handling. **Config Options** .. runopts:: class: torchx.schedulers.aws_batch_scheduler.create_scheduler **Mounts** This class supports bind mounting host directories, efs volumes and host devices. * bind mount: ``type=bind,src=<host path>,dst=<container path>[,readonly]`` * efs volume: ``type=volume,src=<efs id>,dst=<container path>[,readonly]`` * devices: ``type=device,src=/dev/infiniband/uverbs0,[dst=<container path>][,perm=rwm]`` See :py:func:`torchx.specs.parse_mounts` for more info. For other filesystems such as FSx you can mount them onto the host and bind mount them into your job: For Elastic Fabric Adapter (EFA) you'll need to use a device mount to mount them into the container: **Compatibility** .. compatibility:: type: scheduler features: cancel: true logs: true distributed: true describe: | Partial support. AWSBatchScheduler will return job and replica status but does not provide the complete original AppSpec. workspaces: true mounts: true elasticity: false """ def __init__( self, session_name: str, # pyre-fixme[2]: Parameter annotation cannot be `Any`. client: Optional[Any] = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. log_client: Optional[Any] = None, docker_client: Optional["DockerClient"] = None, ) -> None: # NOTE: make sure any new init options are supported in create_scheduler(...) super().__init__("aws_batch", session_name, docker_client=docker_client) # pyre-fixme[4]: Attribute annotation cannot be `Any`. self.__client = client # pyre-fixme[4]: Attribute annotation cannot be `Any`. self.__log_client = log_client @property # pyre-fixme[3]: Return annotation cannot be `Any`. def _client(self) -> Any: if self.__client: return self.__client return _local_session().client("batch") @property # pyre-fixme[3]: Return annotation cannot be `Any`. def _log_client(self) -> Any: if self.__log_client: return self.__log_client return _local_session().client("logs")
[docs] def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str: cfg = dryrun_info._cfg assert cfg is not None, f"{dryrun_info} missing cfg" images_to_push = dryrun_info.request.images_to_push self.push_images(images_to_push) req = dryrun_info.request self._client.register_job_definition(**req.job_def) batch_job_req = { **{ "jobName":, "jobQueue": req.queue, "jobDefinition":, "tags": req.job_def["tags"], }, **({"shareIdentifier": req.share_id} if req.share_id is not None else {}), } self._client.submit_job(**batch_job_req) return f"{req.queue}:{}"
def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJob]: queue = cfg.get("queue") if not isinstance(queue, str): raise TypeError(f"config value 'queue' must be a string, got {queue}") share_id = cfg.get("share_id") priority = cfg["priority"] name_suffix = f"-{share_id}" if share_id is not None else "" name = make_unique(f"{}{name_suffix}") assert len(app.roles) <= 5, ( "AWS Batch only supports <= 5 roles (NodeGroups)." " See:" ) # map any local images to the remote image images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) nodes = [] node_idx = 0 for role_idx, role in enumerate(app.roles): values = macros.Values( img_root="", app_id=name, # this only resolves for role.args # if the entrypoint is run with sh or bash # but won't actually work for macros in env vars replica_id="$AWS_BATCH_JOB_NODE_INDEX", rank0_env="AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS", ) role = values.apply(role) role.env[ENV_TORCHX_ROLE_IDX] = str(role_idx) role.env[ENV_TORCHX_ROLE_NAME] = str( nodes.append( _role_to_node_properties( role, start_idx=node_idx, privileged=cfg["privileged"], job_role_arn=cfg.get("job_role_arn"), execution_role_arn=cfg.get("execution_role_arn"), ) ) node_idx += role.num_replicas job_def = { **{ "jobDefinitionName": name, "type": "multinode", "nodeProperties": { "numNodes": node_idx, "mainNode": 0, "nodeRangeProperties": nodes, }, "retryStrategy": { "attempts": max(max(role.max_retries for role in app.roles), 1), "evaluateOnExit": [ {"onExitCode": "0", "action": "EXIT"}, ], }, "tags": { TAG_TORCHX_VER: torchx.__version__, TAG_TORCHX_APPNAME:, TAG_TORCHX_USER: cfg.get("user"), **app.metadata, }, }, **({"schedulingPriority": priority} if share_id is not None else {}), } req = BatchJob( name=name, queue=queue, share_id=share_id, job_def=job_def, images_to_push=images_to_push, ) return AppDryRunInfo(req, repr) def _cancel_existing(self, app_id: str) -> None: job_id = self._get_job_id(app_id) self._client.terminate_job( jobId=job_id, reason="killed via torchx CLI", ) def _run_opts(self) -> runopts: opts = runopts() opts.add("queue", type_=str, help="queue to schedule job in", required=True) opts.add( "user", type_=str, default=getpass.getuser(), help="The username to tag the job with. `getpass.getuser()` if not specified.", ) opts.add( "privileged", type_=bool, default=False, help="If true runs the container with elevated permissions." " Equivalent to running with `docker run --privileged`.", ) opts.add( "share_id", type_=str, help="The share identifier for the job. " "This must be set if and only if the job queue has a scheduling policy.", ) opts.add( "priority", type_=int, default=0, help="The scheduling priority for the job within the context of share_id. " "Higher number (between 0 and 9999) means higher priority. " "This will only take effect if the job queue has a scheduling policy.", ) opts.add( "job_role_arn", type_=str, help="The Amazon Resource Name (ARN) of the IAM role that the container can assume for AWS permissions.", ) opts.add( "execution_role_arn", type_=str, help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.", ) return opts def _get_job_id(self, app_id: str) -> Optional[str]: queue, name = app_id.split(":") for resp in self._client.get_paginator("list_jobs").paginate( jobQueue=queue, filters=[{"name": "JOB_NAME", "values": [name]}], ): job_summary_list = resp["jobSummaryList"] if job_summary_list: return job_summary_list[0]["jobArn"] return None def _get_job( self, app_id: str, rank: Optional[int] = None ) -> Optional[Dict[str, Any]]: job_id = self._get_job_id(app_id) if not job_id: return None if rank is not None: job_id += f"#{rank}" jobs = self._client.describe_jobs(jobs=[job_id])["jobs"] if len(jobs) == 0: return None return jobs[0]
[docs] def describe(self, app_id: str) -> Optional[DescribeAppResponse]: job = self._get_job(app_id) if job is None: return None # each AppDef.role maps to a batch NodeGroup roles = [] node_properties = job["nodeProperties"] num_nodes = node_properties["numNodes"] for node_group in node_properties["nodeRangeProperties"]: container = node_group["container"] env = {opt["name"]: opt["value"] for opt in container["environment"]} command = container["command"] roles.append( Role( name=env.get(ENV_TORCHX_ROLE_NAME, DEFAULT_ROLE_NAME), num_replicas=_parse_num_replicas( node_group["targetNodes"], num_nodes ), image=container["image"], entrypoint=command[0] if command else MISSING, args=command[1:], env=env, resource=resource_from_resource_requirements( container["resourceRequirements"] ), ) ) return DescribeAppResponse( app_id=app_id, state=JOB_STATE[job["status"]], roles=roles, # TODO: role statuses ui_url=_job_ui_url(job["jobArn"]), )
[docs] def log_iter( self, app_id: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime] = None, until: Optional[datetime] = None, should_tail: bool = False, streams: Optional[Stream] = None, ) -> Iterable[str]: if streams not in (None, Stream.COMBINED): raise ValueError("AWSBatchScheduler only supports COMBINED log stream") job = self._get_job(app_id) if job is None: return [] node_properties = job["nodeProperties"] nodes = node_properties["nodeRangeProperties"] global_idx = -1 # finds the global idx of the node that matches the role's k'th replica for i, node in enumerate(nodes): container = node["container"] env = {opt["name"]: opt["value"] for opt in container["environment"]} node_role = env.get(ENV_TORCHX_ROLE_NAME, DEFAULT_ROLE_NAME) start_idx, _ = _parse_start_and_end_idx( node["targetNodes"], node_properties["numNodes"], ) # k with the replica idx within the role # so add k to the start index of the node group to get the global idx global_idx = start_idx + k if role_name == node_role: break assert global_idx != -1, ( f"Role `{role_name}`'s replica `{k}` not found in job `{job['jobName']}.\n" f"Inspect the job by running `aws batch describe-jobs --jobs {job['jobId']}`" ) job = self._get_job(app_id, rank=global_idx) if not job: return [] if "status" in job and job["status"] == "RUNNING": stream_name = job["container"]["logStreamName"] else: attempts = job["attempts"] if len(attempts) == 0: return [] attempt = attempts[-1] container = attempt["container"] stream_name = container["logStreamName"] iterator = self._stream_events( app_id, stream_name, since=since, until=until, should_tail=should_tail, ) if regex: return filter_regex(regex, iterator) else: return iterator
[docs] def list(self) -> List[ListAppResponse]: # TODO: get queue name input instead of iterating over all queues? all_apps = [] for resp in self._client.get_paginator("describe_job_queues").paginate(): queue_names = [queue["jobQueueName"] for queue in resp["jobQueues"]] for qn in queue_names: all_apps.extend(self._list_by_queue(qn)) return all_apps
def _list_by_queue(self, queue_name: str) -> List[ListAppResponse]: # By default, only running jobs are listed by batch/boto client's list_jobs API # When 'filters' parameter is specified, jobs with all statuses are listed # So use AFTER_CREATED_AT filter to list jobs in all statuses # milli_seconds_after_epoch can later be used to list jobs by timeframe MS_AFTER_EPOCH = "1" EVERY_STATUS = {"name": "AFTER_CREATED_AT", "values": [MS_AFTER_EPOCH]} jobs = [] for resp in self._client.get_paginator("list_jobs").paginate( jobQueue=queue_name, filters=[EVERY_STATUS], # describe-jobs API can take up to 100 jobIds PaginationConfig={"MaxItems": 100}, ): # tag is used to filter torchx jobs # list_jobs() API only returns a job summary which does not include the job's tag # so we need to call the describe_jobs API. # Ideally batch lets us pass tags as a filter to list_jobs API # but this is currently not supported job_ids = [js["jobId"] for js in resp["jobSummaryList"]] for jobdesc in self._get_torchx_submitted_jobs(job_ids): jobs.append( ListAppResponse( app_id=f"{queue_name}:{jobdesc['jobName']}", state=JOB_STATE[jobdesc["status"]], ) ) return jobs def _get_torchx_submitted_jobs(self, job_ids: List[str]) -> List[Dict[str, Any]]: if not job_ids: return [] return [ jobdesc for jobdesc in self._client.describe_jobs(jobs=job_ids)["jobs"] if TAG_TORCHX_VER in jobdesc["tags"] ] def _stream_events( self, app_id: str, stream_name: str, since: Optional[datetime] = None, until: Optional[datetime] = None, should_tail: bool = False, ) -> Iterable[str]: next_token = None last_event_timestamp: int = 0 # in millis since epoch while True: args = {} if next_token is not None: args["nextToken"] = next_token if until is not None: args["endTime"] = to_millis_since_epoch(until) if since is not None: args["startTime"] = to_millis_since_epoch(since) try: response = self._log_client.get_log_events( logGroupName="/aws/batch/job", logStreamName=stream_name, limit=10000, startFromHead=True, **args, ) except self._log_client.exceptions.ResourceNotFoundException: return [] # noqa: B901 if response["nextForwardToken"] == next_token: if ( not until or last_event_timestamp < to_millis_since_epoch(until) ) and should_tail: if not is_terminal(none_throws(self.describe(app_id)).state): since = to_datetime(last_event_timestamp) continue break next_token = response["nextForwardToken"] for event in response["events"]: last_event_timestamp = event["timestamp"] yield event["message"] + "\n"
[docs]def create_scheduler( session_name: str, # pyre-fixme[2]: Parameter annotation cannot be `Any`. client: Optional[Any] = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. log_client: Optional[Any] = None, docker_client: Optional["DockerClient"] = None, **kwargs: object, ) -> AWSBatchScheduler: return AWSBatchScheduler( session_name=session_name, client=client, log_client=log_client, docker_client=docker_client, )


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources