RandomCropTensorDict¶
- class torchrl.envs.transforms.RandomCropTensorDict(sub_seq_len: int, sample_dim: int = - 1, mask_key: Optional[NestedKey] = None)[source]¶
A trajectory sub-sampler for ReplayBuffer and modules.
Gathers a sub-sequence of a defined length along the last dimension of the input tensordict. This can be used to get cropped trajectories from trajectories sampled from a ReplayBuffer.
This transform is primarily designed to be used with replay buffers and modules. Currently, it cannot be used as an environment transform. Do not hesitate to request for this behaviour through an issue if this is desired.
- Parameters:
sub_seq_len (int) – the length of the sub-trajectory to sample
sample_dim (int, optional) – the dimension along which the cropping should occur. Negative dimensions should be preferred to make the transform robust to tensordicts of varying batch dimensions. Defaults to -1 (the default time dimension in TorchRL).
mask_key (NestedKey) – If provided, this represents the mask key to be looked for when doing the sampling. If provided, it only valid elements will be returned. It is assumed that the mask is a boolean tensor with first True values and then False values, not mixed together.
RandomCropTensorDict
will NOT check that this is respected hence any error caused by an improper mask risks to go unnoticed. Defaults: None (no mask key).