MPPIPlanner¶
- class torchrl.modules.MPPIPlanner(*args, **kwargs)[source]¶
MPPI Planner Module.
- Reference:
Model predictive path integral control using covariance variable importance
sampling. (Williams, G., Aldrich, A., and Theodorou, E. A.) https://arxiv.org/abs/1509.01149 - Temporal Difference Learning for Model Predictive Control
(Hansen N., Wang X., Su H.) https://arxiv.org/abs/2203.04955
This module will perform a MPPI planning step when given a TensorDict containing initial states.
A call to the module returns the actions that empirically maximised the returns given a planning horizon
- Parameters:
env (EnvBase) – The environment to perform the planning step on (can be ModelBasedEnv or
EnvBase
).planning_horizon (int) – The length of the simulated trajectories
optim_steps (int) – The number of optimization steps used by the MPC planner
num_candidates (int) – The number of candidates to sample from the Gaussian distributions.
top_k (int) – The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution.
reward_key (str, optional) – The key in the TensorDict to use to retrieve the reward. Defaults to “reward”.
action_key (str, optional) – The key in the TensorDict to use to store the action. Defaults to “action”
Examples
>>> from tensordict import TensorDict >>> from torchrl.data import Composite, Unbounded >>> from torchrl.envs.model_based import ModelBasedEnvBase >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules import ValueOperator >>> from torchrl.objectives.value import TDLambdaEstimator >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.state_spec = Composite( ... hidden_observation=Unbounded((4,)) ... ) ... self.observation_spec = Composite( ... hidden_observation=Unbounded((4,)) ... ) ... self.action_spec = Unbounded((1,)) ... self.reward_spec = Unbounded((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( ... {}, ... batch_size=self.batch_size, ... device=self.device, ... ) ... tensordict = tensordict.update( ... self.full_state_spec.rand()) ... tensordict = tensordict.update( ... self.full_action_spec.rand()) ... tensordict = tensordict.update( ... self.full_observation_spec.rand()) ... return tensordict ... >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( ... TensorDictModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), ... TensorDictModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], ... ), ... ) >>> env = MyMBEnv(world_model) >>> value_net = nn.Linear(4, 1) >>> value_net = ValueOperator(value_net, in_keys=["hidden_observation"]) >>> adv = TDLambdaEstimator( ... gamma=0.99, ... lmbda=0.95, ... value_network=value_net, ... ) >>> # Build a planner and use it as actor >>> planner = MPPIPlanner( ... env, ... adv, ... temperature=1.0, ... planning_horizon=10, ... optim_steps=11, ... num_candidates=7, ... top_k=3) >>> env.rollout(5, planner) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False), terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=cpu, is_shared=False)