AWS SageMaker

class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerScheduler(session_name: str, client: Optional[Any] = None, docker_client: Optional[DockerClient] = None)[source]

Bases: DockerWorkspaceMixin, Scheduler[AWSSageMakerOpts]

AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.

$ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello
$ torchx status aws_batch://torchx_user/1234

Authentication is loaded from the environment using the boto3 credential handling.

Config Options


    required arguments:
        role=ROLE (str)
            an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
        instance_type=INSTANCE_TYPE (str)
            type of EC2 instance to use for training, for example, 'ml.c4.xlarge'

    optional arguments:
        instance_count=INSTANCE_COUNT (int, 1)
            number of Amazon EC2 instances to use for training. Required if instance_groups is not set.
        user=USER (str, runner)
            the username to tag the job with. `getpass.getuser()` if not specified.
        keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS (int, None)
            the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs.
        volume_size=VOLUME_SIZE (int, None)
            size in GB of the storage volume to use for storing input and output data during training (default: 30).
        volume_kms_key=VOLUME_KMS_KEY (str, None)
            KMS key ID for encrypting EBS volume attached to the training instance.
        max_run=MAX_RUN (int, None)
            timeout in seconds for training (default: 24 * 60 * 60).
        input_mode=INPUT_MODE (str, None)
            the input mode that the algorithm supports (default: ‘File’).
        output_path=OUTPUT_PATH (str, None)
            S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution.
        output_kms_key=OUTPUT_KMS_KEY (str, None)
            KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3).
        base_job_name=BASE_JOB_NAME (str, None)
            prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp.
        tags=TAGS (typing.List[typing.Dict[str, str]], None)
            list of tags for labeling a training job.
        subnets=SUBNETS (typing.List[str], None)
            list of subnet ids. If not specified training job will be created without VPC config.
        security_group_ids=SECURITY_GROUP_IDS (typing.List[str], None)
            list of security group ids. If not specified training job will be created without VPC config.
        model_uri=MODEL_URI (str, None)
            URI where a pre-trained model is stored, either locally or in S3.
        model_channel_name=MODEL_CHANNEL_NAME (str, None)
            name of the channel where ‘model_uri’ will be downloaded (default: ‘model’).
        metric_definitions=METRIC_DEFINITIONS (typing.List[typing.Dict[str, str]], None)
            list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs.
        encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC (bool, None)
            specifies whether traffic between training containers is encrypted for the training job (default: False).
        use_spot_instances=USE_SPOT_INSTANCES (bool, None)
            specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set.
        max_wait=MAX_WAIT (int, None)
            timeout in seconds waiting for spot training job.
        checkpoint_s3_uri=CHECKPOINT_S3_URI (str, None)
            S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.
        checkpoint_local_path=CHECKPOINT_LOCAL_PATH (str, None)
            local path that the algorithm writes its checkpoints to.
        debugger_hook_config=DEBUGGER_HOOK_CONFIG (bool, None)
            configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False.
        enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS (bool, None)
            enable SageMaker Metrics Time Series.
        enable_network_isolation=ENABLE_NETWORK_ISOLATION (bool, None)
            specifies whether container will run in network isolation mode (default: False).
        disable_profiler=DISABLE_PROFILER (bool, None)
            specifies whether Debugger monitoring and profiling will be disabled (default: False).
        environment=ENVIRONMENT (typing.Dict[str, str], None)
            environment variables to be set for use during training job
        max_retry_attempts=MAX_RETRY_ATTEMPTS (int, None)
            number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts.
        source_dir=SOURCE_DIR (str, None)
            absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory)
        git_config=GIT_CONFIG (typing.Dict[str, str], None)
            git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token.
        hyperparameters=HYPERPARAMETERS (typing.Dict[str, str], None)
            dictionary containing the hyperparameters to initialize this estimator with.
        container_log_level=CONTAINER_LOG_LEVEL (int, None)
            log level to use within the container (default: logging.INFO).
        code_location=CODE_LOCATION (str, None)
            S3 prefix URI where custom code is uploaded.
        dependencies=DEPENDENCIES (typing.List[str], None)
            list of absolute or relative paths to directories with any additional libraries that should be exported to the container.
        training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE (str, None)
            specifies how SageMaker accesses the Docker image that contains the training algorithm.
        training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN (str, None)
            Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted.
        disable_output_compression=DISABLE_OUTPUT_COMPRESSION (bool, None)
            when set to true, Model is uploaded to Amazon S3 without compression after training finishes.
        enable_infra_check=ENABLE_INFRA_CHECK (bool, None)
            specifies whether it is running Sagemaker built-in infra check jobs.
        image_repo=IMAGE_REPO (str, None)
            (remote jobs) the image repository to use when pushing patched images, must have push access. Ex:
        quiet=QUIET (bool, False)
            whether to suppress verbose output for image building. Defaults to ``False``.



Scheduler Support

Fetch Logs

Distributed Jobs


Cancel Job


Describe Job

Partial support. SageMakerScheduler will return job and replica status but does not provide the complete original AppSpec.

Workspaces / Patching




describe(app_id: str) Optional[DescribeAppResponse][source]

Describes the specified application.


AppDef description or None if the app does not exist.

list() List[ListAppResponse][source]

For apps launched on the scheduler, this API returns a list of ListAppResponse objects each of which have app id and its status. Note: This API is in prototype phase and is subject to change.

log_iter(app_id: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime] = None, until: Optional[datetime] = None, should_tail: bool = False, streams: Optional[Stream] = None) Iterable[str][source]

Returns an iterator to the log lines of the k``th replica of the ``role. The iterator ends when all qualifying log lines have been read.

If the scheduler supports time-based cursors fetching log lines for custom time ranges, then the since, until fields are honored, otherwise they are ignored. Not specifying since and until is equivalent to getting all available log lines. If the until is empty, then the iterator behaves like tail -f, following the log output until the job reaches a terminal state.

The exact definition of what constitutes a log is scheduler specific. Some schedulers may consider stderr or stdout as the log, others may read the logs from a log file.

Behaviors and assumptions:

  1. Produces an undefined-behavior if called on an app that does not exist The caller should check that the app exists using exists(app_id) prior to calling this method.

  2. Is not stateful, calling this method twice with same parameters returns a new iterator. Prior iteration progress is lost.

  3. Does not always support log-tailing. Not all schedulers support live log iteration (e.g. tailing logs while the app is running). Refer to the specific scheduler’s documentation for the iterator’s behavior.

3.1 If the scheduler supports log-tailing, it should be controlled

by should_tail parameter.

  1. Does not guarantee log retention. It is possible that by the time this method is called, the underlying scheduler may have purged the log records for this application. If so this method raises an arbitrary exception.

  2. If should_tail is True, the method only raises a StopIteration exception when the accessible log lines have been fully exhausted and the app has reached a final state. For instance, if the app gets stuck and does not produce any log lines, then the iterator blocks until the app eventually gets killed (either via timeout or manually) at which point it raises a StopIteration.

    If should_tail is False, the method raises StopIteration when there are no more logs.

  3. Need not be supported by all schedulers.

  4. Some schedulers may support line cursors by supporting __getitem__ (e.g. iter[50] seeks to the 50th log line).

  5. Whitespace is preserved, each new line should include \n. To

    support interactive progress bars the returned lines don’t need to include \n but should then be printed without a newline to correctly handle \r carriage returns.


streams – The IO output streams to select. One of: combined, stdout, stderr. If the selected stream isn’t supported by the scheduler it will throw an ValueError.


An Iterator over log lines of the specified role replica


NotImplementedError – if the scheduler does not support log iteration

schedule(dryrun_info: AppDryRunInfo[AWSSageMakerJob]) str[source]

Same as submit except that it takes an AppDryRunInfo. Implementers are encouraged to implement this method rather than directly implementing submit since submit can be trivially implemented by:

dryrun_info = self.submit_dryrun(app, cfg)
return schedule(dryrun_info)
class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerJob(job_name: str, job_def: Dict[str, Any], images_to_push: Dict[str, Tuple[str, str]])[source]

Jobs defined the key values that is requried to schedule a job. This will be the value of request in the AppDryRunInfo object.

  • job_name: defines the job name shown in SageMaker

  • job_def: defines the job description that will be used to schedule the job on SageMaker

  • images_to_push: used by torchx to push to image_repo


torchx.schedulers.aws_sagemaker_scheduler.create_scheduler(session_name: str, **kwargs: object) AWSSageMakerScheduler[source]


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources