Shortcuts

MCTSForest

class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, excluded_keys: List[NestedKey] = None, consolidated: bool | None = None)[source]

A collection of MCTS trees.

Warning

This class is currently under active development. Expect frequent API changes.

The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset.

Keyword Arguments:
  • data_map (TensorDictMap, optional) – the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily initialized using from_tensordict_pair() using the list of observation_keys and action_keys as in_keys.

  • node_map (TensorDictMap, optional) – a map from the observation space to the index space. Internally, the node map is used to gather all possible branches coming out of a given node. For example, if an observation has two associated actions and outcomes in the data map, then the node_map will return a data structure containing the two indices in the data_map that correspond to these two outcomes. If not provided, it is lazily initialized using from_tensordict_pair() using the list of observation_keys as in_keys and the QueryModule as out_keys.

  • max_size (int, optional) – the size of the maps. If not provided, defaults to data_map.max_size if this can be found, then node_map.max_size. If none of these are provided, defaults to 1000.

  • done_keys (list of NestedKey, optional) – the done keys of the environment. If not provided, defaults to ("done", "terminated", "truncated"). The get_keys_from_env() can be used to automatically determine the keys.

  • action_keys (list of NestedKey, optional) – the action keys of the environment. If not provided, defaults to ("action",). The get_keys_from_env() can be used to automatically determine the keys.

  • reward_keys (list of NestedKey, optional) – the reward keys of the environment. If not provided, defaults to ("reward",). The get_keys_from_env() can be used to automatically determine the keys.

  • observation_keys (list of NestedKey, optional) – the observation keys of the environment. If not provided, defaults to ("observation",). The get_keys_from_env() can be used to automatically determine the keys.

  • excluded_keys (list of NestedKey, optional) – a list of keys to exclude from the data storage.

  • consolidated (bool, optional) – if True, the data_map storage will be consolidated on disk. Defaults to False.

Examples

>>> from torchrl.envs import GymEnv
>>> import torch
>>> from tensordict import TensorDict, LazyStackedTensorDict
>>> from torchrl.data import TensorDictMap, ListStorage
>>> from torchrl.data.map.tree import MCTSForest
>>>
>>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter
>>> # Create the MCTS Forest
>>> forest = MCTSForest()
>>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle)
>>> env = PendulumEnv()
>>> obs_keys = list(env.observation_spec.keys(True, True))
>>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys)
>>> # Appending transforms to get an "observation" key that concatenates the observations together
>>> env = env.append_transform(
...     UnsqueezeTransform(
...         in_keys=obs_keys,
...         out_keys=[("unsqueeze", key) for key in obs_keys],
...         dim=-1
...     )
... )
>>> env = env.append_transform(
...     CatTensors([("unsqueeze", key) for key in obs_keys], "observation")
... )
>>> env = env.append_transform(StepCounter())
>>> env.set_seed(0)
>>> # Get a reset state, then make a rollout out of it
>>> reset_state = env.reset()
>>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> # Append the rollout to the forest. We're removing the state entries for clarity
>>> rollout0 = rollout0.copy()
>>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout0)
>>> # The forest should have 6 elements (the length of the rollout)
>>> assert len(forest) == 6
>>> # Let's make another rollout from the same reset state
>>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1)
>>> assert len(forest) == 12
>>> # Let's make another final rollout from an intermediate step in the second rollout
>>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next"))
>>> rollout1b.exclude(*state_keys, inplace=True)
>>> rollout1b.get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1b)
>>> assert len(forest) == 18
>>> # Since we have 2 rollouts starting at the same state, our tree should have two
>>> #  branches if we produce it from the reset entry. Take the state, and call `get_tree`:
>>> r = rollout0[0]
>>> # Let's get the compact tree that follows the initial reset. A compact tree is
>>> #  a tree where nodes that have a single child are collapsed.
>>> tree = forest.get_tree(r)
>>> print(tree.max_length())
2
>>> print(list(tree.valid_paths()))
[(0,), (1, 0), (1, 1)]
>>> from tensordict import assert_close
>>> # We can manually rebuild the tree
>>> assert_close(
...     rollout1,
...     torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]),
...     intersection=True,
... )
True
>>> # Or we can rebuild it using the dedicated method
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0)),
...     intersection=True,
... )
True
>>> tree.plot()
>>> tree = forest.get_tree(r, compact=False)
>>> print(tree.max_length())
9
>>> print(list(tree.valid_paths()))
[(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)]
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0, 0, 0, 0, 0)),
...     intersection=True,
... )
True
property action_keys: List[NestedKey]

