Shortcuts

RandomCropTensorDict

class torchrl.envs.transforms.RandomCropTensorDict(sub_seq_len: int, sample_dim: int = - 1, mask_key: Optional[Union[str, Tuple[str, ...]]] = None)[source]

A trajectory sub-sampler for ReplayBuffer and modules.

Gathers a sub-sequence of a defined length along the last dimension of the input tensordict. This can be used to get cropped trajectories from trajectories sampled from a ReplayBuffer.

This transform is primarily designed to be used with replay buffers and modules. Currently, it cannot be used as an environment transform. Do not hesitate to request for this behaviour through an issue if this is desired.

Parameters:
  • sub_seq_len (int) – the length of the sub-trajectory to sample

  • sample_dim (int, optional) – the dimension along which the cropping should occur. Negative dimensions should be preferred to make the transform robust to tensordicts of varying batch dimensions. Defaults to -1 (the default time dimension in TorchRL).

  • mask_key (NestedKey) – If provided, this represents the mask key to be looked for when doing the sampling. If provided, it only valid elements will be returned. It is assumed that the mask is a boolean tensor with first True values and then False values, not mixed together. RandomCropTensorDict will NOT check that this is respected hence any error caused by an improper mask risks to go unnoticed. Defaults: None (no mask key).

forward(tensordict: TensorDictBase) TensorDictBase[source]

Reads the input tensordict, and for the selected keys, applies the transform.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources