OneHot
- class torchrl.data.OneHot(n: int, shape: torch.Size | None = None, device: DEVICE_TYPING | None = None, dtype: str | torch.dtype | None = torch.bool, use_register: bool = False, mask: torch.Tensor | None = None)[source]
A unidimensional, one-hot discrete tensor spec.
By default, TorchRL assumes that categorical variables are encoded as one-hot encodings of the variable. This allows for simple indexing of tensors, e.g.
>>> batch, size = 3, 4 >>> action_value = torch.arange(batch*size) >>> action_value = action_value.view(batch, size).to(torch.float) >>> action = (action_value == action_value.max(-1, ... keepdim=True)[0]).to(torch.long) >>> chosen_action_value = (action * action_value).sum(-1) >>> print(chosen_action_value) tensor([ 3., 7., 11.])
The last dimension of the shape (variable domain) cannot be indexed.
- Parameters:
n (int) – number of possible outcomes.
shape (torch.Size, optional) – total shape of the sampled tensors. If provided, the last dimension must match
n
.device (str, int or torch.device, optional) – device of the tensors.
dtype (str or torch.dtype, optional) – dtype of the tensors.
use_register (bool) – experimental feature. If
True
, every integer will be mapped onto a binary vector in the order in which they appear. This feature is designed for environment with no a-priori definition of the number of possible outcomes (e.g. discrete outcomes are sampled from an arbitrary set, whose elements will be mapped in a register to a series of unique one-hot binary vectors).mask (torch.Tensor or None) – mask some of the possible outcomes when a sample is taken. See
update_mask()
for more information.
Examples
>>> from torchrl.data.tensor_specs import OneHot >>> spec = OneHot(5, shape=(2, 5)) >>> spec.rand() tensor([[False, True, False, False, False], [False, True, False, False, False]]) >>> mask = torch.tensor([ ... [False, False, False, False, True], ... [False, False, False, False, True] ... ]) >>> spec.update_mask(mask) >>> spec.rand() tensor([[False, False, False, False, True], [False, False, False, False, True]])
- assert_is_in(value: Tensor) None
Asserts whether a tensor belongs to the box, and raises an exception otherwise.
- Parameters:
value (torch.Tensor) – value to be checked.
- cardinality() int [source]
The cardinality of the spec.
This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite spec is the cartesian product of all possible outcomes.
- clear_device_() T
A no-op for all leaf specs (which must have a device).
For
Composite
specs, this method will erase the device.
- contains(item: torch.Tensor | TensorDictBase) bool
If the value
val
could have been generated by theTensorSpec
, returnsTrue
, otherwiseFalse
.See
is_in()
for more information.
- device: torch.device | None = None
- encode(val: np.ndarray | torch.Tensor, space: CategoricalBox | None = None, *, ignore_device: bool = False) torch.Tensor [source]
Encodes a value given the specified spec, and return the corresponding tensor.
This method is to be used in environments that return a value (eg, a numpy array) that can be easily mapped to the TorchRL required domain. If the value is already a tensor, the spec will not change its value and return it as-is.
- Parameters:
val (np.ndarray or torch.Tensor) – value to be encoded as tensor.
- Keyword Arguments:
ignore_device (bool, optional) – if
True
, the spec device will be ignored. This is used to group tensor casting within a call toTensorDict(..., device="cuda")
which is faster.- Returns:
torch.Tensor matching the required tensor specs.
- enumerate(use_mask: bool = False) Tensor [source]
Returns all the samples that can be obtained from the TensorSpec.
The samples will be stacked along the first dimension.
This method is only implemented for discrete specs.
- Parameters:
use_mask (bool, optional) – If
True
and the spec has a mask, samples that are masked are excluded. Default isFalse
.
- expand(*shape)[source]
Returns a new Spec with the expanded shape.
- Parameters:
*shape (tuple or iterable of int) – the new shape of the Spec. Must be broadcastable with the current shape: its length must be at least as long as the current shape length, and its last values must be compliant too; ie they can only differ from it if the current dimension is a singleton.
- flatten(start_dim: int, end_dim: int) T
Flattens a
TensorSpec
.Check
flatten()
for more information on this method.
- classmethod implements_for_spec(torch_function: Callable) Callable
Register a torch function override for TensorSpec.
- index(index: Union[int, Tensor, ndarray, slice, List], tensor_to_index: Tensor) Tensor [source]
Indexes the input tensor.
This method is to be used with specs that encode one or more categorical variables (e.g.,
OneHot
orCategorical
), such that indexing of a tensor with a sample can be done without caring about the actual representation of the index.- Parameters:
index (int, torch.Tensor, slice or list) – index of the tensor
tensor_to_index – tensor to be indexed
- Returns:
indexed tensor
- Exanples:
>>> from torchrl.data import OneHot >>> import torch >>> >>> one_hot = OneHot(n=100) >>> categ = one_hot.to_categorical_spec() >>> idx_one_hot = torch.zeros((100,), dtype=torch.bool) >>> idx_one_hot[50] = 1 >>> print(one_hot.index(idx_one_hot, torch.arange(100))) tensor(50) >>> idx_categ = one_hot.to_categorical(idx_one_hot) >>> print(categ.index(idx_categ, torch.arange(100))) tensor(50)
- is_in(val: Tensor) bool [source]
If the value
val
could have been generated by theTensorSpec
, returnsTrue
, otherwiseFalse
.More precisely, the
is_in
methods checks that the valueval
is within the limits defined by thespace
attribute (the box), and that thedtype
,device
,shape
potentially other metadata match those of the spec. If any of these checks fails, theis_in
method will returnFalse
.- Parameters:
val (torch.Tensor) – value to be checked.
- Returns:
boolean indicating if values belongs to the TensorSpec box.
- one(shape: torch.Size = None) torch.Tensor | TensorDictBase
Returns a one-filled tensor in the box.
Note
Even though there is no guarantee that
1
belongs to the spec domain, this method will not raise an exception when this condition is violated. The primary use case ofone
is to generate empty data buffers, not meaningful data.- Parameters:
shape (torch.Size) – shape of the one-tensor
- Returns:
a one-filled tensor sampled in the TensorSpec box.
- ones(shape: torch.Size = None) torch.Tensor | TensorDictBase
Proxy to
one()
.
- project(val: torch.Tensor | TensorDictBase) torch.Tensor | TensorDictBase
If the input tensor is not in the TensorSpec box, it maps it back to it given some defined heuristic.
- Parameters:
val (torch.Tensor) – tensor to be mapped to the box.
- Returns:
a torch.Tensor belonging to the TensorSpec box.
- rand(shape: Optional[Size] = None) Tensor [source]
Returns a random tensor in the space defined by the spec.
The sampling will be done uniformly over the space, unless the box is unbounded in which case normal values will be drawn.
- Parameters:
shape (torch.Size) – shape of the random tensor
- Returns:
a random tensor sampled in the TensorSpec box.
- reshape(*shape) T
Reshapes a
TensorSpec
.Check
reshape()
for more information on this method.
- sample(shape: torch.Size = None) torch.Tensor | TensorDictBase
Returns a random tensor in the space defined by the spec.
See
rand()
for details.
- squeeze(dim=None)[source]
Returns a new Spec with all the dimensions of size
1
removed.When
dim
is given, a squeeze operation is done only in that dimension.- Parameters:
dim (int or None) – the dimension to apply the squeeze operation to
- to(dest: torch.dtype | DEVICE_TYPING) OneHot [source]
Casts a TensorSpec to a device or a dtype.
Returns the same spec if no change is made.
- to_categorical(val: Tensor, safe: Optional[bool] = None) Tensor [source]
Converts a given one-hot tensor in categorical format.
- Parameters:
val (torch.Tensor, optional) – One-hot tensor to convert in categorical format.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
The categorical tensor.
Examples
>>> one_hot = OneHot(3, shape=(2, 3)) >>> one_hot_sample = one_hot.rand() >>> one_hot_sample tensor([[False, True, False], [False, True, False]]) >>> categ_sample = one_hot.to_categorical(one_hot_sample) >>> categ_sample tensor([1, 1])
- to_categorical_spec() Categorical [source]
Converts the spec to the equivalent categorical spec.
Examples
>>> one_hot = OneHot(3, shape=(2, 3)) >>> one_hot.to_categorical_spec() Categorical( shape=torch.Size([2]), space=CategoricalBox(n=3), device=cpu, dtype=torch.int64, domain=discrete)
- to_numpy(val: Tensor, safe: Optional[bool] = None) ndarray [source]
Returns the
np.ndarray
correspondent of an input tensor.This is intended to be the inverse operation of
encode()
.- Parameters:
val (torch.Tensor) – tensor to be transformed_in to numpy.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
a np.ndarray.
- type_check(value: Tensor, key: Optional[NestedKey] = None) None
Checks the input value
dtype
against theTensorSpec
dtype
and raises an exception if they don’t match.- Parameters:
value (torch.Tensor) – tensor whose dtype has to be checked.
key (str, optional) – if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key.
- unflatten(dim: int, sizes: tuple[int]) T
Unflattens a
TensorSpec
.Check
unflatten()
for more information on this method.
- unsqueeze(dim: int)[source]
Returns a new Spec with one more singleton dimension (at the position indicated by
dim
).- Parameters:
dim (int or None) – the dimension to apply the unsqueeze operation to.
- update_mask(mask)[source]
Sets a mask to prevent some of the possible outcomes when a sample is taken.
The mask can also be set during initialization of the spec.
- Parameters:
mask (torch.Tensor or None) – boolean mask. If None, the mask is disabled. Otherwise, the shape of the mask must be expandable to the shape of the spec.
False
masks an outcome andTrue
leaves the outcome unmasked. If all the possible outcomes are masked, then an error is raised when a sample is taken.
Examples
>>> mask = torch.tensor([True, False, False]) >>> ts = OneHot(3, (2, 3,), dtype=torch.int64, mask=mask) >>> # All but one of the three possible outcomes are masked >>> ts.rand() tensor([[1, 0, 0], [1, 0, 0]])
- view(*shape) T
Reshapes a
TensorSpec
.Check
reshape()
for more information on this method.
- zero(shape: torch.Size = None) torch.Tensor | TensorDictBase
Returns a zero-filled tensor in the box.
Note
Even though there is no guarantee that
0
belongs to the spec domain, this method will not raise an exception when this condition is violated. The primary use case ofzero
is to generate empty data buffers, not meaningful data.- Parameters:
shape (torch.Size) – shape of the zero-tensor
- Returns:
a zero-filled tensor sampled in the TensorSpec box.
- zeros(shape: torch.Size = None) torch.Tensor | TensorDictBase
Proxy to
zero()
.