Shortcuts

SafeProbabilisticTensorDictSequential

class torchrl.modules.tensordict_module.SafeProbabilisticTensorDictSequential(*args, **kwargs)[source]

tensordict.nn.ProbabilisticTensorDictSequential subclass that accepts a TensorSpec as argument to control the output domain.

Similarly to TensorDictSequential, but enforces that the final module in the sequence is an ProbabilisticTensorDictModule and also exposes get_dist method to recover the distribution object from the ProbabilisticTensorDictModule

Parameters:
  • modules (iterable of TensorDictModules) – ordered sequence of TensorDictModule instances, terminating in ProbabilisticTensorDictModule, to be run sequentially.

  • partial_tolerant (bool, optional) – if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is True AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts looking for those that have the required keys, if any.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources