
Source code for torchrl.collectors.distributed.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import subprocess
import time

from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.collectors.distributed.default_configs import (
from torchrl.collectors.distributed.generic import _distributed_init_delayed
from torchrl.collectors.distributed.rpc import _rpc_init_collection_node

    import submitit

    _has_submitit = True
except ModuleNotFoundError as err:
    _has_submitit = False
    SUBMITIT_ERR = err

[docs]class submitit_delayed_launcher: """Delayed launcher for submitit. In some cases, launched jobs cannot spawn other jobs on their own and this can only be done at the jump-host level. In these cases, the :func:`submitit_delayed_launcher` can be used to pre-launch collector nodes that will wait for the main worker to provide the launching instruction. Args: num_jobs (int): the number of collection jobs to be launched. framework (str, optional): the framework to use. Can be either ``"distributed"`` or ``"rpc"``. ``"distributed"`` requires a :class:`~.DistributedDataCollector` collector whereas ``"rpc"`` requires a :class:`RPCDataCollector`. Defaults to ``"distributed"``. backend (str, optional): torch.distributed backend in case ``framework`` points to ``"distributed"``. This value must match the one passed to the collector, otherwise main and satellite nodes will fail to reach the rendezvous and hang forever (ie no exception will be raised!) Defaults to ``'gloo'``. tcpport (int or str, optional): the TCP port to use. Defaults to :obj:`torchrl.collectors.distributed.default_configs.TCP_PORT` submitit_main_conf (dict, optional): the main node configuration to be passed to submitit. Defaults to :obj:`torchrl.collectors.distributed.default_configs.DEFAULT_SLURM_CONF_MAIN` submitit_collection_conf (dict, optional): the configuration to be passed to submitit. Defaults to :obj:`torchrl.collectors.distributed.default_configs.DEFAULT_SLURM_CONF` Examples: >>> num_jobs=2 >>> @submitit_delayed_launcher(num_jobs=num_jobs) ... def main(): ... from torchrl.envs.utils import RandomPolicy from torchrl.envs.libs.gym import GymEnv ... from import BoundedContinuous ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, ... policy=RandomPolicy(BoundedContinuous(-1, 1, shape=(1,))), ... launcher="submitit_delayed", ... ) ... for data in collector: ... print(data) ... >>> if __name__ == "__main__": ... main() ... """ _VERBOSE = VERBOSE # for debugging def __init__( self, num_jobs, framework="distributed", backend="gloo", tcpport=TCP_PORT, submitit_main_conf: dict = DEFAULT_SLURM_CONF_MAIN, submitit_collection_conf: dict = DEFAULT_SLURM_CONF, ): self.num_jobs = num_jobs self.backend = backend self.framework = framework self.submitit_collection_conf = submitit_collection_conf self.submitit_main_conf = submitit_main_conf self.tcpport = tcpport def __call__(self, main_func): def exec_fun(): if not _has_submitit: raise ModuleNotFoundError( "Failed to import submitit. Check installation of the library." ) from SUBMITIT_ERR # submit main executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: cmd = f"squeue -j {main_job.job_id} -o %N | tail -1" node = subprocess.check_output(cmd, shell=True, text=True).strip() try: node = int(node) except ValueError: time.sleep(0.5) continue"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip()"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.submitit_collection_conf) jobs = [] if self.framework == "rpc": from .rpc import DEFAULT_TENSORPIPE_OPTIONS tensorpipe_options = DEFAULT_TENSORPIPE_OPTIONS for i in range(self.num_jobs): rank = i + 1 if self.framework == "distributed": job = executor.submit( _distributed_init_delayed, rank, self.backend, rank0_ip, self.tcpport, world_size, self._VERBOSE, ) elif self.framework == "rpc": job = executor.submit( _rpc_init_collection_node, rank, rank0_ip, self.tcpport, world_size, None, tensorpipe_options, ) else: raise NotImplementedError(f"Unknown framework {self.framework}.") jobs.append(job) for job in jobs: job.result() main_job.result() return exec_fun


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources