Shortcuts

MultiDiscreteTensorSpec

class torchrl.data.MultiDiscreteTensorSpec(nvec: Union[Sequence[int], torch.Tensor, int], shape: Optional[torch.Size] = None, device: Optional[DEVICE_TYPING] = None, dtype: Optional[Union[str, torch.dtype]] = torch.int64, mask: torch.Tensor | None = None)[source]

A concatenation of discrete tensor spec.

Parameters:
  • nvec (iterable of integers or torch.Tensor) – cardinality of each of the elements of the tensor. Can have several axes.

  • shape (torch.Size, optional) – total shape of the sampled tensors. If provided, the last m dimensions must match nvec.shape.

  • device (str, int or torch.device, optional) – device of the tensors.

  • dtype (str or torch.dtype, optional) – dtype of the tensors.

Examples

>>> ts = MultiDiscreteTensorSpec((3, 2, 3))
>>> ts.is_in(torch.tensor([2, 0, 1]))
True
>>> ts.is_in(torch.tensor([2, 2, 1]))
False
assert_is_in(value: Tensor) None

Asserts whether a tensor belongs to the box, and raises an exception otherwise.

Parameters:

value (torch.Tensor) – value to be checked.

clear_device_()

A no-op for all leaf specs (which must have a device).

encode(val: Union[ndarray, Tensor], *, ignore_device=False) Tensor

Encodes a value given the specified spec, and return the corresponding tensor.

Parameters:

val (np.ndarray or torch.Tensor) – value to be encoded as tensor.

Keyword Arguments:

ignore_device (bool, optional) – if True, the spec device will be ignored. This is used to group tensor casting within a call to TensorDict(..., device="cuda") which is faster.

Returns:

torch.Tensor matching the required tensor specs.

expand(*shape)[source]

Returns a new Spec with the extended shape.

Parameters:

*shape (tuple or iterable of int) – the new shape of the Spec. Must comply with the current shape: its length must be at least as long as the current shape length, and its last values must be complient too; ie they can only differ from it if the current dimension is a singleton.

flatten(start_dim, end_dim)

Flattens a tensorspec.

Check flatten() for more information on this method.

classmethod implements_for_spec(torch_function: Callable) Callable

Register a torch function override for TensorSpec.

abstract index(index: Union[int, Tensor, ndarray, slice, List], tensor_to_index: Tensor) Tensor

Indexes the input tensor.

Parameters:
  • index (int, torch.Tensor, slice or list) – index of the tensor

  • tensor_to_index – tensor to be indexed

Returns:

indexed tensor

is_in(val: Tensor) bool[source]

If the value val is in the box defined by the TensorSpec, returns True, otherwise False.

Parameters:

val (torch.Tensor) – value to be checked

Returns:

boolean indicating if values belongs to the TensorSpec box

project(val: Tensor) Tensor

If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.

Parameters:

val (torch.Tensor) – tensor to be mapped to the box.

Returns:

a torch.Tensor belonging to the TensorSpec box.

rand(shape: Optional[Size] = None) Tensor[source]

Returns a random tensor in the box. The sampling will be uniform unless the box is unbounded.

Parameters:

shape (torch.Size) – shape of the random tensor

Returns:

a random tensor sampled in the TensorSpec box.

reshape(*shape)

Reshapes a tensorspec.

Check reshape() for more information on this method.

squeeze(dim: int | None = None)[source]

Returns a new Spec with all the dimensions of size 1 removed.

When dim is given, a squeeze operation is done only in that dimension.

Parameters:

dim (int or None) – the dimension to apply the squeeze operation to

to_numpy(val: Tensor, safe: Optional[bool] = None) dict

Returns the np.ndarray correspondent of an input tensor.

Parameters:
  • val (torch.Tensor) – tensor to be transformed_in to numpy.

  • safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the CHECK_SPEC_ENCODE environment variable.

Returns:

a np.ndarray

to_one_hot(val: Tensor, safe: Optional[bool] = None) Union[MultiOneHotDiscreteTensorSpec, Tensor][source]

Encodes a discrete tensor from the spec domain into its one-hot correspondent.

Parameters:
  • val (torch.Tensor, optional) – Tensor to one-hot encode.

  • safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the CHECK_SPEC_ENCODE environment variable.

Returns:

The one-hot encoded tensor.

to_one_hot_spec() MultiOneHotDiscreteTensorSpec[source]

Converts the spec to the equivalent one-hot spec.

type_check(value: Tensor, key: Optional[str] = None) None

Checks the input value dtype against the TensorSpec dtype and raises an exception if they don’t match.

Parameters:
  • value (torch.Tensor) – tensor whose dtype has to be checked

  • key (str, optional) – if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key.

unflatten(dim, sizes)

Unflattens a tensorspec.

Check unflatten() for more information on this method.

view(*shape)

Reshapes a tensorspec.

Check reshape() for more information on this method.

zero(shape=None) Tensor

Returns a zero-filled tensor in the box.

Parameters:

shape (torch.Size) – shape of the zero-tensor

Returns:

a zero-filled tensor sampled in the TensorSpec box.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources