Shortcuts

Source code for torchrl.modules.planners.cem

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from tensordict import TensorDict, TensorDictBase

from torchrl.envs.common import EnvBase
from torchrl.modules.planners.common import MPCPlannerBase


[docs]class CEMPlanner(MPCPlannerBase): """CEMPlanner Module. Reference: The cross-entropy method for optimization, Botev et al. 2013 This module will perform a CEM planning step when given a TensorDict containing initial states. The CEM planning step is performed by sampling actions from a Gaussian distribution with zero mean and unit variance. The sampled actions are then used to perform a rollout in the environment. The cumulative rewards obtained with the rollout is then ranked. We select the top-k episodes and use their actions to update the mean and standard deviation of the actions distribution. The CEM planning step is repeated for a specified number of steps. A call to the module returns the actions that empirically maximised the returns given a planning horizon Args: env (EnvBase): The environment to perform the planning step on (can be `ModelBasedEnv` or :obj:`EnvBase`). planning_horizon (int): The length of the simulated trajectories optim_steps (int): The number of optimization steps used by the MPC planner num_candidates (int): The number of candidates to sample from the Gaussian distributions. top_k (int): The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution. reward_key (str, optional): The key in the TensorDict to use to retrieve the reward. Defaults to "reward". action_key (str, optional): The key in the TensorDict to use to store the action. Defaults to "action" Examples: >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.state_spec = CompositeSpec( ... hidden_observation=UnboundedContinuousTensorSpec((4,)) ... ) ... self.observation_spec = CompositeSpec( ... hidden_observation=UnboundedContinuousTensorSpec((4,)) ... ) ... self.action_spec = UnboundedContinuousTensorSpec((1,)) ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( ... {}, ... batch_size=self.batch_size, ... device=self.device, ... ) ... tensordict = tensordict.update( ... self.full_state_spec.rand()) ... tensordict = tensordict.update( ... self.full_action_spec.rand()) ... tensordict = tensordict.update( ... self.full_observation_spec.rand()) ... return tensordict ... >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( ... SafeModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), ... SafeModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], ... ), ... ) >>> env = MyMBEnv(world_model) >>> # Build a planner and use it as actor >>> planner = CEMPlanner(env, 10, 11, 7, 3) >>> env.rollout(5, planner) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False) """ def __init__( self, env: EnvBase, planning_horizon: int, optim_steps: int, num_candidates: int, top_k: int, reward_key: str = ("next", "reward"), action_key: str = "action", ): super().__init__(env=env, action_key=action_key) self.planning_horizon = planning_horizon self.optim_steps = optim_steps self.num_candidates = num_candidates self.top_k = top_k self.reward_key = reward_key
[docs] def planning(self, tensordict: TensorDictBase) -> torch.Tensor: batch_size = tensordict.batch_size action_shape = ( *batch_size, self.num_candidates, self.planning_horizon, *self.action_spec.shape, ) action_stats_shape = ( *batch_size, 1, self.planning_horizon, *self.action_spec.shape, ) action_topk_shape = ( *batch_size, self.top_k, self.planning_horizon, *self.action_spec.shape, ) TIME_DIM = len(self.action_spec.shape) - 3 K_DIM = len(self.action_spec.shape) - 4 expanded_original_tensordict = ( tensordict.unsqueeze(-1) .expand(*batch_size, self.num_candidates) .to_tensordict() ) _action_means = torch.zeros( *action_stats_shape, device=tensordict.device, dtype=self.env.action_spec.dtype, ) _action_stds = torch.ones_like(_action_means) container = TensorDict( { "tensordict": expanded_original_tensordict, "stats": TensorDict( { "_action_means": _action_means, "_action_stds": _action_stds, }, [*batch_size, 1, self.planning_horizon], ), }, batch_size, ) for _ in range(self.optim_steps): actions_means = container.get(("stats", "_action_means")) actions_stds = container.get(("stats", "_action_stds")) actions = actions_means + actions_stds * torch.randn( *action_shape, device=actions_means.device, dtype=actions_means.dtype, ) actions = self.env.action_spec.project(actions) optim_tensordict = container.get("tensordict").clone() policy = _PrecomputedActionsSequentialSetter(actions) optim_tensordict = self.env.rollout( max_steps=self.planning_horizon, policy=policy, auto_reset=False, tensordict=optim_tensordict, ) sum_rewards = optim_tensordict.get(self.reward_key).sum( dim=TIME_DIM, keepdim=True ) _, top_k = sum_rewards.topk(self.top_k, dim=K_DIM) top_k = top_k.expand(action_topk_shape) best_actions = actions.gather(K_DIM, top_k) container.set_( ("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True) ) container.set_( ("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True) ) action_means = container.get(("stats", "_action_means")) return action_means[..., 0, 0, :]
class _PrecomputedActionsSequentialSetter: def __init__(self, actions): self.actions = actions self.cmpt = 0 def __call__(self, tensordict): # checks that the step count is lower or equal to the horizon if self.cmpt >= self.actions.shape[-2]: raise ValueError("Precomputed actions sequence is too short") tensordict = tensordict.set("action", self.actions[..., self.cmpt, :]) self.cmpt += 1 return tensordict

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources