SyncDataCollector¶
- class torchrl.collectors.SyncDataCollector(create_env_fn: Union[EnvBase, 'EnvCreator', Sequence[Callable[[], EnvBase]]], policy: Optional[Union[TensorDictModule, Callable[[TensorDictBase], TensorDictBase]]], *, frames_per_batch: int, total_frames: int = - 1, device: DEVICE_TYPING = None, storing_device: DEVICE_TYPING = None, policy_device: DEVICE_TYPING = None, env_device: DEVICE_TYPING = None, create_env_kwargs: dict | None = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = InteractionType.RANDOM, exploration_mode: str | None = None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None, set_truncated: bool = False)[source]¶
Generic data collector for RL problems. Requires an environment constructor and a policy.
- Parameters:
create_env_fn (Callable) – a callable that returns an instance of
EnvBase
class.policy (Callable) –
Policy to be executed in the environment. Must accept
tensordict.tensordict.TensorDictBase
object as input. IfNone
is provided, the policy used will be aRandomPolicy
instance with the environmentaction_spec
. Accepted policies are usually subclasses ofTensorDictModuleBase
. This is the recommended usage of the collector. Other callables are accepted too: If the policy is not aTensorDictModuleBase
(e.g., a regularModule
instances) it will be wrapped in a nn.Module first. Then, the collector will try to assess if these modules require wrapping in aTensorDictModule
or not. - If the policy forward signature matches any offorward(self, tensordict)
,forward(self, td)
orforward(self, <anything>: TensorDictBase)
(or any typing with a single argument typed as a subclass ofTensorDictBase
) then the policy won’t be wrapped in aTensorDictModule
.In all other cases an attempt to wrap it will be undergone as such:
TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)
.
- Keyword Arguments:
frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.
total_frames (int) –
A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the
total_frames
is not divisible byframes_per_batch
, an exception is raised.Endless collectors can be created by passing
total_frames=-1
. Defaults to-1
(endless collector).device (int, str or torch.device, optional) – The generic device of the collector. The
device
args fills any non-specified device: ifdevice
is notNone
and any ofstoring_device
,policy_device
orenv_device
is not specified, its value will be set todevice
. Defaults toNone
(No default device).storing_device (int, str or torch.device, optional) – The device on which the output
TensorDict
will be stored. Ifdevice
is passed andstoring_device
isNone
, it will default to the value indicated bydevice
. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. Defaults toNone
(the output tensordict isn’t on a specific device, leaf tensors sit on the device where they were created).env_device (int, str or torch.device, optional) – The device on which the environment should be cast (or executed if that functionality is supported). If not specified and the env has a non-
None
device,env_device
will default to that value. Ifdevice
is passed andenv_device=None
, it will default todevice
. If the value as such specified ofenv_device
differs frompolicy_device
and one of them is notNone
, the data will be cast toenv_device
before being passed to the env (i.e., passing different devices to policy and env is supported). Defaults toNone
.policy_device (int, str or torch.device, optional) – The device on which the policy should be cast. If
device
is passed andpolicy_device=None
, it will default todevice
. If the value as such specified ofpolicy_device
differs fromenv_device
and one of them is notNone
, the data will be cast topolicy_device
before being passed to the policy (i.e., passing different devices to policy and env is supported). Defaults toNone
.create_env_kwargs (dict, optional) – Dictionary of kwargs for
create_env_fn
.max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span across multiple batches (unless
reset_at_each_iter
is set toTrue
, see below). Once a trajectory reachesn_steps
, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults toNone
(i.e., no maximum number of steps).init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to
None
(i.e. no random frames).reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to
False
.postproc (Callable, optional) – A post-processing transform, such as a
Transform
or aMultiStep
instance. Defaults toNone
.split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See
split_trajectories()
for more information. Defaults toFalse
.exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of
torchrl.envs.utils.ExplorationType.RANDOM
,torchrl.envs.utils.ExplorationType.MODE
ortorchrl.envs.utils.ExplorationType.MEAN
. Defaults totorchrl.envs.utils.ExplorationType.RANDOM
.return_same_td (bool, optional) – if
True
, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default isFalse
.interruptor (_Interruptor, optional) – An _Interruptor object that can be used from outside the class to control rollout collection. The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement strategies such as preeptively stopping rollout collection. Default is
False
.set_truncated (bool, optional) – if
True
, the truncated signals (and corresponding"done"
but not"terminated"
) will be set toTrue
when the last frame of a rollout is reached. If no"truncated"
key is found, an exception is raised. Truncated keys can be set throughenv.add_truncated_keys
. Defaults toFalse
.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector( ... create_env_fn=env_maker, ... policy=policy, ... total_frames=2000, ... max_frames_per_traj=50, ... frames_per_batch=200, ... init_random_frames=-1, ... reset_at_each_iter=False, ... device="cpu", ... storing_device="cpu", ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> del collector
The collector delivers batches of data that are marked with a
"time"
dimension.Examples
>>> assert data.names[-1] == "time"
- iterator() Iterator[TensorDictBase] [source]¶
Iterates through the DataCollector.
Yields: TensorDictBase objects containing (chunks of) trajectories
- load_state_dict(state_dict: OrderedDict, **kwargs) None [source]¶
Loads a state_dict on the environment and policy.
- Parameters:
state_dict (OrderedDict) – ordered dictionary containing the fields “policy_state_dict” and
"env_state_dict"
.
- rollout() TensorDictBase [source]¶
Computes a rollout in the environment using the provided policy.
- Returns:
TensorDictBase containing the computed rollout.
- set_seed(seed: int, static_seed: bool = False) int [source]¶
Sets the seeds of the environments stored in the DataCollector.
- Parameters:
seed (int) – integer representing the seed to be used for the environment.
static_seed (bool, optional) – if
True
, the seed is not incremented. Defaults to False
- Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6
- state_dict() OrderedDict [source]¶
Returns the local state_dict of the data collector (environment and policy).
- Returns:
an ordered dictionary with fields
"policy_state_dict"
and “env_state_dict”.
- update_policy_weights_(policy_weights: Optional[TensorDictBase] = None) None [source]¶
Updates the policy weights if the policy of the data collector and the trained policy live on different devices.
- Parameters:
policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.