MCTSForest¶
- class torchrl.data.MCTSForest(*, data_map: Optional[TensorDictMap] = None, node_map: Optional[TensorDictMap] = None, max_size: Optional[int] = None, done_keys: Optional[List[NestedKey]] = None, reward_keys: Optional[List[NestedKey]] = None, observation_keys: Optional[List[NestedKey]] = None, action_keys: Optional[List[NestedKey]] = None, excluded_keys: Optional[List[NestedKey]] = None, consolidated: Optional[bool] = None)[source]¶
A collection of MCTS trees.
Warning
This class is currently under active development. Expect frequent API changes.
The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset.
- Keyword Arguments:
data_map (TensorDictMap, optional) – the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily initialized using
from_tensordict_pair()
using the list ofobservation_keys
andaction_keys
asin_keys
.node_map (TensorDictMap, optional) – a map from the observation space to the index space. Internally, the node map is used to gather all possible branches coming out of a given node. For example, if an observation has two associated actions and outcomes in the data map, then the
node_map
will return a data structure containing the two indices in thedata_map
that correspond to these two outcomes. If not provided, it is lazily initialized usingfrom_tensordict_pair()
using the list ofobservation_keys
asin_keys
and theQueryModule
asout_keys
.max_size (int, optional) – the size of the maps. If not provided, defaults to
data_map.max_size
if this can be found, thennode_map.max_size
. If none of these are provided, defaults to 1000.done_keys (list of NestedKey, optional) – the done keys of the environment. If not provided, defaults to
("done", "terminated", "truncated")
. Theget_keys_from_env()
can be used to automatically determine the keys.action_keys (list of NestedKey, optional) – the action keys of the environment. If not provided, defaults to
("action",)
. Theget_keys_from_env()
can be used to automatically determine the keys.reward_keys (list of NestedKey, optional) – the reward keys of the environment. If not provided, defaults to
("reward",)
. Theget_keys_from_env()
can be used to automatically determine the keys.observation_keys (list of NestedKey, optional) – the observation keys of the environment. If not provided, defaults to
("observation",)
. Theget_keys_from_env()
can be used to automatically determine the keys.excluded_keys (list of NestedKey, optional) – a list of keys to exclude from the data storage.
consolidated (bool, optional) – if
True
, the data_map storage will be consolidated on disk. Defaults toFalse
.
Examples
>>> from torchrl.envs import GymEnv >>> import torch >>> from tensordict import TensorDict, LazyStackedTensorDict >>> from torchrl.data import TensorDictMap, ListStorage >>> from torchrl.data.map.tree import MCTSForest >>> >>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter >>> # Create the MCTS Forest >>> forest = MCTSForest() >>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle) >>> env = PendulumEnv() >>> obs_keys = list(env.observation_spec.keys(True, True)) >>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys) >>> # Appending transforms to get an "observation" key that concatenates the observations together >>> env = env.append_transform( ... UnsqueezeTransform( ... in_keys=obs_keys, ... out_keys=[("unsqueeze", key) for key in obs_keys], ... dim=-1 ... ) ... ) >>> env = env.append_transform( ... CatTensors([("unsqueeze", key) for key in obs_keys], "observation") ... ) >>> env = env.append_transform(StepCounter()) >>> env.set_seed(0) >>> # Get a reset state, then make a rollout out of it >>> reset_state = env.reset() >>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> # Append the rollout to the forest. We're removing the state entries for clarity >>> rollout0 = rollout0.copy() >>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout0) >>> # The forest should have 6 elements (the length of the rollout) >>> assert len(forest) == 6 >>> # Let's make another rollout from the same reset state >>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone()) >>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1) >>> assert len(forest) == 12 >>> # Let's make another final rollout from an intermediate step in the second rollout >>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next")) >>> rollout1b.exclude(*state_keys, inplace=True) >>> rollout1b.get("next").exclude(*state_keys, inplace=True) >>> forest.extend(rollout1b) >>> assert len(forest) == 18 >>> # Since we have 2 rollouts starting at the same state, our tree should have two >>> # branches if we produce it from the reset entry. Take the state, and call `get_tree`: >>> r = rollout0[0] >>> # Let's get the compact tree that follows the initial reset. A compact tree is >>> # a tree where nodes that have a single child are collapsed. >>> tree = forest.get_tree(r) >>> print(tree.max_length()) 2 >>> print(list(tree.valid_paths())) [(0,), (1, 0), (1, 1)] >>> from tensordict import assert_close >>> # We can manually rebuild the tree >>> assert_close( ... rollout1, ... torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]), ... intersection=True, ... ) True >>> # Or we can rebuild it using the dedicated method >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0)), ... intersection=True, ... ) True >>> tree.plot() >>> tree = forest.get_tree(r, compact=False) >>> print(tree.max_length()) 9 >>> print(list(tree.valid_paths())) [(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)] >>> assert_close( ... rollout1, ... tree.rollout_from_path((1, 0, 0, 0, 0, 0)), ... intersection=True, ... ) True
- property action_keys: List[NestedKey]¶
Action Keys.
Returns the keys used to retrieve actions from the environment’s input. The default action key is “action”.
- Returns:
A list of strings or tuples representing the action keys.
- property done_keys: List[NestedKey]¶
Done Keys.
Returns the keys used to indicate that an episode has ended. The default done keys are “done”, “terminated”, and “truncated”. These keys can be used in the environment’s output to signal the end of an episode.
- Returns:
A list of strings representing the done keys.
- get_keys_from_env(env: EnvBase)[source]¶
Writes missing done, action and reward keys to the Forest given an environment.
Existing keys are not overwritten.
- property observation_keys: List[NestedKey]¶
Observation Keys.
Returns the keys used to retrieve observations from the environment’s output. The default observation key is “observation”.
- Returns:
A list of strings or tuples representing the observation keys.
- property reward_keys: List[NestedKey]¶
Reward Keys.
Returns the keys used to retrieve rewards from the environment’s output. The default reward key is “reward”.
- Returns:
A list of strings or tuples representing the reward keys.