Tree¶
- class torchrl.data.Tree(count: 'int | torch.Tensor' = None, wins: 'int | torch.Tensor' = None, index: 'torch.Tensor | None' = None, hash: 'int | None' = None, node_id: 'int | None' = None, rollout: 'TensorDict | None' = None, node_data: 'TensorDict | None' = None, subtree: "'Tree'" = None, _parent: 'weakref.ref | List[weakref.ref] | None' = None, specs: 'Composite | None' = None, *, batch_size, device=None, names=None)[source]¶
- property batch_size: Size¶
Retrieves the batch size for the tensor class.
- Returns:
batch size (torch.Size)
- property branching_action: torch.Tensor | tensordict.base.TensorDictBase | None¶
Returns the action that branched out to this particular node.
- Returns:
a tensor, tensordict or None if the node has no parent.
See also
This will be equal to
prev_action
whenever the rollout data contains a single step.See also
All actions associated with a given node (or observation) in the tree
.
- edges() List[Tuple[int, int]] [source]¶
Retrieves a list of edges in the tree.
Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited.
- Returns:
A list of tuples, where each tuple contains a parent node ID and a child node ID.
- classmethod fields()¶
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- classmethod from_tensordict(tensordict, non_tensordict=None, safe=True)¶
Tensor class wrapper to instantiate a new tensor class object.
- Parameters:
tensordict (TensorDict) – Dictionary of tensor types
non_tensordict (dict) – Dictionary with non-tensor and nested tensor class objects
- property full_action_spec¶
The action spec of the tree.
This is an alias for Tree.specs[‘input_spec’, ‘full_action_spec’].
- property full_done_spec¶
The done spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_done_spec’].
- property full_observation_spec¶
The observation spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_observation_spec’].
- property full_reward_spec¶
The reward spec of the tree.
This is an alias for Tree.specs[‘output_spec’, ‘full_reward_spec’].
- property full_state_spec¶
The state spec of the tree.
This is an alias for Tree.specs[‘input_spec’, ‘full_state_spec’].
- fully_expanded(env: EnvBase) bool [source]¶
Returns True if the number of children is equal to the environment cardinality.
- get(key: NestedKey, default: Any = _NoDefault.ZERO)¶
Gets the value stored with the input key.
- Parameters:
key (str, tuple of str) – key to be queried. If tuple of str it is equivalent to chained calls of getattr.
default – default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- get_vertex_by_hash(hash: int) Tree [source]¶
Goes through the tree and returns the node corresponding the given hash.
- get_vertex_by_id(id: int) Tree [source]¶
Goes through the tree and returns the node corresponding the given id.
- property is_terminal: bool | torch.Tensor¶
Returns True if the tree has no children nodes.
- classmethod load(prefix: str | pathlib.Path, *args, **kwargs) T ¶
Loads a tensordict from disk.
This class method is a proxy to
load_memmap()
.
- load_(prefix: str | pathlib.Path, *args, **kwargs)¶
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_()
.
- classmethod load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None) T ¶
Loads a memory-mapped tensordict from disk.
- Parameters:
prefix (str or Path to folder) – the path to the folder where the saved tensordict should be fetched.
device (torch.device or equivalent, optional) – if provided, the data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
non_blocking (bool, optional) – if
True
, synchronize won’t be called after loading tensors on device. Defaults toFalse
.out (TensorDictBase, optional) – optional tensordict where the data should be written.
Examples
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
Examples
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
Examples
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)¶
Loads a state_dict attemptedly in-place on the destination tensorclass.
- classmethod make_node(data: TensorDictBase, *, device: Optional[device] = None, batch_size: Optional[Size] = None, specs: Optional[Composite] = None) Tree [source]¶
Creates a new node given some data.
- max_length()[source]¶
Returns the maximum length of all valid paths in the tree.
The length of a path is defined as the number of nodes in the path. If the tree is empty, returns 0.
- Returns:
The maximum length of all valid paths in the tree.
- Return type:
int
- memmap(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) T ¶
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False
, otherwise aTensorDictFuture
instance.
Note
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: Optional[str] = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Parameters:
prefix (str) – directory prefix where the memory-mapped tensors will be stored. The directory tree structure will mimic the tensordict’s.
copy_existing (bool) – If False (default), an exception will be raised if an entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True
, any existing Tensor will be copied to the new location.
- Keyword Arguments:
num_threads (int, optional) – the number of threads used to write the memmap tensors. Defaults to 0.
return_early (bool, optional) – if
True
andnum_threads>0
, the method will return a future of the tensordict.share_non_tensor (bool, optional) – if
True
, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults toFalse
.existsok (bool, optional) – if
False
, an exception will be raised if a tensor already exists in the same path. Defaults toTrue
.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False
, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDict
instance with data stored as memory-mapped tensors ifreturn_early=False
, otherwise aTensorDictFuture
instance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()
will copy the information, which can be slow for large content.Examples
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
Refreshes the content of the memory-mapped tensordict if it has a
saved_path
.This method will raise an exception if no path is associated with it.
- property node_observation: torch.Tensor | tensordict.base.TensorDictBase¶
Returns the observation associated with this particular node.
This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a
rollout
attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e.,node.rollout[..., -1]["next", "observation"]
.If more than one observation key is associated with the tree specs, a
TensorDict
instance is returned instead.For a more consistent representation, see
node_observations
.
- property node_observations: torch.Tensor | tensordict.base.TensorDictBase¶
Returns the observations associated with this particular node in a TensorDict format.
This is the observation (or bag of observations) that defines the node before a branching occurs. If the node contains a
rollout
attribute, the node observation is typically identical to the observation resulting from the last action undertaken, i.e.,node.rollout[..., -1]["next", "observation"]
.If more than one observation key is associated with the tree specs, a
TensorDict
instance is returned instead.For a more consistent representation, see
node_observations
.
- property num_children: int¶
Number of children of this node.
Equates to the number of elements in the
self.subtree
stack.
- num_vertices(*, count_repeat: bool = False) int [source]¶
Returns the number of unique vertices in the Tree.
- Keyword Arguments:
count_repeat (bool, optional) – Determines whether to count repeated vertices. - If
False
, counts each unique vertex only once. - IfTrue
, counts vertices multiple times if they appear in different paths. Defaults toFalse
.- Returns:
The number of unique vertices in the Tree.
- Return type:
int
- property parent: torchrl.data.map.tree.Tree | None¶
The parent of the node.
If the node has a parent and this object is still present in the python workspace, it will be returned by this property.
For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to a different parent.
Note
the
parent
attribute will match in content but not in identity: the tensorclass object is recustructed using the same tensors (i.e., tensors that point to the same memory locations).- Returns:
A
Tree
containing the parent data orNone
if the parent data is out of scope or the node is the root.
- plot(backend: str = 'plotly', figure: str = 'tree', info: Optional[List[str]] = None, make_labels: Optional[Callable[[Any, ...], Any]] = None)[source]¶
Plots a visualization of the tree using the specified backend and figure type.
- Parameters:
backend – The plotting backend to use. Currently only supports ‘plotly’.
figure – The type of figure to plot. Can be either ‘tree’ or ‘box’.
info – A list of additional information to include in the plot (not currently used).
make_labels – An optional function to generate custom labels for the plot.
- Raises:
NotImplementedError – If an unsupported backend or figure type is specified.
- property prev_action: torch.Tensor | tensordict.base.TensorDictBase | None¶
The action undertaken just before this node’s observation was generated.
- Returns:
a tensor, tensordict or None if the node has no parent.
See also
This will be equal to
branching_action
whenever the rollout data contains a single step.See also
All actions associated with a given node (or observation) in the tree
.
- rollout_from_path(path: Tuple[int]) tensordict.base.TensorDictBase | None [source]¶
Retrieves the rollout data along a given path in the tree.
The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. If no rollout data is found along the path, returns
None
.- Parameters:
path – A tuple of integers representing the path in the tree.
- Returns:
The concatenated rollout data along the path, or None if no data is found.
- save(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) T ¶
Saves the tensordict to disk.
This function is a proxy to
memmap()
.
- property selected_actions: torch.Tensor | tensordict.base.TensorDictBase | None¶
Returns a tensor containing all the selected actions branching out from this node.
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
Sets a new key-value pair.
- Parameters:
key (str, tuple of str) – name of the key to be set. If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any) – value to be stored in the tensorclass
inplace (bool, optional) – if
True
, set will tentatively try to update the value in-place. IfFalse
or if the key isn’t present, the value will be simply written at its destination.
- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any] ¶
Returns a state_dict dictionary that can be used to save and load data from a tensorclass.
- to_tensordict(*, retain_none: Optional[bool] = None) TensorDict ¶
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Parameters:
retain_none (bool) –
if
True
, theNone
values will be written in the tensordict. Otherwise they will be discrarded. Default:True
.Note
from v0.8, the default value will be switched to
False
.- Returns:
A new TensorDict object containing the same values as the tensorclass.
- unbind(dim: int)¶
Returns a tuple of indexed tensorclass instances unbound along the indicated dimension.
Resulting tensorclass instances will share the storage of the initial tensorclass instance.
- valid_paths()[source]¶
Generates all valid paths in the tree.
A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node.
- Yields:
tuple – A valid path in the tree.
- vertices(*, key_type: Literal['id', 'hash', 'path'] = 'hash') Dict[Union[int, Tuple[int]], Tree] [source]¶
Returns a map containing the vertices of the Tree.
- Keyword Arguments:
key_type (Literal["id", "hash", "path"], optional) –
Specifies the type of key to use for the vertices.
”id”: Use the vertex ID as the key.
”hash”: Use a hash of the vertex as the key.
- ”path”: Use the path to the vertex as the key. This may lead to a dictionary with a longer length than
when
"id"
or"hash"
are used as the same node may be part of multiple trajectories. Defaults to"hash"
.
Defaults to an empty string, which may imply a default behavior.
- Returns:
A dictionary mapping keys to Tree vertices.
- Return type:
Dict[int | Tuple[int], Tree]
- property visits: int | torch.Tensor¶
Returns the number of visits associated with this particular node.
This is an alias for the
count
attribute.