Shortcuts

MultiaSyncDataCollector

class torchrl.collectors.collectors.MultiaSyncDataCollector(*args, **kwargs)[source]

Runs a given number of DataCollectors on separate processes asynchronously.

../../_images/aafig-002943bd3a58e86dbadcb6d5c15917d50a26f6ad.svg

Environment types can be identical or different.

The collection keeps on occuring on all processes even between the time the batch of rollouts is collected and the next call to the iterator. This class can be safely used with offline RL sota-implementations.

Examples

>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = MultiaSyncDataCollector(
...     create_env_fn=[env_maker, env_maker],
...     policy=policy,
...     total_frames=2000,
...     max_frames_per_traj=50,
...     frames_per_batch=200,
...     init_random_frames=-1,
...     reset_at_each_iter=False,
...     devices="cpu",
...     storing_devices="cpu",
... )
>>> for i, data in enumerate(collector):
...     if i == 2:
...         print(data)
...         break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([200]),
    device=cpu,
    is_shared=False)
>>> collector.shutdown()
>>> del collector

Runs a given number of DataCollectors on separate processes.

Parameters:
  • create_env_fn (List[Callabled]) – list of Callables, each returning an instance of EnvBase.

  • policy (Callable) –

    Policy to be executed in the environment. Must accept tensordict.tensordict.TensorDictBase object as input. If None is provided, the policy used will be a RandomPolicy instance with the environment action_spec. Accepted policies are usually subclasses of TensorDictModuleBase. This is the recommended usage of the collector. Other callables are accepted too: If the policy is not a TensorDictModuleBase (e.g., a regular Module instances) it will be wrapped in a nn.Module first. Then, the collector will try to assess if these modules require wrapping in a TensorDictModule or not. - If the policy forward signature matches any of forward(self, tensordict),

    forward(self, td) or forward(self, <anything>: TensorDictBase) (or any typing with a single argument typed as a subclass of TensorDictBase) then the policy won’t be wrapped in a TensorDictModule.

    • In all other cases an attempt to wrap it will be undergone as such: TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys).

Keyword Arguments:
  • frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.

  • total_frames (int, optional) –

    A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the total_frames is not divisible by frames_per_batch, an exception is raised.

    Endless collectors can be created by passing total_frames=-1. Defaults to -1 (never ending collector).

  • device (int, str or torch.device, optional) – The generic device of the collector. The device args fills any non-specified device: if device is not None and any of storing_device, policy_device or env_device is not specified, its value will be set to device. Defaults to None (No default device). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.

  • storing_device (int, str or torch.device, optional) – The device on which the output TensorDict will be stored. If device is passed and storing_device is None, it will default to the value indicated by device. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults to None (the output tensordict isn’t on a specific device, leaf tensors sit on the device where they were created). Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.

  • env_device (int, str or torch.device, optional) – The device on which the environment should be cast (or executed if that functionality is supported). If not specified and the env has a non-None device, env_device will default to that value. If device is passed and env_device=None, it will default to device. If the value as such specified of env_device differs from policy_device and one of them is not None, the data will be cast to env_device before being passed to the env (i.e., passing different devices to policy and env is supported). Defaults to None. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.

  • policy_device (int, str or torch.device, optional) – The device on which the policy should be cast. If device is passed and policy_device=None, it will default to device. If the value as such specified of policy_device differs from env_device and one of them is not None, the data will be cast to policy_device before being passed to the policy (i.e., passing different devices to policy and env is supported). Defaults to None. Supports a list of devices if one wishes to indicate a different device for each worker. The list must be as long as the number of workers.

  • create_env_kwargs (dict, optional) – A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector.

  • max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span across multiple batches (unless reset_at_each_iter is set to True, see below). Once a trajectory reaches n_steps, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to None (i.e. no maximum number of steps).

  • init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to None (i.e. no random frames).

  • reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to False.

  • postproc (Callable, optional) – A post-processing transform, such as a Transform or a MultiStep instance. Defaults to None.

  • split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See split_trajectories() for more information. Defaults to False.

  • exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of torchrl.envs.utils.ExplorationType.RANDOM, torchrl.envs.utils.ExplorationType.MODE or torchrl.envs.utils.ExplorationType.MEAN. Defaults to torchrl.envs.utils.ExplorationType.RANDOM.

  • reset_when_done (bool, optional) – if True (default), an environment that return a True value in its "done" or "truncated" entry will be reset at the corresponding indices.

  • update_at_each_batch (boolm optional) – if True, update_policy_weight_() will be called before (sync) or after (async) each data collection. Defaults to False.

  • preemptive_threshold (float, optional) – a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early.

  • num_threads (int, optional) – number of threads for this process. Defaults to the number of workers.

  • num_sub_threads (int, optional) – number of threads of the subprocesses. Should be equal to one plus the number of processes launched within each subprocess (or one if a single process is launched). Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance.

  • cat_results (str, int or None) –

    (MultiSyncDataCollector exclusively). If "stack", the data collected from the workers will be stacked along the first dimension. This is the preferred behaviour as it is the most compatible with the rest of the library. If 0, results will be concatenated along the first dimension of the outputs, which can be the batched dimension if the environments are batched or the time dimension if not. A cat_results value of -1 will always concatenate results along the time dimension. This should be preferred over the default. Intermediate values are also accepted. Defaults to 0.

    Note

    From v0.5, this argument will default to "stack" for a better interoperability with the rest of the library.

  • set_truncated (bool, optional) – if True, the truncated signals (and corresponding "done" but not "terminated") will be set to True when the last frame of a rollout is reached. If no "truncated" key is found, an exception is raised. Truncated keys can be set through env.add_truncated_keys. Defaults to False.

load_state_dict(state_dict: OrderedDict) None[source]

Loads the state_dict on the workers.

Parameters:

state_dict (OrderedDict) – state_dict of the form {"worker0": state_dict0, "worker1": state_dict1}.

reset(reset_idx: Optional[Sequence[bool]] = None) None[source]

Resets the environments to a new initial state.

Parameters:

reset_idx – Optional. Sequence indicating which environments have to be reset. If None, all environments are reset.

set_seed(seed: int, static_seed: bool = False) int[source]

Sets the seeds of the environments stored in the DataCollector.

Parameters:
  • seed – integer representing the seed to be used for the environment.

  • static_seed (bool, optional) – if True, the seed is not incremented. Defaults to False

Returns:

Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.

Examples

>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = lambda: ParallelEnv(6, env_fn)
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300)
>>> out_seed = collector.set_seed(1)  # out_seed = 6
shutdown()[source]

Shuts down all processes. This operation is irreversible.

state_dict() OrderedDict[source]

Returns the state_dict of the data collector.

Each field represents a worker containing its own state_dict.

update_policy_weights_(policy_weights: Optional[TensorDictBase] = None) None[source]

Updates the policy weights if the policy of the data collector and the trained policy live on different devices.

Parameters:

policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources