DiscreteTensorSpec¶
- class torchrl.data.DiscreteTensorSpec(n: int, shape: torch.Size | None = None, device: DEVICE_TYPING | None = None, dtype: str | torch.dtype = torch.int64, mask: torch.Tensor | None = None)[source]¶
A discrete tensor spec.
An alternative to OneHotTensorSpec for categorical variables in TorchRL. Instead of using multiplication, categorical variables perform indexing which can speed up computation and reduce memory cost for large categorical variables. The last dimension of the spec (length n of the binary vector) cannot be indexed
Example
>>> batch, size = 3, 4 >>> action_value = torch.arange(batch*size) >>> action_value = action_value.view(batch, size).to(torch.float) >>> action = torch.argmax(action_value, dim=-1).to(torch.long) >>> chosen_action_value = action_value[range(batch), action] >>> print(chosen_action_value) tensor([ 3., 7., 11.])
- Parameters:
n (int) – number of possible outcomes.
shape – (torch.Size, optional): shape of the variable, default is “torch.Size([])”.
device (str, int or torch.device, optional) – device of the tensors.
dtype (str or torch.dtype, optional) – dtype of the tensors.
- assert_is_in(value: Tensor) None ¶
Asserts whether a tensor belongs to the box, and raises an exception otherwise.
- Parameters:
value (torch.Tensor) – value to be checked.
- clear_device_()¶
A no-op for all leaf specs (which must have a device).
- encode(val: Union[ndarray, Tensor], *, ignore_device=False) Tensor ¶
Encodes a value given the specified spec, and return the corresponding tensor.
- Parameters:
val (np.ndarray or torch.Tensor) – value to be encoded as tensor.
- Keyword Arguments:
ignore_device (bool, optional) – if
True
, the spec device will be ignored. This is used to group tensor casting within a call toTensorDict(..., device="cuda")
which is faster.- Returns:
torch.Tensor matching the required tensor specs.
- expand(*shape)[source]¶
Returns a new Spec with the extended shape.
- Parameters:
*shape (tuple or iterable of int) – the new shape of the Spec. Must comply with the current shape: its length must be at least as long as the current shape length, and its last values must be complient too; ie they can only differ from it if the current dimension is a singleton.
- flatten(start_dim, end_dim)¶
Flattens a tensorspec.
Check
flatten()
for more information on this method.
- classmethod implements_for_spec(torch_function: Callable) Callable ¶
Register a torch function override for TensorSpec.
- abstract index(index: Union[int, Tensor, ndarray, slice, List], tensor_to_index: Tensor) Tensor ¶
Indexes the input tensor.
- Parameters:
index (int, torch.Tensor, slice or list) – index of the tensor
tensor_to_index – tensor to be indexed
- Returns:
indexed tensor
- is_in(val: Tensor) bool [source]¶
If the value
val
is in the box defined by the TensorSpec, returns True, otherwise False.- Parameters:
val (torch.Tensor) – value to be checked
- Returns:
boolean indicating if values belongs to the TensorSpec box
- project(val: Tensor) Tensor ¶
If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.
- Parameters:
val (torch.Tensor) – tensor to be mapped to the box.
- Returns:
a torch.Tensor belonging to the TensorSpec box.
- rand(shape=None) Tensor [source]¶
Returns a random tensor in the box. The sampling will be uniform unless the box is unbounded.
- Parameters:
shape (torch.Size) – shape of the random tensor
- Returns:
a random tensor sampled in the TensorSpec box.
- squeeze(dim=None)[source]¶
Returns a new Spec with all the dimensions of size
1
removed.When
dim
is given, a squeeze operation is done only in that dimension.- Parameters:
dim (int or None) – the dimension to apply the squeeze operation to
- to_numpy(val: Tensor, safe: Optional[bool] = None) dict [source]¶
Returns the np.ndarray correspondent of an input tensor.
- Parameters:
val (torch.Tensor) – tensor to be transformed_in to numpy.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
a np.ndarray
- to_one_hot(val: Tensor, safe: Optional[bool] = None) Tensor [source]¶
Encodes a discrete tensor from the spec domain into its one-hot correspondent.
- Parameters:
val (torch.Tensor, optional) – Tensor to one-hot encode.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
The one-hot encoded tensor.
- to_one_hot_spec() OneHotDiscreteTensorSpec [source]¶
Converts the spec to the equivalent one-hot spec.
- type_check(value: Tensor, key: Optional[str] = None) None ¶
Checks the input value dtype against the TensorSpec dtype and raises an exception if they don’t match.
- Parameters:
value (torch.Tensor) – tensor whose dtype has to be checked
key (str, optional) – if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key.
- unflatten(dim, sizes)¶
Unflattens a tensorspec.
Check
unflatten()
for more information on this method.