.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tensordict_shapes.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_tutorials_tensordict_shapes.py>` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tensordict_shapes.py: Manipulating the shape of a TensorDict ====================================== **Author**: `Tom Begley <https://github.com/tcbegley>`_ In this tutorial you will learn how to manipulate the shape of a :class:`~.TensorDict` and its contents. .. GENERATED FROM PYTHON SOURCE LINES 11-16 When we create a :class:`~.TensorDict` we specify a ``batch_size``, which must agree with the leading dimensions of all entries in the :class:`~.TensorDict`. Since we have a guarantee that all entries share those dimensions in common, :class:`~.TensorDict` is able to expose a number of methods with which we can manipulate the shape of the :class:`~.TensorDict` and its contents. .. GENERATED FROM PYTHON SOURCE LINES 16-20 .. code-block:: Python import torch from tensordict.tensordict import TensorDict .. GENERATED FROM PYTHON SOURCE LINES 26-32 Indexing a ``TensorDict`` ------------------------- Since the batch dimensions are guaranteed to exist on all entries, we can index them as we please, and each entry of the :class:`~.TensorDict` will be indexed in the same way. .. GENERATED FROM PYTHON SOURCE LINES 32-41 .. code-block:: Python a = torch.rand(3, 4) b = torch.rand(3, 4, 5) tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4]) indexed_tensordict = tensordict[:2, 1] assert indexed_tensordict["a"].shape == torch.Size([2]) assert indexed_tensordict["b"].shape == torch.Size([2, 5]) .. GENERATED FROM PYTHON SOURCE LINES 42-49 Reshaping a ``TensorDict`` -------------------------- :meth:`TensorDict.reshape <tensordict.TensorDict.reshape>` works just like :meth:`torch.Tensor.reshape`. It applies to all of the contents of the :class:`~.TensorDict` along the batch dimensions - note the shape of ``b`` in the example below. It also updates the ``batch_size`` attribute. .. GENERATED FROM PYTHON SOURCE LINES 49-56 .. code-block:: Python reshaped_tensordict = tensordict.reshape(-1) assert reshaped_tensordict.batch_size == torch.Size([12]) assert reshaped_tensordict["a"].shape == torch.Size([12]) assert reshaped_tensordict["b"].shape == torch.Size([12, 5]) .. GENERATED FROM PYTHON SOURCE LINES 57-65 Splitting a ``TensorDict`` -------------------------- :meth:`TensorDict.split <tensordict.TensorDict.split>` is similar to :meth:`torch.Tensor.split`. It splits the :class:`~.TensorDict` into chunks. Each chunk is a :class:`~.TensorDict` with the same structure as the original one, but whose entries are views of the corresponding entries in the original :class:`~.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 65-71 .. code-block:: Python chunks = tensordict.split([3, 1], dim=1) assert chunks[0].batch_size == torch.Size([3, 3]) assert chunks[1].batch_size == torch.Size([3, 1]) torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1]) .. GENERATED FROM PYTHON SOURCE LINES 72-104 .. note:: Whenever a function or method accepts a ``dim`` argument, negative dimensions are interpreted relative to the ``batch_size`` of the :class:`~.TensorDict` that the function or method is called on. In particular, if there are nested :class:`~.TensorDict` values with different batch sizes, the negative dimension is always interpreted relative to the batch dimensions of the root. >>> tensordict = TensorDict( ... { ... "a": torch.rand(3, 4), ... "nested": TensorDict({"b": torch.rand(3, 4, 5)}, [3, 4, 5]) ... }, ... [3, 4], ... ) >>> # dim = -2 will be interpreted as the first dimension throughout, as the root >>> # TensorDict has 2 batch dimensions, even though the nested TensorDict has 3 >>> chunks = tensordict.split([2, 1], dim=-2) >>> assert chunks[0].batch_size == torch.Size([2, 4]) >>> assert chunks[0]["nested"].batch_size == torch.Size([2, 4, 5]) As you can see from this example, the :meth:`TensorDict.split <tensordict.TensorDict.split>` method behaves exactly as though we had replaced ``dim=-2`` with ``dim=tensordict.batch_dims - 2`` before calling. Unbind ------ :meth:`TensorDict.unbind <tensordict.TensorDict.unbind>` is similar to :meth:`torch.Tensor.unbind`, and conceptually similar to :meth:`TensorDict.split <tensordict.TensorDict.split>`. It removes the specified dimension and returns a ``tuple`` of all slices along that dimension. .. GENERATED FROM PYTHON SOURCE LINES 104-110 .. code-block:: Python slices = tensordict.unbind(dim=1) assert len(slices) == 4 assert all(s.batch_size == torch.Size([3]) for s in slices) torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0]) .. GENERATED FROM PYTHON SOURCE LINES 111-125 Stacking and concatenating -------------------------- :class:`~.TensorDict` can be used in conjunction with ``torch.cat`` and ``torch.stack``. Stacking ``TensorDict`` ^^^^^^^^^^^^^^^^^^^^^^^ Stacking can done lazily or contiguously. A lazy stack is just a list of tensordicts presented as a stack of tensordicts. It allows users to carry a bag of tensordicts with different content shape, device or key sets. Another advantage is that the stack operation can be expensive, and if only a small subset of keys is required, a lazy stack will be much faster than a proper stack. It relies on the :class:`~tensordict.LazyStackedTensorDict` class. In this case, values will only be stacked on-demand when they are accessed. .. GENERATED FROM PYTHON SOURCE LINES 125-148 .. code-block:: Python from tensordict import LazyStackedTensorDict cloned_tensordict = tensordict.clone() stacked_tensordict = LazyStackedTensorDict.lazy_stack( [tensordict, cloned_tensordict], dim=0 ) print(stacked_tensordict) # Previously, torch.stack was always returning a lazy stack. For consistency with # the regular PyTorch API, this behaviour will soon be adapted to deliver only # dense tensordicts. To control which behaviour you are relying on, you can use # the :func:`~tensordict.utils.set_lazy_legacy` decorator/context manager: from tensordict.utils import set_lazy_legacy with set_lazy_legacy(True): # old behaviour lazy_stack = torch.stack([tensordict, cloned_tensordict]) assert isinstance(lazy_stack, LazyStackedTensorDict) with set_lazy_legacy(False): # new behaviour dense_stack = torch.stack([tensordict, cloned_tensordict]) assert isinstance(dense_stack, TensorDict) .. rst-class:: sphx-glr-script-out .. code-block:: none LazyStackedTensorDict( fields={ a: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, 3, 4]), device=None, is_shared=False, stack_dim=0) .. GENERATED FROM PYTHON SOURCE LINES 149-151 If we index a :class:`~.LazyStackedTensorDict` along the stacking dimension we recover the original :class:`~.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 151-155 .. code-block:: Python assert stacked_tensordict[0] is tensordict assert stacked_tensordict[1] is cloned_tensordict .. GENERATED FROM PYTHON SOURCE LINES 156-159 Accessing a key in the :class:`~.LazyStackedTensorDict` results in those values being stacked. If the key corresponds to a nested :class:`~.TensorDict` then we will recover another :class:`~.LazyStackedTensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 159-162 .. code-block:: Python assert stacked_tensordict["a"].shape == torch.Size([2, 3, 4]) .. GENERATED FROM PYTHON SOURCE LINES 163-186 .. note:: Since values are stacked on-demand, accessing an item multiple times will mean it gets stacked multiple times, which is inefficient. If you need to access a value in the stacked :class:`~.TensorDict` more than once, you may want to consider converting the :class:`LazyStackedTensorDict` to a contiguous :class:`~.TensorDict`, which can be done with the :meth:`LazyStackedTensorDict.to_tensordict <tensordict.LazyStackedTensorDict.to_tensordict>` or :meth:`LazyStackedTensorDict.contiguous <tensordict.LazyStackedTensorDict.contiguous>` methods. >>> assert isinstance(stacked_tensordict.contiguous(), TensorDict) >>> assert isinstance(stacked_tensordict.contiguous(), TensorDict) After calling either of these methods, we will have a regular :class:`TensorDict` containing the stacked values, and no additional computation is performed when values are accessed. Concatenating ``TensorDict`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Concatenation is not done lazily, instead calling :func:`torch.cat` on a list of :class:`~.TensorDict` instances simply returns a :class:`~.TensorDict` whose entries are the concatenated entries of the elements of the list. .. GENERATED FROM PYTHON SOURCE LINES 186-192 .. code-block:: Python concatenated_tensordict = torch.cat([tensordict, cloned_tensordict], dim=0) assert isinstance(concatenated_tensordict, TensorDict) assert concatenated_tensordict.batch_size == torch.Size([6, 4]) assert concatenated_tensordict["b"].shape == torch.Size([6, 4, 5]) .. GENERATED FROM PYTHON SOURCE LINES 193-197 Expanding ``TensorDict`` ------------------------ We can expand all of the entries of a :class:`~.TensorDict` using :meth:`TensorDict.expand <tensordict.TensorDict.expand>`. .. GENERATED FROM PYTHON SOURCE LINES 197-202 .. code-block:: Python exp_tensordict = tensordict.expand(2, *tensordict.batch_size) assert exp_tensordict.batch_size == torch.Size([2, 3, 4]) torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1]) .. GENERATED FROM PYTHON SOURCE LINES 203-208 Squeezing and Unsqueezing ``TensorDict`` ---------------------------------------- We can squeeze or unsqueeze the contents of a :class:`~.TensorDict` with the :meth:`~tensordict.TensorDictBase.squeeze` and :meth:`~tensordict.TensorDictBase.unsqueeze` methods. .. GENERATED FROM PYTHON SOURCE LINES 208-218 .. code-block:: Python tensordict = TensorDict({"a": torch.rand(3, 1, 4)}, [3, 1, 4]) squeezed_tensordict = tensordict.squeeze() assert squeezed_tensordict["a"].shape == torch.Size([3, 4]) print(squeezed_tensordict, end="\n\n") unsqueezed_tensordict = tensordict.unsqueeze(-1) assert unsqueezed_tensordict["a"].shape == torch.Size([3, 1, 4, 1]) print(unsqueezed_tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 1, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 1, 4, 1]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 219-235 .. note:: Until now, operations like :meth:`~tensordict.TensorDictBase.unsqueeze`, :meth:`~tensordict.TensorDictBase.squeeze`, :meth:`~tensordict.TensorDictBase.view`, :meth:`~tensordict.TensorDictBase.permute`, :meth:`~tensordict.TensorDictBase.transpose` were all returning a lazy version of these operations (ie, a container where the original tensordict was stored and where the operations was applied every time a key was accessed). This behaviour will be deprecated in the future and can be already controlled via the :func:`~tensordict.utils.set_lazy_legacy` function: >>> with set_lazy_legacy(True): ... lazy_unsqueeze = tensordict.unsqueeze(0) >>> with set_lazy_legacy(False): ... dense_unsqueeze = tensordict.unsqueeze(0) Bear in mind that as ever, these methods apply only to the batch dimensions. Any non batch dimensions of the entries will be unaffected .. GENERATED FROM PYTHON SOURCE LINES 235-242 .. code-block:: Python tensordict = TensorDict({"a": torch.rand(3, 1, 1, 4)}, [3, 1]) squeezed_tensordict = tensordict.squeeze() # only one of the singleton dimensions is dropped as the other # is not a batch dimension assert squeezed_tensordict["a"].shape == torch.Size([3, 1, 4]) .. GENERATED FROM PYTHON SOURCE LINES 243-247 Viewing a TensorDict -------------------- :class:`~.TensorDict` also supports ``view``. This creates a ``_ViewedTensorDict`` which lazily creates views on its contents when they are accessed. .. GENERATED FROM PYTHON SOURCE LINES 247-256 .. code-block:: Python tensordict = TensorDict({"a": torch.arange(12)}, [12]) # no views are created at this step viewed_tensordict = tensordict.view((2, 3, 2)) # the view of "a" is created on-demand when we access it assert viewed_tensordict["a"].shape == torch.Size([2, 3, 2]) .. GENERATED FROM PYTHON SOURCE LINES 257-266 Permuting batch dimensions -------------------------- The :meth:`TensorDict.permute <tensordict.TensorDict.permute>` method can be used to permute the batch dimensions much like :func:`torch.permute`. Non batch dimensions are left untouched. This operation is lazy, so batch dimensions are only permuted when we try to access the entries. As ever, if you are likely to need to access a particular entry multiple times, consider converting to a :class:`~.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 266-274 .. code-block:: Python tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4]) # swap the batch dimensions permuted_tensordict = tensordict.permute([1, 0]) assert permuted_tensordict["a"].shape == torch.Size([4, 3]) assert permuted_tensordict["b"].shape == torch.Size([4, 3, 5]) .. GENERATED FROM PYTHON SOURCE LINES 275-285 Using tensordicts as decorators ------------------------------- For a bunch of reversible operations, tensordicts can be used as decorators. These operations include :meth:`~tensordict.TensorDictBase.to_module` for functional calls, :meth:`~tensordict.TensorDictBase.unlock_` and :meth:`~tensordict.TensorDictBase.lock_` or shape operations such as :meth:`~tensordict.TensorDictBase.view`, :meth:`~tensordict.TensorDictBase.permute` :meth:`~tensordict.TensorDictBase.transpose`, :meth:`~tensordict.TensorDictBase.squeeze` and :meth:`~tensordict.TensorDictBase.unsqueeze`. Here is a quick example with the ``transpose`` function: .. GENERATED FROM PYTHON SOURCE LINES 285-296 .. code-block:: Python tensordict = TensorDict({"a": torch.rand(3, 4), "b": torch.rand(3, 4, 5)}, [3, 4]) with tensordict.transpose(1, 0) as tdt: tdt.set("c", torch.ones(4, 3)) # we have permuted the dims # the ``"c"`` entry is now in the tensordict we used as decorator: # assert (tensordict.get("c") == 1).all() .. GENERATED FROM PYTHON SOURCE LINES 297-302 Gathering values in ``TensorDict`` ---------------------------------- The :meth:`TensorDict.gather <tensordict.TensorDict.gather>` method can be used to index along the batch dimensions and gather the results into a single dimension much like :func:`torch.gather`. .. GENERATED FROM PYTHON SOURCE LINES 302-308 .. code-block:: Python index = torch.randint(4, (3, 4)) gathered_tensordict = tensordict.gather(dim=1, index=index) print("index:\n", index, end="\n\n") print("tensordict['a']:\n", tensordict["a"], end="\n\n") print("gathered_tensordict['a']:\n", gathered_tensordict["a"], end="\n\n") .. rst-class:: sphx-glr-script-out .. code-block:: none index: tensor([[2, 1, 1, 1], [3, 0, 0, 3], [2, 2, 2, 1]]) tensordict['a']: tensor([[0.1095, 0.4947, 0.1063, 0.6060], [0.6785, 0.0642, 0.5362, 0.6046], [0.9209, 0.7894, 0.6459, 0.2854]]) gathered_tensordict['a']: tensor([[0.1063, 0.4947, 0.4947, 0.4947], [0.6046, 0.6785, 0.6785, 0.6046], [0.6459, 0.6459, 0.6459, 0.7894]]) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.008 seconds) .. _sphx_glr_download_tutorials_tensordict_shapes.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tensordict_shapes.ipynb <tensordict_shapes.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tensordict_shapes.py <tensordict_shapes.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tensordict_shapes.zip <tensordict_shapes.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_