Tree¶
- class torchrl.data.Tree(count: 'int' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node: 'TensorDict | None' = None, subtree: "'Tree'" = None, *, batch_size, device=None, names=None)[source]¶
- property batch_size: Size¶
Retrieves the batch size for the tensor class.
- Returns:
batch size (torch.Size)
- classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)¶
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDict) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
- get(key: NestedKey, default: Any = _NoDefault.ZERO)¶
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- get_vertex_by_hash(hash: int) Tree [source]¶
Goes through the tree and returns the node corresponding the given hash.
- get_vertex_by_id(id: int) Tree [source]¶
Goes through the tree and returns the node corresponding the given id.
- property is_terminal¶
Returns True if the the tree has no children nodes.
- classmethod load(prefix: str | pathlib.Path, *args, **kwargs) T ¶
Loads a tensordict from disk.
This class method is a proxy to
load_memmap()
.
- load_(prefix: str | pathlib.Path, *args, **kwargs)¶
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_()
.
- classmethod load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None) T ¶
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True
, synchronize won’t be called after loading tensors on device. Defaults toFalse
.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
Loads a state_dict attemptedly in-place on the destination tensorclass.
- memmap(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: Optional[str] = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDict
instance with data stored as memory-mapped tensors ifreturn_early=False
, otherwise aTensorDictFuture
instance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()
will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
Refreshes the content of the memory-mapped tensordict if it has a
saved_path
.This method will raise an exception if no path is associated with it.
- property num_children: int¶
Number of children of this node.
Equates to the number of elements in the
self.subtree
stack.
- num_vertices(*, count_repeat: bool = False) int [source]¶
Returns the number of unique vertices in the Tree.
- Keyword Arguments:
count_repeat (bool, optional) – Determines whether to count repeated vertices. - If
False
, counts each unique vertex only once. - IfTrue
, counts vertices multiple times if they appear in different paths. Defaults toFalse
.- Returns:
The number of unique vertices in the Tree.
- Return type:
int
- save(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Saves the tensordict to disk.
This function is a proxy to
memmap()
.
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True
, set will tentatively try to update the value in-place. IfFalse
or if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any] ¶
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_tensordict(*, retain_none: Optional[bool] = None) TensorDict ¶
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) –
if
True
, theNone
values will be written in the tensordict. Otherwise they will be discrarded. Default:True
.Note
from v0.8, the default value will be switched to
False
.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)¶
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') Dict[Union[int, Tuple[int]], Tree] [source]¶
Returns a map containing the vertices of the Tree.
- Keyword Arguments:
key_type (Literal["id", "hash", "path"], optional) –
Specifies the type of key to use for the vertices.
”id”: Use the vertex ID as the key.
”hash”: Use a hash of the vertex as the key.
- ”path”: Use the path to the vertex as the key. This may lead to a dictionary with a longer length than
when
"id"
or"hash"
are used as the same node may be part of multiple trajectories. Defaults to"hash"
.
Defaults to an empty string, which may imply a default behavior.
- Returns:
A dictionary mapping keys to Tree vertices.
- Return type:
Dict[int | Tuple[int], Tree]