MCTSForest
- class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, excluded_keys: List[NestedKey] = None, consolidated: bool | None = None)[source]
A collection of MCTS trees.
Warning
This class is currently under active development. Expect frequent API changes.
The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset.
- Keyword Arguments:
data_map (TensorDictMap, optional) – the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily initialized using
from_tensordict_pair()
using the list ofobservation_keys
andaction_keys
asin_keys
.node_map (TensorDictMap, optional) – a map from the observation space to the index space. Internally, the node map is used to gather all possible branches coming out of a given node. For example, if an observation has two associated actions and outcomes in the data map, then the
node_map
will return a data structure containing the two indices in thedata_map
that correspond to these two outcomes. If not provided, it is lazily initialized usingfrom_tensordict_pair()
using the list ofobservation_keys
asin_keys
and theQueryModule
asout_keys
.max_size (int, optional) – the size of the maps. If not provided, defaults to
data_map.max_size
if this can be found, thennode_map.max_size
. If none of these are provided, defaults to 1000.done_keys (list of NestedKey, optional) – the done keys of the environment. If not provided, defaults to
("done", "terminated", "truncated")
. Theget_keys_from_env()
can be used to automatically determine the keys.action_keys (list of NestedKey, optional) – the action keys of the environment. If not provided, defaults to
("action",)
. Theget_keys_from_env()
can be used to automatically determine the keys.reward_keys (list of NestedKey, optional) – the reward keys of the environment. If not provided, defaults to
("reward",)
. Theget_keys_from_env()
can be used to automatically determine the keys.observation_keys (list of NestedKey, optional) – the observation keys of the environment. If not provided, defaults to
("observation",)
. Theget_keys_from_env()
can be used to automatically determine the keys.excluded_keys (list of NestedKey, optional) – a list of keys to exclude from the data storage.
consolidated (bool, optional) – if
True
, the data_map storage will be consolidated on disk. Defaults toFalse
.
Examples
>>> from torchrl.envs import GymEnv >>> import torch >>> from tensordict import TensorDict, LazyStackedTensorDict >>> from torchrl.data import TensorDictMap, ListStorage >>> from torchrl.data.map.tree import MCTSForest >>> >>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter >>> # Create the MCTS Forest >>> forest = MCTSForest() >>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle) >>> env = PendulumEnv() >>> obs_keys = list(env.observation_spec.keys(True, True)) >>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys) >>> # Appending transforms to get an "observation" key that concatenates the observations together >>> env = env.append_transform( ... UnsqueezeTransform( ... in_keys=obs_keys, ... out_keys=[("unsqueeze", key) for key in obs_keys], ... dim=-1 ... ) ... ) >>> env = env.append_transform( ... CatTensors([("unsqueeze", key) for key in obs_keys], "observation") ... ) >>> env = env.append_transform(StepCounter()) >>> env.set_seed(0) >>> # Get a reset state, then make a rollout out of it >>> reset_state = env.reset() >>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> # Append the rollout to the forest. We're removing the state entries for clarity >>> rollout0 = rollout0.copy() >>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout0) >>> # The forest should have 6 elements (the length of the rollout) >>> assert len(forest) == 6 >>> # Let's make another rollout from the same reset state >>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1) >>> assert len(forest) == 12 >>> # Let's make another final rollout from an intermediate step in the second rollout >>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next")) >>> rollout1b.exclude(*state_keys, inplace=True) >>> rollout1b.get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1b) >>> assert len(forest) == 18 >>> # Since we have 2 rollouts starting at the same state, our tree should have two >>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`: >>> r = rollout0[0] >>> # Let's get the compact tree that follows the initial reset. A compact tree is >>> # a tree where nodes that have a single child are collapsed. >>> tree = forest.get_tree(r) >>> print(tree.max_length()) 2 >>> print(list(tree.valid_paths())) [(0,), (1, 0), (1, 1)] >>> from tensordict import assert_close >>> # We can manually rebuild the tree >>> assert_close( ... rollout1, ... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]), ... intersection=True, ... ) True >>> # Or we can rebuild it using the dedicated method >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0)), ... intersection=True, ... ) True >>> tree.plot() >>> tree = forest.get_tree(r, compact=False) >>> print(tree.max_length()) 9 >>> print(list(tree.valid_paths())) [(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)] >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0, 0, 0, 0, 0)), ... intersection=True, ... ) True
- property action_keys: List[NestedKey]
Action Keys.
Returns the keys used to retrieve actions from the environment’s input. The default action key is “action”.
- Returns:
A list of strings or tuples representing the action keys.
- property done_keys: List[NestedKey]
Done Keys.
Returns the keys used to indicate that an episode has ended. The default done keys are “done”, “terminated”, and “truncated”. These keys can be used in the environment’s output to signal the end of an episode.
- Returns:
A list of strings representing the done keys.
- extend(rollout, *, return_node: bool = False)[source]
Add a rollout to the forest.
Nodes are only added to a tree at points where rollouts diverge from each other and at the endpoints of rollouts.
If there is no existing tree that matches the first steps of the rollout, a new tree is added. Only one node is created, for the final step.
If there is an existing tree that matches, the rollout is added to that tree. If the rollout diverges from all other rollouts in the tree at some step, a new node is created before the step where the rollouts diverge, and a leaf node is created for the final step of the rollout. If all of the rollout’s steps match with a previously added rollout, nothing changes. If the rollout matches up to a leaf node of a tree but continues beyond it, that node is extended to the end of the rollout, and no new nodes are created.
- Parameters:
rollout (TensorDict) – The rollout to add to the forest.
return_node (bool, optional) – If
True
, the method returns the added node. Default isFalse
.
- Returns:
- The node that was added to the forest. This is only
returned if
return_node
is True.
- Return type:
Examples
>>> from torchrl.data import MCTSForest >>> from tensordict import TensorDict >>> import torch >>> forest = MCTSForest() >>> r0 = TensorDict({ ... 'action': torch.tensor([1, 2, 3, 4, 5]), ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])}, ... 'observation': torch.tensor([ 0, 123, 392, 989, 809]) ... }, [5]) >>> r1 = TensorDict({ ... 'action': torch.tensor([1, 2, 6, 7]), ... 'next': {'observation': torch.tensor([123, 392, 235, 38])}, ... 'observation': torch.tensor([ 0, 123, 392, 235]) ... }, [4]) >>> td_root = r0[0].exclude("next") >>> forest.extend(r0) >>> forest.extend(r1) >>> tree = forest.get_tree(td_root) >>> print(tree) Tree( count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False), index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), node_data=TensorDict( fields={ observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False), node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None), rollout=TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=cpu, is_shared=False), subtree=Tree( _parent=NonTensorStack( [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x..., batch_size=torch.Size([2]), device=None), count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False), hash=NonTensorStack( [4341220243998689835, 6745467818783115365], batch_size=torch.Size([2]), device=None), node_data=LazyStackedTensorDict( fields={ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2]), device=cpu, is_shared=False, stack_dim=0), node_id=NonTensorStack( [1, 2], batch_size=torch.Size([2]), device=None), rollout=LazyStackedTensorDict( fields={ action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False), next: LazyStackedTensorDict( fields={ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, -1]), device=cpu, is_shared=False, stack_dim=0), observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([2, -1]), device=cpu, is_shared=False, stack_dim=0), wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), index=None, subtree=None, specs=None, batch_size=torch.Size([2]), device=None, is_shared=False), wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), hash=None, _parent=None, specs=None, batch_size=torch.Size([]), device=None, is_shared=False)
- get_keys_from_env(env: EnvBase)[source]
Writes missing done, action and reward keys to the Forest given an environment.
Existing keys are not overwritten.