• Docs >
  • Simplifying PyTorch Memory Management with TensorDict
Shortcuts

Simplifying PyTorch Memory Management with TensorDict

Author: Tom Begley

In this tutorial you will learn how to control where the contents of a TensorDict are stored in memory, either by sending those contents to a device, or by utilizing memory maps.

Devices

When you create a TensorDict, you can specify a device with the device keyword argument. If the device is set, then all entries of the TensorDict will be placed on that device. If the device is not set, then there is no requirement that entries in the TensorDict must be on the same device.

In this example we instantiate a TensorDict with device="cuda:0". When we print the contents we can see that they have been moved onto the device.

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

If the device of the TensorDict is not None, new entries are also moved onto the device.

>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

You can check the current device of the TensorDict with the device attribute.

>>> print(tensordict.device)
cuda:0

The contents of the TensorDict can be sent to a device like a PyTorch tensor with TensorDict.cuda() or TensorDict.device(device) with device being the desired device.

>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

The TensorDict.device method requires a valid device to be passed as the argument. If you want to remove the device from the TensorDict to allow values with different devices, you should use the TensorDict.clear_device method.

>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

Memory-mapped Tensors

tensordict provides a class MemoryMappedTensor which allows us to store the contents of a tensor on disk, while still supporting fast indexing and loading of the contents in batches. See the ImageNet Tutorial for an example of this in action.

To convert the TensorDict to a collection of memory-mapped tensors, use the TensorDict.memmap_.

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()

print(tensordict)
TensorDict(
    fields={
        a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)

Alternatively one can use the TensorDict.memmap_like method. This will create a new TensorDict of the same structure with MemoryMappedTensor values, however it will not copy the contents of the original tensors to the memory-mapped tensors. This allows you to create the memory-mapped TensorDict and then populate it slowly, and hence should generally be preferred to memmap_.

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()

print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

By default the contents of the TensorDict will be saved to a temporary location on disk, however if you would like to control where they are saved you can use the keyword argument prefix="/path/to/root".

The contents of the TensorDict are saved in a directory structure that mimics the structure of the TensorDict itself. The contents of the tensor is saved in a NumPy memmap, and the metadata in an associated PyTorch save file. For example, the above TensorDict is saved as follows:

├── a.memmap
├── a.meta.pt
├── b
│   ├── c.memmap
│   ├── c.meta.pt
│   └── meta.pt
└── meta.pt

Total running time of the script: (0 minutes 0.022 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources