• Docs >
  • Training Session Manager (TSM)
Shortcuts

Training Session Manager (TSM)

Training Session Manager (TSM) is a set of programmatic APIs that helps you launch your distributed (PyTorch) applications onto the supported schedulers. Whereas torchelastic is deployed per container and manages worker processes and coordinates restart behaviors, TSM provides a way to launch the distributed job while natively supporting jobs that are (locally) managed by torchelastic.

Note

TSM is currently an experimental module and is subject to change for future releases of torchelastic. At the moment TSM only ships with a LocalScheduler allowing the user to run the distributed application locally on a dev host.

Usage Overview

Below is a simple program that locally launches a multi-role (trainer, parameter server, reader) distributed application. Each Role runs multiple replicas. In reality each replica runs on its own container on a host. An Application is made up to one or more such Roles.

import getpass
import torchelastic.tsm.driver as tsm

username = getpass.getuser()
train_project_dir = tsm.Container(image=f"/home/{username}/pytorch_trainer")
reader_project_dir = tsm.Container(image=f"/home/{username}/pytorch_reader")

trainer = tsm.ElasticRole(name="trainer", nprocs_per_node=2, nnodes="4:4")
             .runs("train_main.py", "--epochs", "50", MY_ENV_VAR="foobar")
             .on(train_project_dir)
             .replicas(4)

ps = tsm.Role(name="parameter_server")
        .run("ps_main.py")
        .on(train_project_dir)
        .replicas(10)

reader = tsm.Role(name="reader")
            .runs("reader/reader_main.py", "--buffer_size", "1024")
            .on(reader_project_dir)
            .replicas(1)

app = tsm.Application(name="my_train_job").of(trainer, ps, reader)

session = tsm.session(name="my_session")
app_id = session.run(app, scheduler="local")
session.wait(app_id)

In the example above, we have done a few things:

  1. Created and ran a distributed training application that runs a total of 4 + 10 + 1 = 15 containers (just processes since we used a local scheduler).

  2. trainer run wrapped with TorchElastic.

  3. The trainer and ps run from the same image (but different containers): /home/$USER/pytorch_trainer and the reader runs from the image: /home/$USER/pytorch_reader. The images map to a local directory because we are using a local scheduler. For other non-trivial schedulers a container could map to a Docker image, tarball, rpm, etc.

  4. The main entrypoints are relative to the container image’s root dir. For example, the trainer runs /home/$USER/pytorch_trainer/train_main.py.

  5. Arguments to each role entrypoint are passed as *args after the entrypoint CMD.

  6. Environment variables to each role entrypoint are passed as **kwargs after the arguments.

  7. The session object has action APIs on the app (see Session).

API Documentation

Session

class torchelastic.tsm.driver.api.Session(name: str)[source]

Entrypoint and client-facing API for TSM. Has the methods for the user to define and act upon Applications. The Session is stateful and represents a logical workspace of the user. It can be backed by a service (e.g. TSM server) for persistence or can be standalone with no persistence meaning that the Session lasts only during the duration of the hosting process (see the attach() API for instructions on re-parenting apps between sessions).

abstract describe(app_handle: str) → Optional[torchelastic.tsm.driver.api.Application][source]

Reconstructs the application (to the best extent) given the app handle. Note that the reconstructed application may not be the complete app as it was submitted via the run API. How much of the app can be reconstructed is scheduler dependent.

Returns

Application or None if the app does not exist anymore or if the scheduler does not support describing the app handle

dryrun(app: torchelastic.tsm.driver.api.Application, scheduler: str = 'default', cfg: Optional[torchelastic.tsm.driver.api.RunConfig] = None) → torchelastic.tsm.driver.api.AppDryRunInfo[source]

Dry runs an app on the given scheduler with the provided run configs. Does not actually submit the app but rather returns what would have been submitted. The returned AppDryRunInfo is pretty formatted and can be printed or logged directly.

Usage:

dryrun_info = session.dryrun(app, scheduler="local", cfg)
print(dryrun_info)
abstract list() → Dict[str, torchelastic.tsm.driver.api.Application][source]

Returns the applications that were run with this session mapped by the app handle. The persistence of the session is implementation dependent.

abstract log_lines(app_handle: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime.datetime] = None, until: Optional[datetime.datetime] = None) → Iterable[source]

Returns an iterator over the log lines of the specified job container.

Note

  1. k is the node (host) id NOT the rank.

  2. since and until need not always be honored (depends on scheduler).

Warning

The semantics and guarantees of the returned iterator is highly scheduler dependent. See torchelastic.tsm.driver.api.Scheduler.log_iter for the high-level semantics of this log iterator. For this reason it is HIGHLY DISCOURAGED to use this method for generating output to pass to downstream functions/dependencies. This method DOES NOT guarantee that 100% of the log lines are returned. It is totally valid for this method to return no or partial log lines if the scheduler has already totally or partially purged log records for the application.

Usage:

app_handle = session.run(app, scheduler="local", cfg=RunConfig())

print("== trainer node 0 logs ==")
for line in session.log_lines(app_handle, "trainer", k=0):
   print(line)

Discouraged anti-pattern:

# DO NOT DO THIS!
# parses accuracy metric from log and reports it for this experiment run
accuracy = -1
for line in session.log_lines(app_handle, "trainer", k=0):
   if matches_regex(line, "final model_accuracy:[0-9]*"):
       accuracy = parse_accuracy(line)
       break
report(experiment_name, accuracy)
Parameters
  • app_handle – application handle

  • role_name – role within the app (e.g. trainer)

  • k – k-th replica of the role to fetch the logs for

  • regex – optional regex filter, returns all lines if left empty

  • since – datetime based start cursor. If left empty begins from the first log line (start of job).

  • until – datetime based end cursor. If left empty, follows the log output until the job completes and all log lines have been consumed.

Returns

An iterator over the role k-th replica of the specified application.

Raises
  • UnknownAppException – if the app does not exist in the scheduler

  • SessionMismatchException – if the app handle does not belong to this session

name()str[source]
Returns

The name of this session.

run(app: torchelastic.tsm.driver.api.Application, scheduler: str = 'default', cfg: Optional[torchelastic.tsm.driver.api.RunConfig] = None)str[source]

Runs the given application in the specified mode.

Note

sub-classes of Session should implement schedule method rather than overriding this method directly.

Returns

An application handle that is used to call other action APIs on the app.

Raises

AppNotReRunnableException – if the session/scheduler does not support re-running attached apps

run_opts() → Dict[str, torchelastic.tsm.driver.api.runopts][source]

Returns the runopts for the supported scheduler backends.

Usage:

local_runopts = session.run_opts()["local"]
print("local scheduler run options: {local_runopts}")
Returns

A map of scheduler backend to its runopts

abstract schedule(dryrun_info: torchelastic.tsm.driver.api.AppDryRunInfo)str[source]

Actually runs the application from the given dryrun info. Useful when one needs to overwrite a parameter in the scheduler request that is not configurable from one of the object APIs.

Warning

Use sparingly since abusing this method to overwrite many parameters in the raw scheduler request may lead to your usage of TSM going out of compliance in the long term. This method is intended to unblock the user from experimenting with certain scheduler-specific features in the short term without having to wait until TSM exposes scheduler features in its APIs.

Note

It is recommended that sub-classes of Session implement this method instead of directly implementing the run method.

Usage:

dryrun_info = session.dryrun(app, scheduler="default", cfg)

# overwrite parameter "foo" to "bar"
dryrun_info.request.foo = "bar"

app_handle = session.submit(dryrun_info)
abstract scheduler_backends() → List[str][source]

Returns a list of all supported scheduler backends. All session implementations must support a “default” scheduler backend and document what the default scheduler is.

abstract status(app_handle: str) → Optional[torchelastic.tsm.driver.api.AppStatus][source]
Returns

The status of the application, or None if the app does not exist anymore (e.g. was stopped in the past and removed from the scheduler’s backend).

abstract stop(app_handle: str)None[source]

Stops the application, effectively directing the scheduler to cancel the job. Does nothing if the app does not exist.

Note

This method returns as soon as the cancel request has been submitted to the scheduler. The application will be in a RUNNING state until the scheduler actually terminates the job. If the scheduler successfully interrupts the job and terminates it the final state will be CANCELLED otherwise it will be FAILED.

Raises

SessionMismatchException – if the app handle does not belong to this session

abstract wait(app_handle: str) → Optional[torchelastic.tsm.driver.api.AppStatus][source]

Block waits (indefinitely) for the application to complete. Possible implementation:

while(True):
    app_status = status(app)
    if app_status.is_terminal():
        return
    sleep(10)
Returns

The terminal status of the application, or None if the app does not exist anymore

Containers and Resource

class torchelastic.tsm.driver.api.Container(image: str, resources: torchelastic.tsm.driver.api.Resource = Resource(cpu=-1, gpu=-1, memMB=-1, capabilities={}), port_map: Dict[str, int] = <factory>)[source]

Represents the specifications of the container that instances of Roles run on. Maps to the container abstraction that the underlying scheduler supports. This could be an actual container (e.g. Docker) or a physical instance depending on the scheduler.

An image is a software bundle that is installed on a Container. The container on the scheduler dictates what an image actually is. An image could be as simple as a tar-ball or map to a docker image. The scheduler typically knows how to “pull” the image given an image name (str), which could be a simple name (e.g. docker image) or a url (e.g. s3://path/my_image.tar).

A Resource can be bound to a specific scheduler backend or SchedulerBackend.ALL (default) to specify that the same Resource is to be used for all schedulers.

Usage:

# define resource for all schedulers
my_container = Container(image="pytorch/torch:1")
                  .require(Resource(cpu=1, gpu=1, memMB=500))
                  .ports(tcp_store=8080, tensorboard=8081)

# define resource for a specific scheduler
my_container = Container(image="pytorch/torch:1")
                  .require(Resource(cpu=1, gpu=1, memMB=500), "custom_scheduler")
                  .ports(tcp_store=8080, tensorboard=8081)
class torchelastic.tsm.driver.api.Resource(cpu: int, gpu: int, memMB: int, capabilities: Dict[str, Any] = <factory>)[source]

Represents resource requirements for a Container.

Parameters
  • cpu – number of cpu cores (note: not hyper threads)

  • gpu – number of gpus

  • memMB – MB of ram

  • capabilities – additional hardware specs (interpreted by scheduler)

Roles and Applications

class torchelastic.tsm.driver.api.Role(name: str, entrypoint: str = '<MISSING>', args: List[str] = <factory>, env: Dict[str, str] = <factory>, container: torchelastic.tsm.driver.api.Container = Container(image='<MISSING>', resources=Resource(cpu=-1, gpu=-1, memMB=-1, capabilities={}), port_map={}), num_replicas: int = 1, max_retries: int = 0, retry_policy: torchelastic.tsm.driver.api.RetryPolicy = <RetryPolicy.APPLICATION: 'APPLICATION'>)[source]

A set of nodes that perform a specific duty within the Application. Examples:

  1. Distributed data parallel app - made up of a single role (trainer).

  2. App with parameter server - made up of multiple roles (trainer, ps).

Usage:

trainer = Role(name="trainer")
            .runs("my_trainer.py", "--arg", "foo", ENV_VAR="FOOBAR")
            .on(container)
            .replicas(4)
Parameters
  • name – name of the role

  • entrypoint – command (within the container) to invoke the role

  • args – commandline arguments to the entrypoint cmd

  • env – environment variable mappings

  • container – container to run in

  • replicas – number of container replicas to run

  • max_retries – max number of retries before giving up

  • retry_policy – retry behavior upon replica failures

  • deployment_preference – hint to the scheduler on how to best deploy and manage replicas of this role

class torchelastic.tsm.driver.api.ElasticRole(name: str, **launch_kwargs)[source]

A Role for which the user provided entrypoint is executed with the torchelastic agent (in the container). Note that the torchelastic agent invokes multiple copies of entrypoint.

For more information about torchelastic see torchelastic quickstart docs.

Important

It is the responsibility of the user to ensure that the container’s image includes torchelastic. Since TSM has no control over the build process of the image, it cannot automatically include torchelastic in the container’s image.

The following example launches 2 replicas (nodes) of an elastic my_train_script.py that is allowed to scale between 2 to 4 nodes. Each node runs 8 workers which are allowed to fail and restart a maximum of 3 times.

Warning

replicas MUST BE an integer between (inclusive) nnodes. That is, ElasticRole("trainer", nnodes="2:4").replicas(5) is invalid and will result in undefined behavior.

elastic_trainer = ElasticRole("trainer", nproc_per_node=8, nnodes="2:4", max_restarts=3)
                   .runs("my_train_script.py", "--script_arg", "foo", "--another_arg", "bar")
                   .on(container)
                   .replicas(2)
# effectively runs:
#    python -m torchelastic.distributed.launch
#        --nproc_per_node 8
#        --nnodes 2:4
#        --max_restarts 3
#        my_train_script.py --script_arg foo --another_arg bar
class torchelastic.tsm.driver.api.macros[source]

Defines macros that can be used with Role.entrypoint and Role.args. The macros will be substituted at runtime to their actual values.

Available macros:

  1. img_root - root directory of the pulled image on the container

  2. app_id - application id as assigned by the scheduler

  3. replica_id - unique id for each instance of a replica of a Role,

    for instance a role with 3 replicas could have the 0, 1, 2 as replica ids. Note that when the container fails and is replaced, the new container will have the same replica_id as the one it is replacing. For instance if node 1 failed and was replaced by the scheduler the replacing node will also have replica_id=1.

Example:

# runs: hello_world.py --app_id ${app_id}
trainer = Role(name="trainer").runs("hello_world.py", "--app_id", macros.app_id)
app = Application("train_app").of(trainer)
app_handle = session.run(app, scheduler="local", cfg=RunConfig())
class torchelastic.tsm.driver.api.RetryPolicy(value)[source]

Defines the retry policy for the Roles in the Application. The policy defines the behavior when the role replica encounters a failure:

  1. unsuccessful (non zero) exit code

  2. hardware/host crashes

  3. preemption

  4. eviction

Note

Not all retry policies are supported by all schedulers. However all schedulers must support RetryPolicy.APPLICATION. Please refer to the scheduler’s documentation for more information on the retry policies they support and behavior caveats (if any).

  1. REPLICA: Replaces the replica instance. Surviving replicas are untouched.

    Use with ElasticRole to have torchelastic coordinate restarts and membership changes. Otherwise, it is up to the application to deal with failed replica departures and replacement replica admittance.

  2. APPLICATION: Restarts the entire application.

class torchelastic.tsm.driver.api.Application(name: str, roles: List[torchelastic.tsm.driver.api.Role] = <factory>)[source]

Represents a distributed application made up of multiple Roles. Contains the necessary information for the driver to submit this app to the scheduler.

Extending TSM

TSM is built in a “plug-n-play” manner. While it ships out-of-the-box with certain schedulers and session implementations, you can implement your own to fit the needs of your PyTorch application and infrastructure. This section introduces the interfaces that were meant to be subclassed and extended.

Scheduler

class torchelastic.tsm.driver.api.Scheduler(session_name: str)[source]

An interface abstracting functionalities of a scheduler. Implementors need only implement those methods annotated with @abc.abstractmethod.

cancel(app_id: str)None[source]

Cancels/kills the application. This method is idempotent within the same thread and is safe to call on the same application multiple times. However when called from multiple threads/processes on the same app the exact semantics of this method depends on the idempotency guarantees of the underlying scheduler API.

Note

This method does not block for the application to reach a cancelled state. To ensure that the application reaches a terminal state use the wait API.

abstract describe(app_id: str) → Optional[torchelastic.tsm.driver.api.DescribeAppResponse][source]

Describes the specified application.

Returns

Application description or None if the app does not exist.

exists(app_id: str)[source]
Returns

True if the app exists (was submitted), False otherwise

log_iter(app_id: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime.datetime] = None, until: Optional[datetime.datetime] = None) → Iterable[source]

