torchrl.collectors package

Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they collect data over non-static data sources and (2) the data is collected using a model (likely a version of the model that is being trained).

TorchRL’s data collectors accept two main arguments: an environment (or a list of environment constructors) and a policy. They will iteratively execute an environment step and a policy query over a defined number of steps before delivering a stack of the data collected to the user. Environments will be reset whenever they reach a done state, and/or after a predefined number of steps.

Because data collection is a potentially compute heavy process, it is crucial to configure the execution hyperparameters appropriately. The first parameter to take into consideration is whether the data collection should occur serially with the optimization step or in parallel. The SyncDataCollector class will execute the data collection on the training worker. The MultiSyncDataCollector will split the workload across an number of workers and aggregate the results that will be delivered to the training worker. Finally, the MultiaSyncDataCollector will execute the data collection on several workers and deliver the first batch of results that it can gather. This execution will occur continuously and concomitantly with the training of the networks: this implies that the weights of the policy that is used for the data collection may slightly lag the configuration of the policy on the training worker. Therefore, although this class may be the fastest to collect data, it comes at the price of being suitable only in settings where it is acceptable to gather data asynchronously (e.g. off-policy RL or curriculum RL). For remotely executed rollouts (MultiSyncDataCollector or MultiaSyncDataCollector) it is necessary to synchronise the weights of the remote policy with the weights from the training worker using either the collector.update_policy_weights_() or by setting update_at_each_batch=True in the constructor.

The second parameter to consider (in the remote settings) is the device where the data will be collected and the device where the environment and policy operations will be executed. For instance, a policy executed on CPU may be slower than one executed on CUDA. When multiple inference workers run concomitantly, dispatching the compute workload across the available devices may speed up the collection or avoid OOM errors. Finally, the choice of the batch size and passing device (ie the device where the data will be stored while waiting to be passed to the collection worker) may also impact the memory management. The key parameters to control are devices which controls the execution devices (ie the device of the policy) and storing_devices which will control the device where the environment and data are stored during a rollout. A good heuristic is usually to use the same device for storage and compute, which is the default behaviour when only the devices argument is being passed.

Besides those compute parameters, users may choose to configure the following parameters:

  • max_frames_per_traj: the number of frames after which a env.reset() is called

  • frames_per_batch: the number of frames delivered at each iteration over the collector

  • init_random_frames: the number of random steps (steps where env.rand_step() is being called)

  • reset_at_each_iter: if True, the environment(s) will be reset after each batch collection

  • split_trajs: if True, the trajectories will be split and delivered in a padded tensordict along with a "mask" key that will point to a boolean mask representing the valid values.

  • exploration_type: the exploration strategy to be used with the policy.

  • reset_when_done: whether environments should be reset when reaching a done state.

Single node data collectors

DataCollectorBase(*args, **kwds)

Base class for data collectors.

SyncDataCollector(create_env_fn, policy, *, ...)

Generic data collector for RL problems.

MultiSyncDataCollector(create_env_fn, policy, *)

Runs a given number of DataCollectors on separate processes synchronously.

MultiaSyncDataCollector(*args, **kwargs)

Runs a given number of DataCollectors on separate processes asynchronously.


A random policy for data collectors.

aSyncDataCollector(create_env_fn[, policy, ...])

Runs a single DataCollector on a separate process.

Distributed data collectors

TorchRL provides a set of distributed data collectors. These tools support multiple backends ('gloo', 'nccl', 'mpi' with the DistributedDataCollector or PyTorch RPC with RPCDataCollector) and launchers ('ray', submitit or torch.multiprocessing). They can be efficiently used in synchronous or asynchronous mode, on a single node or across multiple nodes.

Resources: Find examples for these collectors in the dedicated folder.


Choosing the sub-collector: All distributed collectors support the various single machine collectors. One may wonder why using a MultiSyncDataCollector or a ParallelEnv instead. In general, multiprocessed collectors have a lower IO footprint than parallel environments which need to communicate at each step. Yet, the model specs play a role in the opposite direction, since using parallel environments will result in a faster execution of the policy (and/or transforms) since these operations will be vectorized.


Choosing the device of a collector (or a parallel environment): Sharing data among processes is achieved via shared-memory buffers with parallel environment and multiprocessed environments executed on CPU. Depending on the capabilities of the machine being used, this may be prohibitively slow compared to sharing data on GPU which is natively supported by cuda drivers. In practice, this means that using the device="cpu" keyword argument when building a parallel environment or collector can result in a slower collection than using device="cuda" when available.

DistributedDataCollector(create_env_fn, ...)

A distributed data collector with torch.distributed backend.

RPCDataCollector(create_env_fn, policy, *, ...)

An RPC-based distributed data collector.

DistributedSyncDataCollector(create_env_fn, ...)

A distributed synchronous data collector with torch.distributed backend.

submitit_delayed_launcher(num_jobs[, ...])

Delayed launcher for submitit.

RayCollector(create_env_fn, policy, ...[, ...])

Distributed data collector with Ray backend.

Helper functions

split_trajectories(rollout_tensordict[, prefix])

A util function for trajectory separation.


