Note
Go to the end to download the full example code.
Simplifying PyTorch Memory Management with TensorDict¶
Author: Tom Begley
In this tutorial you will learn how to control where the contents of a
TensorDict
are stored in memory, either by sending those contents to a device,
or by utilizing memory maps.
Devices¶
When you create a TensorDict
, you can specify a device with the device
keyword argument. If the device
is set, then all entries of the
TensorDict
will be placed on that device. If the device
is not set, then
there is no requirement that entries in the TensorDict
must be on the same
device.
In this example we instantiate a TensorDict
with device="cuda:0"
. When
we print the contents we can see that they have been moved onto the device.
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=cuda:0,
is_shared=True)
If the device of the TensorDict
is not None
, new entries are also moved
onto the device.
>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=cuda:0,
is_shared=True)
You can check the current device of the TensorDict
with the device
attribute.
>>> print(tensordict.device)
cuda:0
The contents of the TensorDict
can be sent to a device like a PyTorch tensor
with TensorDict.cuda()
or
TensorDict.device(device)
with device
being the desired device.
>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=cuda:0,
is_shared=True)
The TensorDict.device
method requires a valid
device to be passed as the argument. If you want to remove the device from the
TensorDict
to allow values with different devices, you should use the
TensorDict.clear_device
method.
>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
fields={
a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)
Memory-mapped Tensors¶
tensordict
provides a class MemoryMappedTensor
which allows us to store the contents of a tensor on disk, while still
supporting fast indexing and loading of the contents in batches.
See the ImageNet Tutorial for an
example of this in action.
To convert the TensorDict
to a collection of memory-mapped tensors, use the
TensorDict.memmap_
.
tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()
print(tensordict)
TensorDict(
fields={
a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
Alternatively one can use the
TensorDict.memmap_like
method. This will
create a new TensorDict
of the same structure with
MemoryMappedTensor
values, however it will not copy the
contents of the original tensors to the
memory-mapped tensors. This allows you to create the memory-mapped
TensorDict
and then populate it slowly, and hence should generally be
preferred to memmap_
.
tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()
print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
By default the contents of the TensorDict
will be saved to a temporary
location on disk, however if you would like to control where they are saved you can
use the keyword argument prefix="/path/to/root"
.
The contents of the TensorDict
are saved in a directory structure that mimics
the structure of the TensorDict
itself. The contents of the tensor is saved
in a NumPy memmap, and the metadata in an associated PyTorch save file. For example,
the above TensorDict
is saved as follows:
├── a.memmap
├── a.meta.pt
├── b
│ ├── c.memmap
│ ├── c.meta.pt
│ └── meta.pt
└── meta.pt
Total running time of the script: (0 minutes 0.004 seconds)