SyncDataCollector¶
- class torchrl.collectors.collectors.SyncDataCollector(create_env_fn: EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]], policy: TensorDictModule | Callable[[TensorDictBase], TensorDictBase] | None, *, frames_per_batch: int, total_frames: int, device: device | str | int = None, storing_device: device | str | int = None, create_env_kwargs: dict | None = None, max_frames_per_traj: int = -1, init_random_frames: int = -1, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: InteractionType = InteractionType.RANDOM, exploration_mode=None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None)[source]¶
Generic data collector for RL problems. Requires an environment constructor and a policy.
- Parameters:
create_env_fn (Callable) – a callable that returns an instance of
EnvBase
class.policy (Callable) – Policy to be executed in the environment. Must accept
tensordict.tensordict.TensorDictBase
object as input. IfNone
is provided, the policy used will be aRandomPolicy
instance with the environmentaction_spec
.frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.
total_frames (int) –
A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the
total_frames
is not divisible byframes_per_batch
, an exception is raised.Endless collectors can be created by passing
total_frames=-1
.device (int, str or torch.device, optional) – The device on which the policy will be placed. If it differs from the input policy device, the
update_policy_weights_()
method should be queried at appropriate times during the training loop to accommodate for the lag between parameter configuration at various times. Defaults toNone
(i.e. policy is kept on its original device).storing_device (int, str or torch.device, optional) – The device on which the output
tensordict.TensorDict
will be stored. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults to"cpu"
.create_env_kwargs (dict, optional) – Dictionary of kwargs for
create_env_fn
.max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span over multiple batches (unless
reset_at_each_iter
is set toTrue
, see below). Once a trajectory reachesn_steps
, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults to-1
(i.e. no maximum number of steps).init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. Defaults to
-1
(i.e. no random frames).reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to
False
.postproc (Callable, optional) – A post-processing transform, such as a
Transform
or aMultiStep
instance. Defaults toNone
.split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See
split_trajectories()
for more information. Defaults toFalse
.exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of
ExplorationType.RANDOM
,ExplorationType.MODE
orExplorationType.MEAN
. Defaults toExplorationType.RANDOM
return_same_td (bool, optional) – if
True
, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default is False.interruptor (_Interruptor, optional) – An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. Default is
False
.reset_when_done (bool, optional) – if
True
(default), an environment that return aTrue
value in its"done"
or"truncated"
entry will be reset at the corresponding indices.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector( ... create_env_fn=env_maker, ... policy=policy, ... total_frames=2000, ... max_frames_per_traj=50, ... frames_per_batch=200, ... init_random_frames=-1, ... reset_at_each_iter=False, ... device="cpu", ... storing_device="cpu", ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ step_count: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), "traj_ids: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), mask: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([4, 50, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False) >>> del collector
The collector delivers batches of data that are marked with a
"time"
dimension.Examples
>>> assert data.names[-1] == "time"
- iterator() Iterator[TensorDictBase] [source]¶
Iterates through the DataCollector.
Yields: TensorDictBase objects containing (chunks of) trajectories
- load_state_dict(state_dict: OrderedDict, **kwargs) None [source]¶
Loads a state_dict on the environment and policy.
- Parameters:
state_dict (OrderedDict) – ordered dictionary containing the fields “policy_state_dict” and
"env_state_dict"
.
- rollout() TensorDictBase [source]¶
Computes a rollout in the environment using the provided policy.
- Returns:
TensorDictBase containing the computed rollout.
- set_seed(seed: int, static_seed: bool = False) int [source]¶
Sets the seeds of the environments stored in the DataCollector.
- Parameters:
seed (int) – integer representing the seed to be used for the environment.
static_seed (bool, optional) – if
True
, the seed is not incremented. Defaults to False
- Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> collector = SyncDataCollector(env_fn_parallel) >>> out_seed = collector.set_seed(1) # out_seed = 6
- state_dict() OrderedDict [source]¶
Returns the local state_dict of the data collector (environment and policy).
- Returns:
an ordered dictionary with fields
"policy_state_dict"
and “env_state_dict”.
- update_policy_weights_(policy_weights: TensorDictBase | None = None) None [source]¶
Updates the policy weights if the policy of the data collector and the trained policy live on different devices.
- Parameters:
policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.