.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tensordict_keys.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_tutorials_tensordict_keys.py>` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tensordict_keys.py: Manipulating the keys of a TensorDict ===================================== **Author**: `Tom Begley <https://github.com/tcbegley>`_ In this tutorial you will learn how to work with and manipulate the keys in a ``TensorDict``, including getting and setting keys, iterating over keys, manipulating nested values, and flattening the keys. .. GENERATED FROM PYTHON SOURCE LINES 12-15 Setting and getting keys ------------------------ We can set and get keys using the same syntax as a Python ``dict`` .. GENERATED FROM PYTHON SOURCE LINES 15-28 .. code-block:: Python import torch from tensordict.tensordict import TensorDict tensordict = TensorDict() # set a key a = torch.rand(10) tensordict["a"] = a # retrieve the value stored under "a" assert tensordict["a"] is a .. GENERATED FROM PYTHON SOURCE LINES 34-41 .. note:: Unlike a Python ``dict``, all keys in the ``TensorDict`` must be strings. However as we will see, it is also possible to use tuples of strings to manipulate nested values. We can also use the methods ``.get()`` and ``.set`` to accomplish the same thing. .. GENERATED FROM PYTHON SOURCE LINES 41-51 .. code-block:: Python tensordict = TensorDict() # set a key a = torch.rand(10) tensordict.set("a", a) # retrieve the value stored under "a" assert tensordict.get("a") is a .. GENERATED FROM PYTHON SOURCE LINES 52-54 Like ``dict``, we can provide a default value to ``get`` that should be returned in case the requested key is not found. .. GENERATED FROM PYTHON SOURCE LINES 54-57 .. code-block:: Python assert tensordict.get("banana", a) is a .. GENERATED FROM PYTHON SOURCE LINES 58-61 Similarly, like ``dict``, we can use the :meth:`TensorDict.setdefault` to get the value of a particular key, returning a default value if that key is not found, and also setting that value in the :class:`~.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 61-66 .. code-block:: Python assert tensordict.setdefault("banana", a) is a # a is now stored under "banana" assert tensordict["banana"] is a .. GENERATED FROM PYTHON SOURCE LINES 67-70 Deleting keys is also achieve in the same way as a Python ``dict``, using the ``del`` statement and the chosen key. Equivalently we could use the :meth:`TensorDict.del_ <tensordict.TensorDict.del_>` method. .. GENERATED FROM PYTHON SOURCE LINES 70-73 .. code-block:: Python del tensordict["banana"] .. GENERATED FROM PYTHON SOURCE LINES 74-77 Furthermore, when setting keys with ``.set()`` we can use the keyword argument ``inplace=True`` to make an inplace update, or equivalently use the ``.set_()`` method. .. GENERATED FROM PYTHON SOURCE LINES 77-90 .. code-block:: Python tensordict.set("a", torch.zeros(10), inplace=True) # all the entries of the "a" tensor are now zero assert (tensordict.get("a") == 0).all() # but it's still the same tensor as before assert tensordict.get("a") is a # we can achieve the same with set_ tensordict.set_("a", torch.ones(10)) assert (tensordict.get("a") == 1).all() assert tensordict.get("a") is a .. GENERATED FROM PYTHON SOURCE LINES 91-97 Renaming keys ------------- To rename a key, simply use the :meth:`TensorDict.rename_key_ <tensordict.TensorDict.rename_key_>` method. The value stored under the original key will remain in the :class:`~.TensorDict`, but the key will be changed to the specified new key. .. GENERATED FROM PYTHON SOURCE LINES 97-102 .. code-block:: Python tensordict.rename_key_("a", "b") assert tensordict.get("b") is a print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 103-108 Updating multiple values ------------------------ The :meth:`TensorDict.update <tensordict.TensorDict.update>` method can be used to update a :class:`TensorDict`` with another one or with a ``dict``. Keys that already exist are overwritten, and keys that do not already exist are created. .. GENERATED FROM PYTHON SOURCE LINES 108-116 .. code-block:: Python tensordict = TensorDict({"a": torch.rand(10), "b": torch.rand(10)}, [10]) tensordict.update(TensorDict({"a": torch.zeros(10), "c": torch.zeros(10)}, [10])) assert (tensordict["a"] == 0).all() assert (tensordict["b"] != 0).all() assert (tensordict["c"] == 0).all() print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 117-122 Nested values ------------- The values of a ``TensorDict`` can themselves be a ``TensorDict``. We can add nested values during instantiation, either by adding ``TensorDict`` directly, or using nested dictionaries .. GENERATED FROM PYTHON SOURCE LINES 122-132 .. code-block:: Python # creating nested values with a nested dict nested_tensordict = TensorDict( {"a": torch.rand(2, 3), "double_nested": {"a": torch.rand(2, 3)}}, [2, 3] ) # creating nested values with a TensorDict tensordict = TensorDict({"a": torch.rand(2), "nested": nested_tensordict}, [2]) print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), double_nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 133-134 To access these nested values, we can use tuples of strings. For example .. GENERATED FROM PYTHON SOURCE LINES 134-138 .. code-block:: Python double_nested_a = tensordict["nested", "double_nested", "a"] nested_a = tensordict.get(("nested", "a")) .. GENERATED FROM PYTHON SOURCE LINES 139-140 Similarly we can set nested values using tuples of strings .. GENERATED FROM PYTHON SOURCE LINES 140-146 .. code-block:: Python tensordict["nested", "double_nested", "b"] = torch.rand(2, 3) tensordict.set(("nested", "b"), torch.rand(2, 3)) print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), double_nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 147-150 Iterating over a TensorDict's contents -------------------------------------- We can iterate over the keys of a ``TensorDict`` using the ``.keys()`` method. .. GENERATED FROM PYTHON SOURCE LINES 150-154 .. code-block:: Python for key in tensordict.keys(): print(key) .. rst-class:: sphx-glr-script-out .. code-block:: none a nested .. GENERATED FROM PYTHON SOURCE LINES 155-160 By default this will iterate only over the top-level keys in the ``TensorDict``, however it is possible to recursively iterate over all of the keys in the ``TensorDict`` with the keyword argument ``include_nested=True``. This will iterate recursively over all keys in any nested TensorDicts, returning nested keys as tuples of strings. .. GENERATED FROM PYTHON SOURCE LINES 160-164 .. code-block:: Python for key in tensordict.keys(include_nested=True): print(key) .. rst-class:: sphx-glr-script-out .. code-block:: none a ('nested', 'a') ('nested', 'double_nested', 'a') ('nested', 'double_nested', 'b') ('nested', 'double_nested') ('nested', 'b') nested .. GENERATED FROM PYTHON SOURCE LINES 165-167 In case you want to only iterate over keys corresponding to ``Tensor`` values, you can additionally specify ``leaves_only=True``. .. GENERATED FROM PYTHON SOURCE LINES 167-171 .. code-block:: Python for key in tensordict.keys(include_nested=True, leaves_only=True): print(key) .. rst-class:: sphx-glr-script-out .. code-block:: none a ('nested', 'a') ('nested', 'double_nested', 'a') ('nested', 'double_nested', 'b') ('nested', 'b') .. GENERATED FROM PYTHON SOURCE LINES 172-174 Much like ``dict``, there are also ``.values`` and ``.items`` methods which accept the same keyword arguments. .. GENERATED FROM PYTHON SOURCE LINES 174-181 .. code-block:: Python for key, value in tensordict.items(include_nested=True): if isinstance(value, TensorDict): print(f"{key} is a TensorDict") else: print(f"{key} is a Tensor") .. rst-class:: sphx-glr-script-out .. code-block:: none a is a Tensor nested is a TensorDict ('nested', 'a') is a Tensor ('nested', 'double_nested') is a TensorDict ('nested', 'double_nested', 'a') is a Tensor ('nested', 'double_nested', 'b') is a Tensor ('nested', 'b') is a Tensor .. GENERATED FROM PYTHON SOURCE LINES 182-192 Checking for existence of a key ------------------------------- To check if a key exists in a ``TensorDict``, use the ``in`` operator in conjunction with ``.keys()``. .. note:: Performing ``key in tensordict.keys()`` does efficient ``dict`` lookups of keys (recursively at each level in the nested case), and so performance is not negatively impacted when there is a large number of keys in the ``TensorDict``. .. GENERATED FROM PYTHON SOURCE LINES 192-198 .. code-block:: Python assert "a" in tensordict.keys() # to check for nested keys, set include_nested=True assert ("nested", "a") in tensordict.keys(include_nested=True) assert ("nested", "banana") not in tensordict.keys(include_nested=True) .. GENERATED FROM PYTHON SOURCE LINES 199-203 Flattening and unflattening nested keys --------------------------------------- We can flatten a ``TensorDict`` with nested values using the ``.flatten_keys()`` method. .. GENERATED FROM PYTHON SOURCE LINES 203-207 .. code-block:: Python print(tensordict, end="\n\n") print(tensordict.flatten_keys(separator=".")) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), double_nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 208-210 Given a ``TensorDict`` that has been flattened, it is possible to unflatten it again with the ``.unflatten_keys()`` method. .. GENERATED FROM PYTHON SOURCE LINES 210-215 .. code-block:: Python flattened_tensordict = tensordict.flatten_keys(separator=".") print(flattened_tensordict, end="\n\n") print(flattened_tensordict.unflatten_keys(separator=".")) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), nested.double_nested.a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), nested.double_nested.b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), double_nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 216-219 This can be particularly useful when manipulating the parameters of a :class:`torch.nn.Module`, as we can end up with a :class:`~.TensorDict` whose structure mimics the module structure. .. GENERATED FROM PYTHON SOURCE LINES 219-230 .. code-block:: Python import torch.nn as nn module = nn.Sequential( nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 10)), nn.Linear(10, 1), ) params = TensorDict(dict(module.named_parameters()), []).unflatten_keys() print(params) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ 0: TensorDict( fields={ 0: TensorDict( fields={ bias: Parameter(shape=torch.Size([50]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([50, 100]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 1: TensorDict( fields={ bias: Parameter(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([10, 50]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 1: TensorDict( fields={ bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 231-238 Selecting and excluding keys ---------------------------- We can obtain a new :class:`~.TensorDict` with a subset of the keys by using :meth:`TensorDict.select <tensordict.TensorDict.select>`, which returns a new :class:`~.TensorDict` containing only the specified keys, or :meth: `TensorDict.exclude <tensordict.TensorDict.exclude>`, which returns a new :class:`~.TensorDict` with the specified keys omitted. .. GENERATED FROM PYTHON SOURCE LINES 238-243 .. code-block:: Python print("Select:") print(tensordict.select("a", ("nested", "a")), end="\n\n") print("Exclude:") print(tensordict.exclude(("nested", "b"), ("nested", "double_nested"))) .. rst-class:: sphx-glr-script-out .. code-block:: none Select: TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) Exclude: TensorDict( fields={ a: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), nested: TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2, 3]), device=None, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.009 seconds) .. _sphx_glr_download_tutorials_tensordict_keys.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tensordict_keys.ipynb <tensordict_keys.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tensordict_keys.py <tensordict_keys.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tensordict_keys.zip <tensordict_keys.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_