.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/tensordict_keys.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_tensordict_keys.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_tensordict_keys.py:


Manipulating the keys of a TensorDict
=====================================
**Author**: `Tom Begley <https://github.com/tcbegley>`_

In this tutorial you will learn how to work with and manipulate the keys in a
``TensorDict``, including getting and setting keys, iterating over keys, manipulating
nested values, and flattening the keys.

.. GENERATED FROM PYTHON SOURCE LINES 12-15

Setting and getting keys
------------------------
We can set and get keys using the same syntax as a Python ``dict``

.. GENERATED FROM PYTHON SOURCE LINES 15-28

.. code-block:: Python


    import torch
    from tensordict.tensordict import TensorDict

    tensordict = TensorDict()

    # set a key
    a = torch.rand(10)
    tensordict["a"] = a

    # retrieve the value stored under "a"
    assert tensordict["a"] is a








.. GENERATED FROM PYTHON SOURCE LINES 34-41

.. note::

   Unlike a Python ``dict``, all keys in the ``TensorDict`` must be strings. However
   as we will see, it is also possible to use tuples of strings to manipulate nested
   values.

We can also use the methods ``.get()`` and ``.set`` to accomplish the same thing.

.. GENERATED FROM PYTHON SOURCE LINES 41-51

.. code-block:: Python


    tensordict = TensorDict()

    # set a key
    a = torch.rand(10)
    tensordict.set("a", a)

    # retrieve the value stored under "a"
    assert tensordict.get("a") is a








.. GENERATED FROM PYTHON SOURCE LINES 52-54

Like ``dict``, we can provide a default value to ``get`` that should be returned in
case the requested key is not found.

.. GENERATED FROM PYTHON SOURCE LINES 54-57

.. code-block:: Python


    assert tensordict.get("banana", a) is a








.. GENERATED FROM PYTHON SOURCE LINES 58-61

Similarly, like ``dict``, we can use the :meth:`TensorDict.setdefault` to get the
value of a particular key, returning a default value if that key is not found, and
also setting that value in the :class:`~.TensorDict`.

.. GENERATED FROM PYTHON SOURCE LINES 61-66

.. code-block:: Python


    assert tensordict.setdefault("banana", a) is a
    # a is now stored under "banana"
    assert tensordict["banana"] is a








.. GENERATED FROM PYTHON SOURCE LINES 67-70

Deleting keys is also achieve in the same way as a Python ``dict``, using the ``del``
statement and the chosen key. Equivalently we could use the
:meth:`TensorDict.del_ <tensordict.TensorDict.del_>` method.

.. GENERATED FROM PYTHON SOURCE LINES 70-73

.. code-block:: Python


    del tensordict["banana"]








.. GENERATED FROM PYTHON SOURCE LINES 74-77

Furthermore, when setting keys with ``.set()`` we can use the keyword argument
``inplace=True`` to make an inplace update, or equivalently use the ``.set_()``
method.

.. GENERATED FROM PYTHON SOURCE LINES 77-90

.. code-block:: Python


    tensordict.set("a", torch.zeros(10), inplace=True)

    # all the entries of the "a" tensor are now zero
    assert (tensordict.get("a") == 0).all()
    # but it's still the same tensor as before
    assert tensordict.get("a") is a

    # we can achieve the same with set_
    tensordict.set_("a", torch.ones(10))
    assert (tensordict.get("a") == 1).all()
    assert tensordict.get("a") is a








.. GENERATED FROM PYTHON SOURCE LINES 91-97

Renaming keys
-------------
To rename a key, simply use the
:meth:`TensorDict.rename_key_ <tensordict.TensorDict.rename_key_>` method. The value
stored under the original key will remain in the :class:`~.TensorDict`, but the key
will be changed to the specified new key.

.. GENERATED FROM PYTHON SOURCE LINES 97-102

.. code-block:: Python


    tensordict.rename_key_("a", "b")
    assert tensordict.get("b") is a
    print(tensordict)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 103-108