Action Keys.

Returns the keys used to retrieve actions from the environment’s input. The default action key is “action”.

Returns:

A list of strings or tuples representing the action keys.

property done_keys: List[NestedKey]

Done Keys.

Returns the keys used to indicate that an episode has ended. The default done keys are “done”, “terminated”, and “truncated”. These keys can be used in the environment’s output to signal the end of an episode.

Returns:

A list of strings representing the done keys.

extend(rollout, *, return_node: bool = False)[source]

Add a rollout to the forest.

Nodes are only added to a tree at points where rollouts diverge from each other and at the endpoints of rollouts.

If there is no existing tree that matches the first steps of the rollout, a new tree is added. Only one node is created, for the final step.

If there is an existing tree that matches, the rollout is added to that tree. If the rollout diverges from all other rollouts in the tree at some step, a new node is created before the step where the rollouts diverge, and a leaf node is created for the final step of the rollout. If all of the rollout’s steps match with a previously added rollout, nothing changes. If the rollout matches up to a leaf node of a tree but continues beyond it, that node is extended to the end of the rollout, and no new nodes are created.

Parameters:
  • rollout (TensorDict) – The rollout to add to the forest.

  • return_node (bool, optional) – If True, the method returns the added node. Default is False.

Returns:

The node that was added to the forest. This is only

returned if return_node is True.

Return type:

Tree

Examples

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> import torch
>>> forest = MCTSForest()
>>> r0 = TensorDict({
...     'action': torch.tensor([1, 2, 3, 4, 5]),
...     'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
...     'observation': torch.tensor([  0, 123, 392, 989, 809])
... }, [5])
>>> r1 = TensorDict({
...     'action': torch.tensor([1, 2, 6, 7]),
...     'next': {'observation': torch.tensor([123, 392, 235,  38])},
...     'observation': torch.tensor([  0, 123, 392, 235])
... }, [4])
>>> td_root = r0[0].exclude("next")
>>> forest.extend(r0)
>>> forest.extend(r1)
>>> tree = forest.get_tree(td_root)
>>> print(tree)
Tree(
    count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
    index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
    node_data=TensorDict(
        fields={
            observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([]),
        device=cpu,
        is_shared=False),
    node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    rollout=TensorDict(
        fields={
            action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
            next: TensorDict(
                fields={
                    observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=cpu,
                is_shared=False),
            observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([2]),
        device=cpu,
        is_shared=False),
    subtree=Tree(
        _parent=NonTensorStack(
            [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
            batch_size=torch.Size([2]),
            device=None),
        count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
        hash=NonTensorStack(
            [4341220243998689835, 6745467818783115365],
            batch_size=torch.Size([2]),
            device=None),
        node_data=LazyStackedTensorDict(
            fields={
                observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        node_id=NonTensorStack(
            [1, 2],
            batch_size=torch.Size([2]),
            device=None),
        rollout=LazyStackedTensorDict(
            fields={
                action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
                next: LazyStackedTensorDict(
                    fields={
                        observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
                    exclusive_fields={
                    },
                    batch_size=torch.Size([2, -1]),
                    device=cpu,
                    is_shared=False,
                    stack_dim=0),
                observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2, -1]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        index=None,
        subtree=None,
        specs=None,
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False),
    wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
    hash=None,
    _parent=None,
    specs=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
get_keys_from_env(env: EnvBase)[source]

Writes missing done, action and reward keys to the Forest given an environment.

Existing keys are not overwritten.

property observation_keys: List[NestedKey]

Observation Keys.

Returns the keys used to retrieve observations from the environment’s output. The default observation key is “observation”.

Returns:

A list of strings or tuples representing the observation keys.

property reward_keys: List[NestedKey]

Reward Keys.

Returns the keys used to retrieve rewards from the environment’s output. The default reward key is “reward”.

Returns:

A list of strings or tuples representing the reward keys.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources