MultiSyncDataCollector¶
- class torchrl.collectors.collectors.MultiSyncDataCollector(create_env_fn: Sequence[Callable[[], EnvBase]], policy: Optional[Union[TensorDictModule, Callable[[TensorDictBase], TensorDictBase]]], *, frames_per_batch: int = 200, total_frames: Optional[int] = - 1, device: DEVICE_TYPING = None, storing_device: Optional[Union[DEVICE_TYPING, Sequence[DEVICE_TYPING]]] = None, create_env_kwargs: Optional[Sequence[dict]] = None, max_frames_per_traj: int | None = None, init_random_frames: int | None = None, reset_at_each_iter: bool = False, postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = InteractionType.RANDOM, exploration_mode=None, reset_when_done: bool = True, preemptive_threshold: float = None, update_at_each_batch: bool = False, devices=None, storing_devices=None, num_threads: int = None, num_sub_threads: int = 1)[source]¶
Runs a given number of DataCollectors on separate processes synchronously.
Envs can be identical or different.
The collection starts when the next item of the collector is queried, and no environment step is computed in between the reception of a batch of trajectory and the start of the next collection. This class can be safely used with online RL algorithms.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs import StepCounter >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_maker = lambda: TransformedEnv(GymEnv("Pendulum-v1", device="cpu"), StepCounter(max_steps=50)) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = MultiSyncDataCollector( ... create_env_fn=[env_maker, env_maker], ... policy=policy, ... total_frames=2000, ... max_frames_per_traj=50, ... frames_per_batch=200, ... init_random_frames=-1, ... reset_at_each_iter=False, ... devices="cpu", ... storing_devices="cpu", ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break TensorDict( fields={ action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([200]), device=cpu, is_shared=False) >>> collector.shutdown() >>> del collector
Runs a given number of DataCollectors on separate processes.
- Parameters:
create_env_fn (List[Callabled]) – list of Callables, each returning an instance of
EnvBase
.policy (Callable, optional) – Instance of TensorDictModule class. Must accept TensorDictBase object as input. If
None
is provided, the policy used will be aRandomPolicy
instance with the environmentaction_spec
.frames_per_batch (int) – A keyword-only argument representing the total number of elements in a batch.
total_frames (int) –
A keyword-only argument representing the total number of frames returned by the collector during its lifespan. If the
total_frames
is not divisible byframes_per_batch
, an exception is raised.Endless collectors can be created by passing
total_frames=-1
.device (int, str, torch.device or sequence of such, optional) – The device on which the policy will be placed. If it differs from the input policy device, the
update_policy_weights_()
method should be queried at appropriate times during the training loop to accommodate for the lag between parameter configuration at various times. If necessary, a list of devices can be passed in which case each element will correspond to the designated device of a sub-collector. Defaults toNone
(i.e. policy is kept on its original device).storing_device (int, str, torch.device or sequence of such, optional) – The device on which the output
tensordict.TensorDict
will be stored. For long trajectories, it may be necessary to store the data on a different device than the one where the policy and env are executed. If necessary, a list of devices can be passed in which case each element will correspond to the designated storing device of a sub-collector. Defaults to"cpu"
.create_env_kwargs (dict, optional) – A dictionary with the keyword arguments used to create an environment. If a list is provided, each of its elements will be assigned to a sub-collector.
max_frames_per_traj (int, optional) – Maximum steps per trajectory. Note that a trajectory can span over multiple batches (unless
reset_at_each_iter
is set toTrue
, see below). Once a trajectory reachesn_steps
, the environment is reset. If the environment wraps multiple environments together, the number of steps is tracked for each environment independently. Negative values are allowed, in which case this argument is ignored. Defaults toNone
(i.e. no maximum number of steps).init_random_frames (int, optional) – Number of frames for which the policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to
None
(i.e. no random frames).reset_at_each_iter (bool, optional) – Whether environments should be reset at the beginning of a batch collection. Defaults to
False
.postproc (Callable, optional) – A post-processing transform, such as a
Transform
or aMultiStep
instance. Defaults toNone
.split_trajs (bool, optional) – Boolean indicating whether the resulting TensorDict should be split according to the trajectories. See
split_trajectories()
for more information. Defaults toFalse
.exploration_type (ExplorationType, optional) – interaction mode to be used when collecting data. Must be one of
ExplorationType.RANDOM
,ExplorationType.MODE
orExplorationType.MEAN
. Defaults toExplorationType.RANDOM
return_same_td (bool, optional) – if
True
, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same tensordict is added to a replay buffer for instance, the whole content of the buffer will be identical. Default isFalse
.reset_when_done (bool, optional) – if
True
(default), an environment that return aTrue
value in its"done"
or"truncated"
entry will be reset at the corresponding indices.update_at_each_batch (boolm optional) – if
True
,update_policy_weight_()
will be called before (sync) or after (async) each data collection. Defaults toFalse
.preemptive_threshold (float, optional) – a value between 0.0 and 1.0 that specifies the ratio of workers that will be allowed to finished collecting their rollout before the rest are forced to end early.
num_threads (int, optional) – number of threads for this process. Defaults to the number of workers.
num_sub_threads (int, optional) – number of threads of the subprocesses. Should be equal to one plus the number of processes launched within each subprocess (or one if a single process is launched). Defaults to 1 for safety: if none is indicated, launching multiple workers may charge the cpu load too much and harm performance.
- load_state_dict(state_dict: OrderedDict) None [source]¶
Loads the state_dict on the workers.
- Parameters:
state_dict (OrderedDict) – state_dict of the form
{"worker0": state_dict0, "worker1": state_dict1}
.
- reset(reset_idx: Optional[Sequence[bool]] = None) None ¶
Resets the environments to a new initial state.
- Parameters:
reset_idx – Optional. Sequence indicating which environments have to be reset. If None, all environments are reset.
- set_seed(seed: int, static_seed: bool = False) int [source]¶
Sets the seeds of the environments stored in the DataCollector.
- Parameters:
seed – integer representing the seed to be used for the environment.
static_seed (bool, optional) – if
True
, the seed is not incremented. Defaults to False
- Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) >>> out_seed = collector.set_seed(1) # out_seed = 6
- state_dict() OrderedDict [source]¶
Returns the state_dict of the data collector.
Each field represents a worker containing its own state_dict.
- update_policy_weights_(policy_weights: Optional[TensorDictBase] = None) None [source]¶
Updates the policy weights if the policy of the data collector and the trained policy live on different devices.
- Parameters:
policy_weights (TensorDictBase, optional) – if provided, a TensorDict containing the weights of the policy to be used for the udpdate.