Updating multiple values
------------------------
The :meth:`TensorDict.update <tensordict.TensorDict.update>` method can be used to
update a :class:`TensorDict`` with another one or with a ``dict``. Keys that already
exist are overwritten, and keys that do not already exist are created.

.. GENERATED FROM PYTHON SOURCE LINES 108-116

.. code-block:: Python


    tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10])
    tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10]))
    assert (tensordict["a"] == 0).all()
    assert (tensordict["b"] != 0).all()
    assert (tensordict["c"] == 0).all()
    print(tensordict)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
            b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
            c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([10]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 117-122

Nested values
-------------
The values of a ``TensorDict`` can themselves be a ``TensorDict``. We can add nested
values during instantiation, either by adding ``TensorDict`` directly, or using nested
dictionaries

.. GENERATED FROM PYTHON SOURCE LINES 122-132

.. code-block:: Python


    # creating nested values with a nested dict
    nested_tensordict = TensorDict(
        {"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3]
    )
    # creating nested values with a TensorDict
    tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2])

    print(tensordict)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested: TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    double_nested: TensorDict(
                        fields={
                            a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                        batch_size=torch.Size([2, 3]),
                        device=None,
                        is_shared=False)},
                batch_size=torch.Size([2, 3]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 133-134

To access these nested values, we can use tuples of strings. For example

.. GENERATED FROM PYTHON SOURCE LINES 134-138

.. code-block:: Python


    double_nested_a = tensordict["nested", "double_nested", "a"]
    nested_a = tensordict.get(("nested", "a"))








.. GENERATED FROM PYTHON SOURCE LINES 139-140

Similarly we can set nested values using tuples of strings

.. GENERATED FROM PYTHON SOURCE LINES 140-146

.. code-block:: Python


    tensordict["nested", "double_nested", "b"] = torch.rand(2, 3)
    tensordict.set(("nested", "b"), torch.rand(2, 3))

    print(tensordict)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested: TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    double_nested: TensorDict(
                        fields={
                            a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                            b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                        batch_size=torch.Size([2, 3]),
                        device=None,
                        is_shared=False)},
                batch_size=torch.Size([2, 3]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 147-150

Iterating over a TensorDict's contents
--------------------------------------
We can iterate over the keys of a ``TensorDict`` using the ``.keys()`` method.

.. GENERATED FROM PYTHON SOURCE LINES 150-154

.. code-block:: Python


    for key in tensordict.keys():
        print(key)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    a
    nested




.. GENERATED FROM PYTHON SOURCE LINES 155-160

By default this will iterate only over the top-level keys in the ``TensorDict``,
however it is possible to recursively iterate over all of the keys in the
``TensorDict`` with the keyword argument ``include_nested=True``. This will iterate
recursively over all keys in any nested TensorDicts, returning nested keys as tuples
of strings.

.. GENERATED FROM PYTHON SOURCE LINES 160-164

.. code-block:: Python


    for key in tensordict.keys(include_nested=True):
        print(key)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    a
    ('nested', 'a')
    ('nested', 'double_nested', 'a')
    ('nested', 'double_nested', 'b')
    ('nested', 'double_nested')
    ('nested', 'b')
    nested




.. GENERATED FROM PYTHON SOURCE LINES 165-167

In case you want to only iterate over keys corresponding to ``Tensor`` values, you can
additionally specify ``leaves_only=True``.

.. GENERATED FROM PYTHON SOURCE LINES 167-171

.. code-block:: Python


    for key in tensordict.keys(include_nested=True, leaves_only=True):
        print(key)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    a
    ('nested', 'a')
    ('nested', 'double_nested', 'a')
    ('nested', 'double_nested', 'b')
    ('nested', 'b')




.. GENERATED FROM PYTHON SOURCE LINES 172-174

Much like ``dict``, there are also ``.values`` and ``.items`` methods which accept the
same keyword arguments.

.. GENERATED FROM PYTHON SOURCE LINES 174-181

.. code-block:: Python


    for key, value in tensordict.items(include_nested=True):
        if isinstance(value, TensorDict):
            print(f"{key} is a TensorDict")
        else:
            print(f"{key} is a Tensor")





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    a is a Tensor
    nested is a TensorDict
    ('nested', 'a') is a Tensor
    ('nested', 'double_nested') is a TensorDict
    ('nested', 'double_nested', 'a') is a Tensor
    ('nested', 'double_nested', 'b') is a Tensor
    ('nested', 'b') is a Tensor




.. GENERATED FROM PYTHON SOURCE LINES 182-192

Checking for existence of a key
-------------------------------
To check if a key exists in a ``TensorDict``, use the ``in`` operator in conjunction
with ``.keys()``.

.. note::

   Performing ``key in tensordict.keys()`` does efficient ``dict`` lookups of keys
   (recursively at each level in the nested case), and so performance is not
   negatively impacted when there is a large number of keys in the ``TensorDict``.

.. GENERATED FROM PYTHON SOURCE LINES 192-198

.. code-block:: Python


    assert "a" in tensordict.keys()
    # to check for nested keys, set include_nested=True
    assert ("nested", "a") in tensordict.keys(include_nested=True)
    assert ("nested", "banana") not in tensordict.keys(include_nested=True)








.. GENERATED FROM PYTHON SOURCE LINES 199-203

Flattening and unflattening nested keys
---------------------------------------
We can flatten a ``TensorDict`` with nested values using the ``.flatten_keys()``
method.

.. GENERATED FROM PYTHON SOURCE LINES 203-207

.. code-block:: Python


    print(tensordict, end="\n\n")
    print(tensordict.flatten_keys(separator="."))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested: TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    double_nested: TensorDict(
                        fields={
                            a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                            b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                        batch_size=torch.Size([2, 3]),
                        device=None,
                        is_shared=False)},
                batch_size=torch.Size([2, 3]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 208-210

Given a ``TensorDict`` that has been flattened, it is possible to unflatten it again
with the ``.unflatten_keys()`` method.

.. GENERATED FROM PYTHON SOURCE LINES 210-215

.. code-block:: Python


    flattened_tensordict = tensordict.flatten_keys(separator=".")
    print(flattened_tensordict, end="\n\n")
    print(flattened_tensordict.unflatten_keys(separator="."))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
            nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)

    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested: TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                    double_nested: TensorDict(
                        fields={
                            a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                            b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                        batch_size=torch.Size([2]),
                        device=None,
                        is_shared=False)},
                batch_size=torch.Size([2]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 216-219

This can be particularly useful when manipulating the parameters of a
:class:`torch.nn.Module`, as we can end up with a :class:`~.TensorDict` whose
structure mimics the module structure.

.. GENERATED FROM PYTHON SOURCE LINES 219-230

.. code-block:: Python


    import torch.nn as nn

    module = nn.Sequential(
        nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)),
        nn.Linear(10, 1),
    )
    params = TensorDict(dict(module.named_parameters()), []).unflatten_keys()

    print(params)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    TensorDict(
        fields={
            0: TensorDict(
                fields={
                    0: TensorDict(
                        fields={
                            bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False),
                            weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)},
                        batch_size=torch.Size([]),
                        device=None,
                        is_shared=False),
                    1: TensorDict(
                        fields={
                            bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
                            weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)},
                        batch_size=torch.Size([]),
                        device=None,
                        is_shared=False)},
                batch_size=torch.Size([]),
                device=None,
                is_shared=False),
            1: TensorDict(
                fields={
                    bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                    weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
                batch_size=torch.Size([]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([]),
        device=None,
        is_shared=False)




.. GENERATED FROM PYTHON SOURCE LINES 231-238

Selecting and excluding keys
----------------------------
We can obtain a new :class:`~.TensorDict` with a subset of the keys by using
:meth:`TensorDict.select <tensordict.TensorDict.select>`, which returns a new
:class:`~.TensorDict` containing only the specified keys, or
:meth: `TensorDict.exclude <tensordict.TensorDict.exclude>`, which returns a new
:class:`~.TensorDict` with the specified keys omitted.

.. GENERATED FROM PYTHON SOURCE LINES 238-243

.. code-block:: Python


    print("Select:")
    print(tensordict.select("a", ("nested", "a")), end="\n\n")
    print("Exclude:")
    print(tensordict.exclude(("nested", "b"), ("nested", "double_nested")))




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Select:
    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested: TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                batch_size=torch.Size([2, 3]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)

    Exclude:
    TensorDict(
        fields={
            a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
            nested: TensorDict(
                fields={
                    a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
                batch_size=torch.Size([2, 3]),
                device=None,
                is_shared=False)},
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False)





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.009 seconds)


.. _sphx_glr_download_tutorials_tensordict_keys.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: tensordict_keys.ipynb <tensordict_keys.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: tensordict_keys.py <tensordict_keys.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: tensordict_keys.zip <tensordict_keys.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_