.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/tensordict_slicing.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_tensordict_slicing.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_tensordict_slicing.py:


Slicing, Indexing, and Masking
==============================
**Author**: `Tom Begley <https://github.com/tcbegley>`_

In this tutorial you will learn how to slice, index, and mask a :class:`~.TensorDict`.

.. GENERATED FROM PYTHON SOURCE LINES 10-23

As discussed in the tutorial
`Manipulating the shape of a TensorDict <./tensordict_shapes.html>`_, when we create a
:class:`~.TensorDict` we specify a ``batch_size``, which must agree
with the leading dimensions of all entries in the :class:`~.TensorDict`. Since we have
a guarantee that all entries share those dimensions in common, we are able to index
and mask the batch dimensions in the same way that we would index a
:class:`torch.Tensor`. The indices are applied along the batch dimensions to all of
the entries in the :class:`~.TensorDict`.

For example, given a :class:`~.TensorDict` with two batch dimensions,
``tensordict[0]`` returns a new :class:`~.TensorDict` with the same structure, and
whose values correspond to the first "row" of each entry in the original
:class:`~.TensorDict`.

.. GENERATED FROM PYTHON SOURCE LINES 23-33

.. code-block:: Python


    import torch
    from tensordict import TensorDict

    tensordict = TensorDict(
        {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
    )

    print(tensordict[0])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
            b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([4]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 34-36

The same syntax applies as for regular tensors. For example if we wanted to drop the
first row of each entry we could index as follows

.. GENERATED FROM PYTHON SOURCE LINES 36-39

.. code-block:: Python


    print(tensordict[1:])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
            b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([2, 4]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 40-41

We can index multiple dimensions simultaneously

.. GENERATED FROM PYTHON SOURCE LINES 41-44

.. code-block:: Python


    print(tensordict[:, 2:])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
            b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([3, 2]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 45-47

We can also use ``Ellipsis`` to represent as many ``:`` as would be needed to make
the selection tuple the same length as ``tensordict.batch_dims``.

.. GENERATED FROM PYTHON SOURCE LINES 47-50

.. code-block:: Python


    print(tensordict[..., 2:])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
            b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([3, 2]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 51-64

.. note:

   Remember that all indexing is applied relative to the batch dimensions. In the
   above example there is a difference between ``tensordict["a"][..., 2:]`` and
   ``tensordict[..., 2:]["a"]``. The first retrieves the three-dimensional tensor
   stored under the key ``"a"`` and applies the index ``2:`` to the final dimension.
   The second applies the index ``2:`` to the final *batch dimension*, which is the
   second dimension, before retrieving the result.

Setting Values with Indexing
----------------------------
In general, ``tensordict[index] = new_tensordict`` will work as long as the batch
sizes are compatible.

.. GENERATED FROM PYTHON SOURCE LINES 64-73

.. code-block:: Python


    tensordict = TensorDict(
        {"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
    )

    td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
    tensordict[:-1] = td2
    print(tensordict["a"], tensordict["b"])





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    tensor([[[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]],

            [[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]],

            [[0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
            [1., 1., 1., 1.],
            [0., 0., 0., 0.]])




.. GENERATED FROM PYTHON SOURCE LINES 74-77

Masking
-------
We mask :class:`TensorDict` as we mask tensors.

.. GENERATED FROM PYTHON SOURCE LINES 77-80

.. code-block:: Python


    mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])
    tensordict[mask]




.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
            b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([6]),
        device=None,
        is_shared=False)




.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.004 seconds)


.. _sphx_glr_download_tutorials_tensordict_slicing.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: tensordict_slicing.ipynb <tensordict_slicing.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: tensordict_slicing.py <tensordict_slicing.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: tensordict_slicing.zip <tensordict_slicing.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_