CompositeSpec¶
- class torchrl.data.CompositeSpec(*args, **kwargs)[source]¶
A composition of TensorSpecs.
- Parameters:
*args – if an unnamed argument is passed, it must be a dictionary with keys matching the expected keys to be found in the
CompositeSpec
object. This is useful to build nested CompositeSpecs with tuple indices.**kwargs (key (str) – value (TensorSpec)): dictionary of tensorspecs to be stored. Values can be None, in which case is_in will be assumed to be
True
for the corresponding tensors, andproject()
will have no effect. spec.encode cannot be used with missing values.
Examples
>>> pixels_spec = BoundedTensorSpec( ... torch.zeros(3,32,32), ... torch.ones(3, 32, 32)) >>> observation_vector_spec = BoundedTensorSpec(torch.zeros(33), ... torch.ones(33)) >>> composite_spec = CompositeSpec( ... pixels=pixels_spec, ... observation_vector=observation_vector_spec) >>> td = TensorDict({"pixels": torch.rand(10,3,32,32), ... "observation_vector": torch.rand(10,33)}, batch_size=[10]) >>> print("td (rand) is within bounds: ", composite_spec.is_in(td)) td (rand) is within bounds: True >>> td = TensorDict({"pixels": torch.randn(10,3,32,32), ... "observation_vector": torch.randn(10,33)}, batch_size=[10]) >>> print("td (randn) is within bounds: ", composite_spec.is_in(td)) td (randn) is within bounds: False >>> td_project = composite_spec.project(td) >>> print("td modification done in place: ", td_project is td) td modification done in place: True >>> print("check td is within bounds after projection: ", ... composite_spec.is_in(td_project)) check td is within bounds after projection: True >>> print("random td: ", composite_spec.rand([3,])) random td: TensorDict( fields={ observation_vector: Tensor(torch.Size([3, 33]), dtype=torch.float32), pixels: Tensor(torch.Size([3, 3, 32, 32]), dtype=torch.float32)}, batch_size=torch.Size([3]), device=None, is_shared=False)
Examples
>>> # we can build a nested composite spec using unnamed arguments >>> print(CompositeSpec({("a", "b"): None, ("a", "c"): None})) CompositeSpec( a: CompositeSpec( b: None, c: None))
- CompositeSpec supports nested indexing:
>>> spec = CompositeSpec(obs=None) >>> spec["nested", "x"] = None >>> print(spec) CompositeSpec( nested: CompositeSpec( x: None), x: None)
- assert_is_in(value: Tensor) None ¶
Asserts whether a tensor belongs to the box, and raises an exception otherwise.
- Parameters:
value (torch.Tensor) – value to be checked.
- encode(vals: Dict[str, Any], *, ignore_device: bool = False) Dict[str, Tensor] [source]¶
Encodes a value given the specified spec, and return the corresponding tensor.
- Parameters:
val (np.ndarray or torch.Tensor) – value to be encoded as tensor.
- Keyword Arguments:
ignore_device (bool, optional) – if
True
, the spec device will be ignored. This is used to group tensor casting within a call toTensorDict(..., device="cuda")
which is faster.- Returns:
torch.Tensor matching the required tensor specs.
- expand(*shape)[source]¶
Returns a new Spec with the extended shape.
- Parameters:
*shape (tuple or iterable of int) – the new shape of the Spec. Must comply with the current shape: its length must be at least as long as the current shape length, and its last values must be complient too; ie they can only differ from it if the current dimension is a singleton.
- classmethod implements_for_spec(torch_function: Callable) Callable ¶
Register a torch function override for TensorSpec.
- abstract index(index: Union[int, Tensor, ndarray, slice, List], tensor_to_index: Tensor) Tensor ¶
Indexes the input tensor.
- Parameters:
index (int, torch.Tensor, slice or list) – index of the tensor
tensor_to_index – tensor to be indexed
- Returns:
indexed tensor
- is_in(val: Union[dict, TensorDictBase]) bool [source]¶
If the value
val
is in the box defined by the TensorSpec, returns True, otherwise False.- Parameters:
val (torch.Tensor) – value to be checked
- Returns:
boolean indicating if values belongs to the TensorSpec box
- items(include_nested: bool = False, leaves_only: bool = False) ItemsView [source]¶
Items of the CompositeSpec.
- Parameters:
include_nested (bool, optional) – if
False
, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. aCompositeSpec(next=CompositeSpec(obs=None))
will lead to the keys["next"]. Default is ``False`
, i.e. nested keys will not be returned.leaves_only (bool, optional) – if
False
, the values returned will contain every level of nesting, i.e. aCompositeSpec(next=CompositeSpec(obs=None))
will lead to the keys["next", ("next", "obs")]
. Default isFalse
.
- keys(include_nested: bool = False, leaves_only: bool = False) KeysView [source]¶
Keys of the CompositeSpec.
The keys argument reflect those of
tensordict.TensorDict
.- Parameters:
include_nested (bool, optional) – if
False
, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. aCompositeSpec(next=CompositeSpec(obs=None))
will lead to the keys["next"]. Default is ``False`
, i.e. nested keys will not be returned.leaves_only (bool, optional) – if
False
, the values returned will contain every level of nesting, i.e. aCompositeSpec(next=CompositeSpec(obs=None))
will lead to the keys["next", ("next", "obs")]
. Default isFalse
.
- lock_(recurse=False)[source]¶
Locks the CompositeSpec and prevents modification of its content.
This is only a first-level lock, unless specified otherwise through the
recurse
arg.Leaf specs can always be modified in place, but they cannot be replaced in their CompositeSpec parent.
Examples
>>> shape = [3, 4, 5] >>> spec = CompositeSpec( ... a=CompositeSpec( ... b=CompositeSpec(shape=shape[:3], device="cpu"), shape=shape[:2] ... ), ... shape=shape[:1], ... ) >>> spec["a"] = spec["a"].clone() >>> recurse = False >>> spec.lock_(recurse=recurse) >>> try: ... spec["a"] = spec["a"].clone() ... except RuntimeError: ... print("failed!") failed! >>> try: ... spec["a", "b"] = spec["a", "b"].clone() ... print("succeeded!") ... except RuntimeError: ... print("failed!") succeeded! >>> recurse = True >>> spec.lock_(recurse=recurse) >>> try: ... spec["a", "b"] = spec["a", "b"].clone() ... print("succeeded!") ... except RuntimeError: ... print("failed!") failed!
- project(val: TensorDictBase) TensorDictBase [source]¶
If the input tensor is not in the TensorSpec box, it maps it back to it given some heuristic.
- Parameters:
val (torch.Tensor) – tensor to be mapped to the box.
- Returns:
a torch.Tensor belonging to the TensorSpec box.
- rand(shape=None) TensorDictBase [source]¶
Returns a random tensor in the box. The sampling will be uniform unless the box is unbounded.
- Parameters:
shape (torch.Size) – shape of the random tensor
- Returns:
a random tensor sampled in the TensorSpec box.
- squeeze(dim: int | None = None)[source]¶
Returns a new Spec with all the dimensions of size
1
removed.When
dim
is given, a squeeze operation is done only in that dimension.- Parameters:
dim (int or None) – the dimension to apply the squeeze operation to
- to_numpy(val: TensorDict, safe: Optional[bool] = None) dict [source]¶
Returns the np.ndarray correspondent of an input tensor.
- Parameters:
val (torch.Tensor) – tensor to be transformed_in to numpy.
safe (bool) – boolean value indicating whether a check should be performed on the value against the domain of the spec. Defaults to the value of the
CHECK_SPEC_ENCODE
environment variable.
- Returns:
a np.ndarray
- type_check(value: Union[Tensor, TensorDictBase], selected_keys: Optional[Union[str, Sequence[str]]] = None)[source]¶
Checks the input value dtype against the TensorSpec dtype and raises an exception if they don’t match.
- Parameters:
value (torch.Tensor) – tensor whose dtype has to be checked
key (str, optional) – if the TensorSpec has keys, the value dtype will be checked against the spec pointed by the indicated key.
- unlock_(recurse=False)[source]¶
Unlocks the CompositeSpec and allows modification of its content.
This is only a first-level lock modification, unless specified otherwise through the
recurse
arg.
- values(include_nested: bool = False, leaves_only: bool = False) ValuesView [source]¶
Values of the CompositeSpec.
- Parameters:
include_nested (bool, optional) – if
False
, the returned keys will not be nested. They will represent only the immediate children of the root, and not the whole nested sequence, i.e. aCompositeSpec(next=CompositeSpec(obs=None))
will lead to the keys["next"]. Default is ``False`
, i.e. nested keys will not be returned.leaves_only (bool, optional) – if
False
, the values returned will contain every level of nesting, i.e. aCompositeSpec(next=CompositeSpec(obs=None))
will lead to the keys["next", ("next", "obs")]
. Default isFalse
.