.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tensordict_slicing.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_tutorials_tensordict_slicing.py>` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tensordict_slicing.py: Slicing, Indexing, and Masking ============================== **Author**: `Tom Begley <https://github.com/tcbegley>`_ In this tutorial you will learn how to slice, index, and mask a :class:`~.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 10-23 As discussed in the tutorial `Manipulating the shape of a TensorDict <./tensordict_shapes.html>`_, when we create a :class:`~.TensorDict` we specify a ``batch_size``, which must agree with the leading dimensions of all entries in the :class:`~.TensorDict`. Since we have a guarantee that all entries share those dimensions in common, we are able to index and mask the batch dimensions in the same way that we would index a :class:`torch.Tensor`. The indices are applied along the batch dimensions to all of the entries in the :class:`~.TensorDict`. For example, given a :class:`~.TensorDict` with two batch dimensions, ``tensordict[0]`` returns a new :class:`~.TensorDict` with the same structure, and whose values correspond to the first "row" of each entry in the original :class:`~.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 23-33 .. code-block:: Python import torch from tensordict import TensorDict tensordict = TensorDict( {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4] ) print(tensordict[0]) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 34-36 The same syntax applies as for regular tensors. For example if we wanted to drop the first row of each entry we could index as follows .. GENERATED FROM PYTHON SOURCE LINES 36-39 .. code-block:: Python print(tensordict[1:]) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 4]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 40-41 We can index multiple dimensions simultaneously .. GENERATED FROM PYTHON SOURCE LINES 41-44 .. code-block:: Python print(tensordict[:, 2:]) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 45-47 We can also use ``Ellipsis`` to represent as many ``:`` as would be needed to make the selection tuple the same length as ``tensordict.batch_dims``. .. GENERATED FROM PYTHON SOURCE LINES 47-50 .. code-block:: Python print(tensordict[..., 2:]) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 51-64 .. note: Remember that all indexing is applied relative to the batch dimensions. In the above example there is a difference between ``tensordict["a"][..., 2:]`` and ``tensordict[..., 2:]["a"]``. The first retrieves the three-dimensional tensor stored under the key ``"a"`` and applies the index ``2:`` to the final dimension. The second applies the index ``2:`` to the final *batch dimension*, which is the second dimension, before retrieving the result. Setting Values with Indexing ---------------------------- In general, ``tensordict[index] = new_tensordict`` will work as long as the batch sizes are compatible. .. GENERATED FROM PYTHON SOURCE LINES 64-73 .. code-block:: Python tensordict = TensorDict( {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4] ) td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4]) tensordict[:-1] = td2 print(tensordict["a"], tensordict["b"]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], [[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]], [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [0., 0., 0., 0.]]) .. GENERATED FROM PYTHON SOURCE LINES 74-77 Masking ------- We mask :class:`TensorDict` as we mask tensors. .. GENERATED FROM PYTHON SOURCE LINES 77-80 .. code-block:: Python mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]]) tensordict[mask] .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([6]), device=None, is_shared=False) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.004 seconds) .. _sphx_glr_download_tutorials_tensordict_slicing.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tensordict_slicing.ipynb <tensordict_slicing.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tensordict_slicing.py <tensordict_slicing.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tensordict_slicing.zip <tensordict_slicing.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_