MultiOneHotDiscreteTensorSpec¶
- class torchrl.data.MultiOneHotDiscreteTensorSpec(nvec: Sequence[int], shape: Optional[torch.Size] = None, device=None, dtype=torch.bool, use_register=False, mask: torch.Tensor | None = None)[source]¶
A concatenation of one-hot discrete tensor spec.
The last dimension of the shape (domain of the tensor elements) cannot be indexed.
- Parameters:
nvec (iterable of integers) – cardinality of each of the elements of the tensor.
shape (torch.Size, optional) – total shape of the sampled tensors. If provided, the last dimension must match sum(nvec).
device (str, int or torch.device, optional) – device of the tensors.
dtype (str or torch.dtype, optional) – dtype of the tensors.
Examples
>>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) >>> ts.is_in(torch.tensor([0,0,1, ... 0,1, ... 1,0,0])) True >>> ts.is_in(torch.tensor([1,0,1, ... 0,1, ... 1,0,0])) # False False
- assert_is_in(value: Tensor) None ¶
Asserts whether a tensor belongs to the box, and raises an exception otherwise.
- Parameters:
value (torch.Tensor) – value to be checked.
- clear_device_()¶
A no-op for all leaf specs (which must have a device).
- encode(val: Union[ndarray, Tensor], *, ignore_device: bool = False) Tensor [source]¶
Encodes a value given the specified spec, and return the corresponding tensor.
- Parameters:
val (np.ndarray or torch.Tensor) – value to be encoded as tensor.
- Keyword Arguments:
ignore_device (bool, optional) – if
True
, the spec device will be ignored. This is used to group tensor casting within a call toTensorDict(..., device="cuda")
which is faster.- Returns:
torch.Tensor matching the required tensor specs.
- expand(*shape)[source]¶
Returns a new Spec with the extended shape.
- Parameters:
*shape (tuple or iterable of int) – the new shape of the Spec. Must comply with the current shape: its length must be at least as long as the current shape length, and its last values must be complient too; ie they can only differ from it if the current dimension is a singleton.
- classmethod implements_for_spec(torch_function: Callable) Callable ¶
Register a torch function override for TensorSpec.
- index(index: Union[int, Tensor, ndarray, slice, List], tensor_to_index: Tensor) Tensor [source]¶
Indexes the input tensor.
- Parameters:
index (int, torch.Tensor, slice or list) – index of the tensor
tensor_to_index – tensor to be indexed
- Returns:
indexed tensor
- is_in(val: Tensor) bool [source]¶
If the value
val
is in the box defined by the TensorSpec, returns True, otherwise False.- Parameters:
val (torch.Tensor) – value to be checked
- Returns:
boolean indicating if values belongs to the TensorSpec box
- project(val: Tensor) Tensor ¶
If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.
- Parameters:
val (torch.Tensor) – tensor to be mapped to the box.
- Returns:
a torch.Tensor belonging to the TensorSpec box.
- rand(shape: Optional[Size] = None) Tensor [source]¶
Returns a random tensor in the box. The sampling will be uniform unless the box is unbounded.
- Parameters:
shape (torch.Size) – shape of the random tensor
- Returns:
a random tensor sampled in the TensorSpec box.
- squeeze(dim=None)[source]¶
Returns a new Spec with all the dimensions of size
1
removed.When
dim
is given, a squeeze operation is done only in that dimension.- Parameters:
dim (int or None) – the dimension to apply the squeeze operation to
- to_categorical(val: Tensor, safe: Optional[bool] = None) Tensor [source]¶
Converts a given one-hot tensor in categorical format.
- Parameters:
val (torch.Tensor, optional) – One-hot tensor to convert in categorical format.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
The categorical tensor.
- to_categorical_spec() MultiDiscreteTensorSpec [source]¶
Converts the spec to the equivalent categorical spec.
- to_numpy(val: Tensor, safe: Optional[bool] = None) ndarray ¶
Returns the np.ndarray correspondent of an input tensor.
- Parameters:
val (torch.Tensor) – tensor to be transformed_in to numpy.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
a np.ndarray
- type_check(value: Tensor, key: Optional[str] = None) None ¶
Checks the input value dtype against the TensorSpec dtype and raises an exception if they don’t match.
- Parameters:
value (torch.Tensor) – tensor whose dtype has to be checked
key (str, optional) – if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key.