class torchrl.collectors.collectors.SyncDataCollector(create_env_fn: EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]], policy: TensorDictModule | Callable[[TensorDictBase], TensorDictBase] | None, *, frames_per_batch: int, total_frames: int, device: device | str | int = None, storing_device: device | str | int = None, create_env_kwargs: dict | None = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: InteractionType = InteractionType.RANDOM, exploration_mode=None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None)[source]

Generic data collector for RL problems. Requires an environment constructor and a policy.

  • create_env_fn (Callable) – a callable that returns an instance of EnvBase class.

  • policy (Callable) – Policy to be executed in the environment. Must accept tensordict.tensordict.TensorDictBase object as input. If None is provided, the policy used will be a RandomPolicy instance with the environment action_spec.

  • frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.

  • total_frames (int) –

    A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the total_frames is not divisible by frames_per_batch, an exception is raised.

    Endless collectors can be created by passing total_frames=-1.

  • device (int, str or torch.device, optional) – The device on which the policy will be placed. If it differs from the input policy device, the update_policy_weights_() method should be queried at appropriate times during the training loop to accommodate for the lag between parameter configuration at various times. Defaults to None (i.e. policy is kept on its original device).

  • storing_device (int, str or torch.device, optional) – The device on which the output tensordict.TensorDict will be stored. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults to "cpu".

  • create_env_kwargs (dict, optional) – Dictionary of kwargs for create_env_fn.

  • max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span over multiple batches (unless reset_at_each_iter is set to True, see below). Once a trajectory reaches n_steps, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to -1 (i.e. no maximum number of steps).

  • init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. Defaults to -1 (i.e. no random frames).

  • reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to False.

  • postproc (Callable, optional) – A post-processing transform, such as a Transform or a MultiStep instance. Defaults to None.

  • split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See split_trajectories() for more information. Defaults to False.

  • exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of ExplorationType.RANDOM, ExplorationType.MODE or ExplorationType.MEAN. Defaults to ExplorationType.RANDOM

  • return_same_td (bool, optional) – if True, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False.

  • interruptor (_Interruptor, optional) – An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. Default is False.

  • reset_when_done (bool, optional) – if True (default), an environment that return a True value in its "done" or "truncated" entry will be reset at the corresponding indices.


>>> from torchrl.envs.libs.gym import GymEnv
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
>>> collector = SyncDataCollector(
...     create_env_fn=env_maker,
...     policy=policy,
...     total_frames=2000,
...     max_frames_per_traj=50,
...     frames_per_batch=200,
...     init_random_frames=-1,
...     reset_at_each_iter=False,
...     device="cpu",
...     storing_device="cpu",
... )
>>> for i, data in enumerate(collector):
...     if i == 2:
...         print(data)
...         break
        action: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
                step_count: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False),
                "traj_ids: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([4, 50]),
        done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        mask: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
                observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4, 50]),
        observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 50]),
>>> del collector

The collector delivers batches of data that are marked with a "time" dimension.


>>> assert data.names[-1] == "time"
iterator() Iterator[TensorDictBase][source]

Iterates through the DataCollector.

Yields: TensorDictBase objects containing (chunks of) trajectories

load_state_dict(state_dict: OrderedDict, **kwargs) None[source]

Loads a state_dict on the environment and policy.


state_dict (OrderedDict) – ordered dictionary containing the fields “policy_state_dict” and "env_state_dict".

reset(index=None, **kwargs) None[source]

Resets the environments to a new initial state.

rollout() TensorDictBase[source]

Computes a rollout in the environment using the provided policy.


TensorDictBase containing the computed rollout.

set_seed(seed: int, static_seed: bool = False) int[source]

Sets the seeds of the environments stored in the DataCollector.

  • seed (int) – integer representing the seed to be used for the environment.

  • static_seed (bool, optional) – if True, the seed is not incremented. Defaults to False


Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.


>>> from torchrl.envs import ParallelEnv
>>> from torchrl.envs.libs.gym import GymEnv
>>> env_fn = lambda: GymEnv("Pendulum-v1")
>>> env_fn_parallel = ParallelEnv(6, env_fn)
>>> collector = SyncDataCollector(env_fn_parallel)
>>> out_seed = collector.set_seed(1)  # out_seed = 6
shutdown() None[source]

Shuts down all workers and/or closes the local environment.

state_dict() OrderedDict[source]

Returns the local state_dict of the data collector (environment and policy).


an ordered dictionary with fields "policy_state_dict" and “env_state_dict”.

update_policy_weights_(policy_weights: TensorDictBase | None = None) None[source]

Updates the policy weights if the policy of the data collector and the trained policy live on different devices.


policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources