Shortcuts

MCTSForest

class torchrl.data.MCTSForest(*, data_map: Optional[TensorDictMap] = None, node_map: Optional[TensorDictMap] = None, max_size: Optional[int] = None, done_keys: Optional[List[NestedKey]] = None, reward_keys: Optional[List[NestedKey]] = None, observation_keys: Optional[List[NestedKey]] = None, action_keys: Optional[List[NestedKey]] = None, excluded_keys: Optional[List[NestedKey]] = None, consolidated: Optional[bool] = None)[source]

A collection of MCTS trees.

Warning

This class is currently under active development. Expect frequent API changes.

The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset.

Keyword Arguments:
  • data_map (TensorDictMap, optional) – the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily initialized using from_tensordict_pair() using the list of observation_keys and action_keys as in_keys.

  • node_map (TensorDictMap, optional) – a map from the observation space to the index space. Internally, the node map is used to gather all possible branches coming out of a given node. For example, if an observation has two associated actions and outcomes in the data map, then the node_map will return a data structure containing the two indices in the data_map that correspond to these two outcomes. If not provided, it is lazily initialized using from_tensordict_pair() using the list of observation_keys as in_keys and the QueryModule as out_keys.

  • max_size (int, optional) – the size of the maps. If not provided, defaults to data_map.max_size if this can be found, then node_map.max_size. If none of these are provided, defaults to 1000.

  • done_keys (list of NestedKey, optional) – the done keys of the environment. If not provided, defaults to ("done", "terminated", "truncated"). The get_keys_from_env() can be used to automatically determine the keys.

  • action_keys (list of NestedKey, optional) – the action keys of the environment. If not provided, defaults to ("action",). The get_keys_from_env() can be used to automatically determine the keys.

  • reward_keys (list of NestedKey, optional) – the reward keys of the environment. If not provided, defaults to ("reward",). The get_keys_from_env() can be used to automatically determine the keys.

  • observation_keys (list of NestedKey, optional) – the observation keys of the environment. If not provided, defaults to ("observation",). The get_keys_from_env() can be used to automatically determine the keys.

  • excluded_keys (list of NestedKey, optional) – a list of keys to exclude from the data storage.

  • consolidated (bool, optional) – if True, the data_map storage will be consolidated on disk. Defaults to False.

Examples

>>> from torchrl.envs import GymEnv
>>> import torch
>>> from tensordict import TensorDict, LazyStackedTensorDict
>>> from torchrl.data import TensorDictMap, ListStorage
>>> from torchrl.data.map.tree import MCTSForest
>>>
>>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter
>>> # Create the MCTS Forest
>>> forest = MCTSForest()
>>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle)
>>> env = PendulumEnv()
>>> obs_keys = list(env.observation_spec.keys(True, True))
>>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys)
>>> # Appending transforms to get an "observation" key that concatenates the observations together
>>> env = env.append_transform(
...     UnsqueezeTransform(
...         in_keys=obs_keys,
...         out_keys=[("unsqueeze", key) for key in obs_keys],
...         dim=-1
...     )
... )
>>> env = env.append_transform(
...     CatTensors([("unsqueeze", key) for key in obs_keys], "observation")
... )
>>> env = env.append_transform(StepCounter())
>>> env.set_seed(0)
>>> # Get a reset state, then make a rollout out of it
>>> reset_state = env.reset()
>>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> # Append the rollout to the forest. We're removing the state entries for clarity
>>> rollout0 = rollout0.copy()
>>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout0)
>>> # The forest should have 6 elements (the length of the rollout)
>>> assert len(forest) == 6
>>> # Let's make another rollout from the same reset state
>>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1)
>>> assert len(forest) == 12
>>> # Let's make another final rollout from an intermediate step in the second rollout
>>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next"))
>>> rollout1b.exclude(*state_keys, inplace=True)
>>> rollout1b.get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1b)
>>> assert len(forest) == 18
>>> # Since we have 2 rollouts starting at the same state, our tree should have two
>>> #  branches if we produce it from the reset entry. Take the state, and call `get_tree`:
>>> r = rollout0[0]
>>> # Let's get the compact tree that follows the initial reset. A compact tree is
>>> #  a tree where nodes that have a single child are collapsed.
>>> tree = forest.get_tree(r)
>>> print(tree.max_length())
2
>>> print(list(tree.valid_paths()))
[(0,), (1, 0), (1, 1)]
>>> from tensordict import assert_close
>>> # We can manually rebuild the tree
>>> assert_close(
...     rollout1,
...     torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]),
...     intersection=True,
... )
True
>>> # Or we can rebuild it using the dedicated method
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0)),
...     intersection=True,
... )
True
>>> tree.plot()
>>> tree = forest.get_tree(r, compact=False)
>>> print(tree.max_length())
9
>>> print(list(tree.valid_paths()))
[(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)]
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0, 0, 0, 0, 0)),
...     intersection=True,
... )
True
property action_keys: List[NestedKey]

Action Keys.

Returns the keys used to retrieve actions from the environment’s input. The default action key is “action”.

Returns:

A list of strings or tuples representing the action keys.

property done_keys: List[NestedKey]

Done Keys.

Returns the keys used to indicate that an episode has ended. The default done keys are “done”, “terminated”, and “truncated”. These keys can be used in the environment’s output to signal the end of an episode.

Returns:

A list of strings representing the done keys.

get_keys_from_env(env: EnvBase)[source]

Writes missing done, action and reward keys to the Forest given an environment.

Existing keys are not overwritten.

property observation_keys: List[NestedKey]

Observation Keys.

Returns the keys used to retrieve observations from the environment’s output. The default observation key is “observation”.

Returns:

A list of strings or tuples representing the observation keys.

property reward_keys: List[NestedKey]

Reward Keys.

Returns the keys used to retrieve rewards from the environment’s output. The default reward key is “reward”.

Returns:

A list of strings or tuples representing the reward keys.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources