Shortcuts

MultiSyncDataCollector

class torchrl.collectors.collectors.MultiSyncDataCollector(create_env_fn: Sequence[Callable[[], EnvBase]], policy: Optional[Union[TensorDictModule, Callable[[TensorDictBase], TensorDictBase]]], *, frames_per_batch: int = 200, total_frames: Optional[int] = - 1, device: DEVICE_TYPING = None, storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, create_env_kwargs: Optional[Sequence[dict]] = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = InteractionType.RANDOM, exploration_mode=None, reset_when_done: bool = True, preemptive_threshold: float = None, update_at_each_batch: bool = False, devices=None, storing_devices=None, num_threads: int = None, num_sub_threads: int = 1)[source]

Runs a given number of DataCollectors on separate processes synchronously.

../../_images/aafig-77eaa37526b10b56438e35ff9489eb12c19b4c6c.svg

Envs can be identical or different.

The collection starts when the next item of the collector is queried, and no environment step is computed in between the reception of a batch of trajectory and the start of the next collection. This class can be safely used with online RL algorithms.

Examples

>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.envs import StepCounter
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: TransformedEnv(GymEnv("Pendulum-v1", device="cpu"), StepCounter(max_steps=50))
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = MultiSyncDataCollector(
...     create_env_fn=[env_maker, env_maker],
...     policy=policy,
...     total_frames=2000,
...     max_frames_per_traj=50,
...     frames_per_batch=200,
...     init_random_frames=-1,
...     reset_at_each_iter=False,
...     devices="cpu",
...     storing_devices="cpu",
... )
>>> for i, data in enumerate(collector):
...     if i == 2:
...         print(data)
...         break
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([200]),
    device=cpu,
    is_shared=False)
>>> collector.shutdown()
>>> del collector

Runs a given number of DataCollectors on separate processes.

Parameters:
  • create_env_fn (List[Callabled]) – list of Callables, each returning an instance of EnvBase.

  • policy (Callable, optional) – Instance of TensorDictModule class. Must accept TensorDictBase object as input. If None is provided, the policy used will be a RandomPolicy instance with the environment action_spec.

  • frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.

  • total_frames (int) –

    A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the total_frames is not divisible by frames_per_batch, an exception is raised.

    Endless collectors can be created by passing total_frames=-1.

  • device (int, str, torch.device or sequence of such, optional) – The device on which the policy will be placed. If it differs from the input policy device, the update_policy_weights_() method should be queried at appropriate times during the training loop to accommodate for the lag between parameter configuration at various times. If necessary, a list of devices can be passed in which case each element will correspond to the designated device of a sub-collector. Defaults to None (i.e. policy is kept on its original device).

  • storing_device (int, str, torch.device or sequence of such, optional) – The device on which the output tensordict.TensorDict will be stored. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. If necessary, a list of devices can be passed in which case each element will correspond to the designated storing device of a sub-collector. Defaults to "cpu".

  • create_env_kwargs (dict, optional) – A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector.

  • max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span over multiple batches (unless reset_at_each_iter is set to True, see below). Once a trajectory reaches n_steps, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to None (i.e. no maximum number of steps).

  • init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to None (i.e. no random frames).

  • reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to False.

  • postproc (Callable, optional) – A post-processing transform, such as a Transform or a MultiStep instance. Defaults to None.

  • split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See split_trajectories() for more information. Defaults to False.

  • exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of ExplorationType.RANDOM, ExplorationType.MODE or ExplorationType.MEAN. Defaults to ExplorationType.RANDOM

  • return_same_td (bool, optional) – if True, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False.

  • reset_when_done (bool, optional) – if True (default), an environment that return a True value in its "done" or "truncated" entry will be reset at the corresponding indices.

  • update_at_each_batch (boolm optional) – if True, update_policy_weight_() will be called before (sync) or after (async) each data collection. Defaults to False.

  • preemptive_threshold (float, optional) – a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early.

  • num_threads (int, optional) – number of threads for this process. Defaults to the number of workers.

  • num_sub_threads (int, optional) – number of threads of the subprocesses. Should be equal to one plus the number of processes launched within each subprocess (or one if a single process is launched). Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance.

load_state_dict(state_dict: OrderedDict) None[source]

Loads the state_dict on the workers.

Parameters:

state_dict (OrderedDict) – state_dict of the form {"worker0": state_dict0, "worker1": state_dict1}.

reset(reset_idx: Optional[Sequence[bool]] = None) None

Resets the environments to a new initial state.

Parameters:

reset_idx – Optional. Sequence indicating which environments have to be reset. If None, all environments are reset.

set_seed(seed: int, static_seed: bool = False) int[source]

Sets the seeds of the environments stored in the DataCollector.

Parameters:
  • seed – integer representing the seed to be used for the environment.

  • static_seed (bool, optional) – if True, the seed is not incremented. Defaults to False

Returns:

Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.

Examples

>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = lambda: ParallelEnv(6, env_fn)
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300)
>>> out_seed = collector.set_seed(1)  # out_seed = 6
shutdown()[source]

Shuts down all processes. This operation is irreversible.

state_dict() OrderedDict[source]

Returns the state_dict of the data collector.

Each field represents a worker containing its own state_dict.

update_policy_weights_(policy_weights: Optional[TensorDictBase] = None) None[source]

Updates the policy weights if the policy of the data collector and the trained policy live on different devices.

Parameters:

policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources