ConsistentDropout¶
- class torchrl.modules.ConsistentDropout(p: float = 0.5)[source]¶
Implements a
Dropout
variant with consistent dropout.This method is proposed in “Consistent Dropout for Policy Gradient Reinforcement Learning” (Hausknecht & Wagener, 2022).
This
Dropout
variant attempts to increase training stability and reduce update variance by caching the dropout masks used during rollout and reusing them during the update phase.The class you are looking at is independent of the rest of TorchRL’s API and does not require tensordict to be run.
ConsistentDropoutModule
is a wrapper aroundConsistentDropout
that capitalizes on the extensibility ofTensorDict``s by storing generated dropout masks in the transition ``TensorDict
themselves. See this class for a detailed explanation as well as usage examples.There is otherwise little conceptual deviance from the PyTorch
Dropout
implementation.- ..note:: TorchRL’s data collectors perform rollouts in
no_grad()
mode but not in eval mode, so the dropout masks will be applied unless the policy passed to the collector is in eval mode.
Note
Unlike other exploration modules,
ConsistentDropoutModule
uses thetrain
/eval
mode to comply with the regular Dropout API in PyTorch. Theset_exploration_type()
context manager will have no effect on this module.- Parameters:
p (
float
, optional) – Dropout probability. Defaults to0.5
.
See also
MultiSyncDataCollector
: Uses_main_async_collector()
(SyncDataCollector
) under the hood
- forward(x: Tensor, mask: Optional[Tensor] = None) Tensor [source]¶
During training (rollouts & updates), this call masks a tensor full of ones before multiplying with the input tensor.
During evaluation, this call results in a no-op and only the input is returned.
- Parameters:
x (torch.Tensor) – the input tensor.
mask (torch.Tensor, optional) – the optional mask for the dropout.
Returns: a tensor and a corresponding mask in train mode, and only a tensor in eval mode.
- ..note:: TorchRL’s data collectors perform rollouts in