Source code for torchrl.data.map.tree
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from collections import deque
from typing import Any, Callable, Dict, List, Literal, Tuple
import torch
from tensordict import (
merge_tensordicts,
NestedKey,
TensorClass,
TensorDict,
TensorDictBase,
)
from torchrl.data.map.tdstorage import TensorDictMap
from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree
from torchrl.data.replay_buffers.storages import ListStorage
from torchrl.envs.common import EnvBase
[docs]class Tree(TensorClass["nocast"]):
"""Representation of a single MCTS (Monte Carlo Tree Search) Tree.
This class encapsulates the data and behavior of a tree node in an MCTS algorithm.
It includes attributes for storing information about the node, such as its children,
visit count, and rollout data. Methods are provided for traversing the tree,
computing statistics, and visualizing the tree structure.
It is somewhat indistinguishable from a node or a vertex - we use the term "Tree" when talking about
a node with children, "node" or "vertex" when talking about a place in the tree where a branching occurs.
A node in the tree is defined primarily by its ``hash`` value. Usually, a ``hash`` is determined by a unique
combination of state (or observation) and action. If one observation (found in the ``node`` attribute) has more than
one action associated, each branch will be stored in the ``subtree`` attribute as a stack of ``Tree`` instances.
Attributes:
count (int): The number of visits to this node.
index (torch.Tensor): Indices of the child nodes in the data map.
hash (torch.Tensor): A hash value for this node.
It may be the case that ``hash`` is ``None`` in the specific case where the root of the tree
has more than one action associated. In that case, each subtree branch will have a different action
associated and a hash correspoding to the ``(observation, action)`` pair.
node_id (int): A unique identifier for this node.
rollout (TensorDict): Rollout data following the observation encoded in this node, in a TED format.
If there are multiple actions taken at this node, subtrees are stored in the corresponding
entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method.
node (TensorDict): Data defining this node (e.g., observations) before the next branching.
Entries usually matches the ``in_keys`` in ``MCTSForest.node_map``.
subtree (Tree): A stack of subtrees produced when actions are taken.
num_children (int): The number of child nodes (read-only).
is_terminal (bool): whether the tree has children nodes (read-only).
If the tree is compact, ``is_terminal == True`` means that there are more than one child node in
``self.subtree``.
Methods:
__contains__: Whether another tree can be found in the tree.
vertices: Returns a dictionary containing all vertices in the tree. Keys must be paths, ids or hashes.
num_vertices: Returns the total number of vertices in the tree, with or without duplicates.
edges: Returns a list of edges in the tree.
valid_paths: Yields all valid paths in the tree.
max_length: Returns the maximum length of any path in the tree.
rollout_from_path: Reconstructs a rollout from a given path.
plot: Visualizes the tree using a specified backend and figure type.
get_node_by_id: returns the vertex given by its id in the tree.
get_node_by_hash: returns the vertex given by its hash in the forest.
"""
count: int = None
index: torch.Tensor | None = None
# The hash is None if the node has more than one action associated
hash: int | None = None
node_id: int | None = None
# rollout following the observation encoded in node, in a TorchRL (TED) format
rollout: TensorDict | None = None
# The data specifying the node
node: TensorDict | None = None
# Stack of subtrees. A subtree is produced when an action is taken.
subtree: "Tree" = None
@property
def num_children(self) -> int:
"""Number of children of this node.
Equates to the number of elements in the ``self.subtree`` stack.
"""
return len(self.subtree) if self.subtree is not None else 0
@property
def is_terminal(self):
"""Returns True if the the tree has no children nodes."""
return self.subtree is None
[docs] def get_vertex_by_id(self, id: int) -> Tree:
"""Goes through the tree and returns the node corresponding the given id."""
q = deque()
q.append(self)
while len(q):
tree = q.popleft()
if tree.node_id == id:
return tree
if tree.subtree is not None:
q.extend(tree.subtree.unbind(0))
raise ValueError(f"Node with id {id} not found.")
[docs] def get_vertex_by_hash(self, hash: int) -> Tree:
"""Goes through the tree and returns the node corresponding the given hash."""
q = deque()
q.append(self)
while len(q):
tree = q.popleft()
if tree.hash == hash:
return tree
if tree.subtree is not None:
q.extend(tree.subtree.unbind(0))
raise ValueError(f"Node with hash {hash} not found.")
def __contains__(self, other: Tree) -> bool:
hash = other.hash
for vertex in self.vertices().values():
if vertex.hash == hash:
return True
else:
return False
[docs] def vertices(
self, *, key_type: Literal["id", "hash", "path"] = "hash"
) -> Dict[int | Tuple[int], Tree]:
"""Returns a map containing the vertices of the Tree.
Keyword args:
key_type (Literal["id", "hash", "path"], optional): Specifies the type of key to use for the vertices.
- "id": Use the vertex ID as the key.
- "hash": Use a hash of the vertex as the key.
- "path": Use the path to the vertex as the key. This may lead to a dictionary with a longer length than
when ``"id"`` or ``"hash"`` are used as the same node may be part of multiple trajectories.
Defaults to ``"hash"``.
Defaults to an empty string, which may imply a default behavior.
Returns:
Dict[int | Tuple[int], Tree]: A dictionary mapping keys to Tree vertices.
"""
memo = set()
result = {}
q = deque()
cur_path = ()
q.append((self, cur_path))
use_hash = key_type == "hash"
use_id = key_type == "id"
use_path = key_type == "path"
while len(q):
tree, cur_path = q.popleft()
h = tree.hash
if h in memo and not use_path:
continue
memo.add(h)
r = tree.rollout
if r is not None:
r = r["next", "observation"]
if use_path:
result[cur_path] = tree
elif use_id:
result[tree.node_id] = tree
elif use_hash:
result[tree.node_id] = tree
else:
raise ValueError(
f"key_type must be either 'hash', 'id' or 'path'. Got {key_type}."
)
n = int(tree.num_children)
for i in range(n):
cur_path_tree = cur_path + (i,)
q.append((tree.subtree[i], cur_path_tree))
return result
[docs] def num_vertices(self, *, count_repeat: bool = False) -> int:
"""Returns the number of unique vertices in the Tree.
Keyword Args:
count_repeat (bool, optional): Determines whether to count repeated vertices.
- If ``False``, counts each unique vertex only once.
- If ``True``, counts vertices multiple times if they appear in different paths.
Defaults to ``False``.
Returns:
int: The number of unique vertices in the Tree.
"""
return len(
{
v.node_id
for v in self.vertices(
key_type="hash" if not count_repeat else "path"
).values()
}
)
def edges(self) -> List[Tuple[int, int]]:
result = []
q = deque()
parent = self.node_id
q.append((self, parent))
while len(q):
tree, parent = q.popleft()
n = int(tree.num_children)
for i in range(n):
node = tree.subtree[i]
node_id = node.node_id
result.append((parent, node_id))
q.append((node, node_id))
return result
def valid_paths(self):
q = deque()
cur_path = ()
q.append((self, cur_path))
while len(q):
tree, cur_path = q.popleft()
n = int(tree.num_children)
if not n:
yield cur_path
for i in range(n):
cur_path_tree = cur_path + (i,)
q.append((tree.subtree[i], cur_path_tree))
def max_length(self):
return max(*(len(path) for path in self.valid_paths()))
def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None:
r = self.rollout
tree = self
rollouts = []
if r is not None:
rollouts.append(r)
for i in path:
tree = tree.subtree[i]
r = tree.rollout
if r is not None:
rollouts.append(r)
if rollouts:
return torch.cat(rollouts, dim=-1)
@staticmethod
def _label(info: List[str], tree: "Tree", root=False):
labels = []
for key in info:
if key == "hash":
hash = tree.hash
if hash is not None:
hash = hash.item()
v = f"hash={hash}"
elif root:
v = f"{key}=None"
else:
v = f"{key}={tree.rollout[key].mean().item()}"
labels.append(v)
return ", ".join(labels)
def plot(
self: Tree,
backend: str = "plotly",
figure: str = "tree",
info: List[str] = None,
make_labels: Callable[[Any], Any] | None = None,
):
if backend == "plotly":
if figure == "box":
_plot_plotly_box(self)
return
elif figure == "tree":
_plot_plotly_tree(self, make_labels=make_labels)
return
else:
pass
raise NotImplementedError(
f"Unkown plotting backend {backend} with figure {figure}."
)
[docs]class MCTSForest:
"""A collection of MCTS trees.
The class is aimed at storing rollouts in a storage, and produce trees based on a given root
in that dataset.
Keyword Args:
data_map (TensorDictMap, optional): the storage to use to store the data
(observation, reward, states etc). If not provided, it is lazily
initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`.
node_map (TensorDictMap, optional): TODO
done_keys (list of NestedKey): the done keys of the environment. If not provided,
defaults to ``("done", "terminated", "truncated")``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
action_keys (list of NestedKey): the action keys of the environment. If not provided,
defaults to ``("action",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
reward_keys (list of NestedKey): the reward keys of the environment. If not provided,
defaults to ``("reward",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
observation_keys (list of NestedKey): the observation keys of the environment. If not provided,
defaults to ``("observation",)``.
The :meth:`~.get_keys_from_env` can be used to automatically determine the keys.
consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk.
Defaults to ``False``.
Examples:
>>> from torchrl.envs import GymEnv
>>> import torch
>>> from tensordict import TensorDict, LazyStackedTensorDict
>>> from torchrl.data import TensorDictMap, ListStorage
>>> from torchrl.data.map.tree import MCTSForest
>>>
>>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter
>>> # Create the MCTS Forest
>>> forest = MCTSForest()
>>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle)
>>> env = PendulumEnv()
>>> obs_keys = list(env.observation_spec.keys(True, True))
>>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys)
>>> # Appending transforms to get an "observation" key that concatenates the observations together
>>> env = env.append_transform(
... UnsqueezeTransform(
... in_keys=obs_keys,
... out_keys=[("unsqueeze", key) for key in obs_keys],
... dim=-1
... )
... )
>>> env = env.append_transform(
... CatTensors([("unsqueeze", key) for key in obs_keys], "observation")
... )
>>> env = env.append_transform(StepCounter())
>>> env.set_seed(0)
>>> # Get a reset state, then make a rollout out of it
>>> reset_state = env.reset()
>>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> # Append the rollout to the forest. We're removing the state entries for clarity
>>> rollout0 = rollout0.copy()
>>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout0)
>>> # The forest should have 6 elements (the length of the rollout)
>>> assert len(forest) == 6
>>> # Let's make another rollout from the same reset state
>>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1)
>>> assert len(forest) == 12
>>> # Let's make another final rollout from an intermediate step in the second rollout
>>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next"))
>>> rollout1b.exclude(*state_keys, inplace=True)
>>> rollout1b.get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1b)
>>> assert len(forest) == 18
>>> # Since we have 2 rollouts starting at the same state, our tree should have two
>>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`:
>>> r = rollout0[0]
>>> # Let's get the compact tree that follows the initial reset. A compact tree is
>>> # a tree where nodes that have a single child are collapsed.
>>> tree = forest.get_tree(r)
>>> print(tree.max_length())
2
>>> print(list(tree.valid_paths()))
[(0,), (1, 0), (1, 1)]
>>> from tensordict import assert_close
>>> # We can manually rebuild the tree
>>> assert_close(
... rollout1,
... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]),
... intersection=True,
... )
True
>>> # Or we can rebuild it using the dedicated method
>>> assert_close(
... rollout1,
... tree.rollout_from_path((1, 0)),
... intersection=True,
... )
True
>>> tree.plot()
>>> tree = forest.get_tree(r, compact=False)
>>> print(tree.max_length())
9
>>> print(list(tree.valid_paths()))
[(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)]
>>> assert_close(
... rollout1,
... tree.rollout_from_path((1, 0, 0, 0, 0, 0)),
... intersection=True,
... )
True
"""
def __init__(
self,
*,
data_map: TensorDictMap | None = None,
node_map: TensorDictMap | None = None,
done_keys: List[NestedKey] | None = None,
reward_keys: List[NestedKey] = None,
observation_keys: List[NestedKey] = None,
action_keys: List[NestedKey] = None,
consolidated: bool | None = None,
):
self.data_map = data_map
self.node_map = node_map
self.done_keys = done_keys
self.action_keys = action_keys
self.reward_keys = reward_keys
self.observation_keys = observation_keys
self.consolidated = consolidated
@property
def done_keys(self):
done_keys = getattr(self, "_done_keys", None)
if done_keys is None:
self._done_keys = done_keys = ("done", "terminated", "truncated")
return done_keys
@done_keys.setter
def done_keys(self, value):
self._done_keys = value
@property
def reward_keys(self):
reward_keys = getattr(self, "_reward_keys", None)
if reward_keys is None:
self._reward_keys = reward_keys = ("reward",)
return reward_keys
@reward_keys.setter
def reward_keys(self, value):
self._reward_keys = value
@property
def action_keys(self):
action_keys = getattr(self, "_action_keys", None)
if action_keys is None:
self._action_keys = action_keys = ("action",)
return action_keys
@action_keys.setter
def action_keys(self, value):
self._action_keys = value
@property
def observation_keys(self):
observation_keys = getattr(self, "_observation_keys", None)
if observation_keys is None:
self._observation_keys = observation_keys = ("observation",)
return observation_keys
@observation_keys.setter
def observation_keys(self, value):
self._observation_keys = value
[docs] def get_keys_from_env(self, env: EnvBase):
"""Writes missing done, action and reward keys to the Forest given an environment.
Existing keys are not overwritten.
"""
if getattr(self, "_reward_keys", None) is None:
self.reward_keys = env.reward_keys
if getattr(self, "_done_keys", None) is None:
self.done_keys = env.done_keys
if getattr(self, "_action_keys", None) is None:
self.action_keys = env.action_keys
if getattr(self, "_observation_keys", None) is None:
self.observation_keys = env.observation_keys
@classmethod
def _write_fn_stack(cls, new, old=None):
if old is None:
result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False)
result.set(
"count", torch.ones(result.shape, dtype=torch.int, device=result.device)
)
else:
def cat(name, x, y):
if name == "count":
return x
if y.ndim < x.ndim:
y = y.unsqueeze(0)
result = torch.cat([x, y], 0).unique(dim=0, sorted=False)
return result
result = old.named_apply(cat, new, default=None)
result.set_("count", old.get("count") + 1)
return result
def _make_storage(self, source, dest):
try:
self.data_map = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=[*self.observation_keys, *self.action_keys],
consolidated=self.consolidated,
)
except KeyError as err:
raise KeyError(
"A KeyError occurred during data map creation. This could be due to the wrong setting of a key in the MCTSForest constructor. Scroll up for more info."
) from err
def _make_storage_branches(self, source, dest):
self.node_map = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=[*self.observation_keys],
out_keys=[
*self.data_map.query_module.out_keys, # hash and index
# *self.action_keys,
# *[("next", rk) for rk in self.reward_keys],
"count",
],
storage_constructor=ListStorage,
collate_fn=TensorDict.lazy_stack,
write_fn=self._write_fn_stack,
)
def extend(self, rollout):
source, dest = (
rollout.exclude("next").copy(),
rollout.select("next", *self.action_keys).copy(),
)
if self.data_map is None:
self._make_storage(source, dest)
# We need to set the action somewhere to keep track of what action lead to what child
# # Set the action in the 'next'
# dest[1:] = source[:-1].exclude(*self.done_keys)
self.data_map[source] = dest
value = source
if self.node_map is None:
self._make_storage_branches(source, dest)
self.node_map[source] = TensorDict.lazy_stack(value.unbind(0))
def get_child(self, root: TensorDictBase) -> TensorDictBase:
return self.data_map[root]
def _make_local_tree(
self,
root: TensorDictBase,
index: torch.Tensor | None = None,
compact: bool = True,
) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]:
root = root.select(*self.node_map.in_keys)
node_meta = None
if root in self.node_map:
node_meta = self.node_map[root]
if index is None:
node_meta = self.node_map[root]
index = node_meta["_index"]
elif index is not None:
pass
else:
return None
steps = []
while index.numel() <= 1:
index = index.squeeze()
d = self.data_map.storage[index]
steps.append(merge_tensordicts(d, root, callback_exist=lambda *x: None))
d = d["next"]
if d in self.node_map:
root = d.select(*self.node_map.in_keys)
node_meta = self.node_map[root]
index = node_meta["_index"]
if not compact:
break
else:
index = None
break
rollout = None
if steps:
rollout = torch.stack(steps, -1)
# Will be populated later
hash = node_meta["_hash"]
return (
Tree(
rollout=rollout,
count=node_meta["count"],
node=root,
index=index,
hash=None,
subtree=None,
),
index,
hash,
)
# The recursive implementation is slower and less compatible with compile
# def _make_tree(self, root: TensorDictBase, index: torch.Tensor|None=None)->Tree:
# tree, indices = self._make_local_tree(root, index=index)
# subtrees = []
# if indices is not None:
# for i in indices:
# subtree = self._make_tree(tree.node, index=i)
# subtrees.append(subtree)
# subtrees = TensorDict.lazy_stack(subtrees)
# tree.subtree = subtrees
# return tree
def _make_tree_iter(
self, root, index=None, max_depth: int | None = None, compact: bool = True
):
q = deque()
memo = {}
tree, indices, hash = self._make_local_tree(root, index=index)
tree.node_id = 0
result = tree
depth = 0
counter = 1
if indices is not None:
q.append((tree, indices, hash, depth))
del tree, indices
while len(q):
tree, indices, hash, depth = q.popleft()
extend = max_depth is None or depth < max_depth
subtrees = []
for i, h in zip(indices, hash):
# TODO: remove the .item()
h = h.item()
subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3)
if subtree is None:
subtree, subtree_indices, subtree_hash = self._make_local_tree(
tree.node, index=i, compact=compact
)
subtree.node_id = counter
counter += 1
subtree.hash = h
memo[h] = (subtree, subtree_indices, subtree_hash)
subtrees.append(subtree)
if extend and subtree_indices is not None:
q.append((subtree, subtree_indices, subtree_hash, depth + 1))
subtrees = TensorDict.lazy_stack(subtrees)
tree.subtree = subtrees
return result
def get_tree(
self,
root,
*,
max_depth: int | None = None,
compact: bool = True,
) -> Tree:
return self._make_tree_iter(root=root, max_depth=max_depth, compact=compact)
@classmethod
def valid_paths(cls, tree: Tree):
yield from tree.valid_paths()
def __len__(self):
return len(self.data_map)