TensorStorage¶
- class torchrl.data.replay_buffers.TensorStorage(storage, max_size=None, *, device: device = 'cpu', ndim: int = 1)[source]¶
A storage for tensors and tensordicts.
- Parameters:
storage (tensor or TensorDict) – the data buffer to be used.
max_size (int) – size of the storage, i.e. maximum number of elements stored in the buffer.
- Keyword Arguments:
device (torch.device, optional) – device where the sampled tensors will be stored and sent. Default is
torch.device("cpu")
. If “auto” is passed, the device is automatically gathered from the first batch of data passed. This is not enabled by default to avoid data placed on GPU by mistake, causing OOM issues.ndim (int, optional) – the number of dimensions to be accounted for when measuring the storage size. For instance, a storage of shape
[3, 4]
has capacity3
ifndim=1
and12
ifndim=2
. Defaults to1
.
Examples
>>> data = TensorDict({ ... "some data": torch.randn(10, 11), ... ("some", "nested", "data"): torch.randn(10, 11, 12), ... }, batch_size=[10, 11]) >>> storage = TensorStorage(data) >>> len(storage) # only the first dimension is considered as indexable 10 >>> storage.get(0) TensorDict( fields={ some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), some: TensorDict( fields={ nested: TensorDict( fields={ data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False)}, batch_size=torch.Size([11]), device=None, is_shared=False) >>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``
This class also supports tensorclass data.
Examples
>>> from tensordict import tensorclass >>> @tensorclass ... class MyClass: ... foo: torch.Tensor ... bar: torch.Tensor >>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11]) >>> storage = TensorStorage(data) >>> storage.get(0) MyClass( bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False), foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([11]), device=None, is_shared=False)
- attach(buffer: Any) None ¶
This function attaches a sampler to this storage.
Buffers that read from this storage must be included as an attached entity by calling this method. This guarantees that when data in the storage changes, components are made aware of changes even if the storage is shared with other buffers (eg. Priority Samplers).
- Parameters:
buffer – the object that reads from this storage.