Returns an iterator to the log lines of the k``th replica of the ``role. The iterator ends end all qualifying log lines have been read.

If the scheduler supports time-based cursors fetching log lines for custom time ranges, then the since, until fields are honored, otherwise they are ignored. Not specifying since and until is equivalent to getting all available log lines. If the until is empty, then the iterator behaves like tail -f, following the log output until the job reaches a terminal state.

The exact definition of what constitutes a log is scheduler specific. Some schedulers may consider stderr or stdout as the log, others may read the logs from a log file.

Behaviors and assumptions:

  1. Produces an undefined-behavior if called on an app that does not exist The caller should check that the app exists using exists(app_id) prior to calling this method.

  2. Is not stateful, calling this method twice with same parameters returns a new iterator. Prior iteration progress is lost.

  3. Does not always support log-tailing. Not all schedulers support live log iteration (e.g. tailing logs while the app is running). Refer to the specific scheduler’s documentation for the iterator’s behavior.

  4. Does not guarantee log retention. It is possible that by the time this method is called, the underlying scheduler may have purged the log records for this application. If so this method raises an arbitrary exception.

  5. Only raises a StopIteration exception when the accessible log lines have been fully exhausted and the app has reached a final state. For instance, if the app gets stuck and does not produce any log lines, then the iterator blocks until the app eventually gets killed (either via timeout or manually) at which point it raises a StopIteration.

  6. Need not be supported by all schedulers.

  7. Some schedulers may support line cursors by supporting __getitem__ (e.g. iter[50] seeks to the 50th log line).

Returns

An Iterator over log lines of the specified role replica

Raises

NotImplementedError - if the scheduler does not support log iteration

run_opts() → torchelastic.tsm.driver.api.runopts[source]

Returns the run configuration options expected by the scheduler. Basically a --help for the run API.

abstract schedule(dryrun_info: torchelastic.tsm.driver.api.AppDryRunInfo)str[source]

Same as submit except that it takes an AppDryrunInfo. Implementors are encouraged to implement this method rather than directly implementing submit since submit can be trivially implemented by:

dryrun_info = self.submit_dryrun(app, cfg)
return schedule(dryrun_info)
submit(app: torchelastic.tsm.driver.api.Application, cfg: torchelastic.tsm.driver.api.RunConfig)str[source]

Submits the application to be run by the scheduler.

Returns

The application id that uniquely identifies the submitted app.

submit_dryrun(app: torchelastic.tsm.driver.api.Application, cfg: torchelastic.tsm.driver.api.RunConfig) → torchelastic.tsm.driver.api.AppDryRunInfo[source]

Rather than submitting the request to run the app, returns the request object that would have been submitted to the underlying service. The type of the request object is scheduler dependent. This method can be used to dry-run an application. Please refer to the scheduler implementation’s documentation regarding the actual return type.

class torchelastic.tsm.driver.local_scheduler.LocalScheduler(session_name: str, cache_size: int = 100)[source]

Schedules on localhost. Containers are modeled as processes and certain properties of the container that are either not relevant or that cannot be enforced for localhost runs are ignored. Properties that are ignored:

  1. Resource requirements

  2. Container limit enforcements

  3. Retry policies

  4. Retry counts (no retries supported)

  5. Deployment preferences

..note:: Use this scheduler sparingly since an application

that runs successfully on a session backed by this scheduler may not work on an actual production cluster using a different scheduler.

describe(app_id: str) → Optional[torchelastic.tsm.driver.api.DescribeAppResponse][source]

Describes the specified application.

Returns

Application description or None if the app does not exist.

log_iter(app_id: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime.datetime] = None, until: Optional[datetime.datetime] = None) → Iterable[source]

Returns an iterator to the log lines of the k``th replica of the ``role. The iterator ends end all qualifying log lines have been read.

If the scheduler supports time-based cursors fetching log lines for custom time ranges, then the since, until fields are honored, otherwise they are ignored. Not specifying since and until is equivalent to getting all available log lines. If the until is empty, then the iterator behaves like tail -f, following the log output until the job reaches a terminal state.

The exact definition of what constitutes a log is scheduler specific. Some schedulers may consider stderr or stdout as the log, others may read the logs from a log file.

Behaviors and assumptions:

  1. Produces an undefined-behavior if called on an app that does not exist The caller should check that the app exists using exists(app_id) prior to calling this method.

  2. Is not stateful, calling this method twice with same parameters returns a new iterator. Prior iteration progress is lost.

  3. Does not always support log-tailing. Not all schedulers support live log iteration (e.g. tailing logs while the app is running). Refer to the specific scheduler’s documentation for the iterator’s behavior.

  4. Does not guarantee log retention. It is possible that by the time this method is called, the underlying scheduler may have purged the log records for this application. If so this method raises an arbitrary exception.

  5. Only raises a StopIteration exception when the accessible log lines have been fully exhausted and the app has reached a final state. For instance, if the app gets stuck and does not produce any log lines, then the iterator blocks until the app eventually gets killed (either via timeout or manually) at which point it raises a StopIteration.

  6. Need not be supported by all schedulers.

  7. Some schedulers may support line cursors by supporting __getitem__ (e.g. iter[50] seeks to the 50th log line).

Returns

An Iterator over log lines of the specified role replica

Raises

NotImplementedError - if the scheduler does not support log iteration

run_opts() → torchelastic.tsm.driver.api.runopts[source]

Returns the run configuration options expected by the scheduler. Basically a --help for the run API.

schedule(dryrun_info: torchelastic.tsm.driver.api.AppDryRunInfo)str[source]

Same as submit except that it takes an AppDryrunInfo. Implementors are encouraged to implement this method rather than directly implementing submit since submit can be trivially implemented by:

dryrun_info = self.submit_dryrun(app, cfg)
return schedule(dryrun_info)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources