Shortcuts

MPPIPlanner

class torchrl.modules.MPPIPlanner(*args, **kwargs)[source]

MPPI Planner Module.

Reference:
  • Model predictive path integral control using covariance variable importance

sampling. (Williams, G., Aldrich, A., and Theodorou, E. A.) https://arxiv.org/abs/1509.01149 - Temporal Difference Learning for Model Predictive Control

(Hansen N., Wang X., Su H.) https://arxiv.org/abs/2203.04955

This module will perform a MPPI planning step when given a TensorDict containing initial states.

A call to the module returns the actions that empirically maximised the returns given a planning horizon

Parameters:
  • env (EnvBase) – The environment to perform the planning step on (can be ModelBasedEnv or EnvBase).

  • planning_horizon (int) – The length of the simulated trajectories

  • optim_steps (int) – The number of optimization steps used by the MPC planner

  • num_candidates (int) – The number of candidates to sample from the Gaussian distributions.

  • top_k (int) – The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution.

  • reward_key (str, optional) – The key in the TensorDict to use to retrieve the reward. Defaults to “reward”.

  • action_key (str, optional) – The key in the TensorDict to use to store the action. Defaults to “action”

Examples

>>> from tensordict import TensorDict
>>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
>>> from torchrl.envs.model_based import ModelBasedEnvBase
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import ValueOperator
>>> from torchrl.objectives.value import TDLambdaEstimator
>>> class MyMBEnv(ModelBasedEnvBase):
...     def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
...         super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
...         self.state_spec = CompositeSpec(
...             hidden_observation=UnboundedContinuousTensorSpec((4,))
...         )
...         self.observation_spec = CompositeSpec(
...             hidden_observation=UnboundedContinuousTensorSpec((4,))
...         )
...         self.action_spec = UnboundedContinuousTensorSpec((1,))
...         self.reward_spec = UnboundedContinuousTensorSpec((1,))
...
...     def _reset(self, tensordict: TensorDict) -> TensorDict:
...         tensordict = TensorDict(
...             {},
...             batch_size=self.batch_size,
...             device=self.device,
...         )
...         tensordict = tensordict.update(
...             self.full_state_spec.rand())
...         tensordict = tensordict.update(
...             self.full_action_spec.rand())
...         tensordict = tensordict.update(
...             self.full_observation_spec.rand())
...         return tensordict
...
>>> from torchrl.modules import MLP, WorldModelWrapper
>>> import torch.nn as nn
>>> world_model = WorldModelWrapper(
...     TensorDictModule(
...         MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
...         in_keys=["hidden_observation", "action"],
...         out_keys=["hidden_observation"],
...     ),
...     TensorDictModule(
...         nn.Linear(4, 1),
...         in_keys=["hidden_observation"],
...         out_keys=["reward"],
...     ),
... )
>>> env = MyMBEnv(world_model)
>>> value_net = nn.Linear(4, 1)
>>> value_net = ValueOperator(value_net, in_keys=["hidden_observation"])
>>> adv = TDLambdaEstimator(
...     gamma=0.99,
...     lmbda=0.95,
...     value_network=value_net,
... )
>>> # Build a planner and use it as actor
>>> planner = MPPIPlanner(
...     env,
...     adv,
...     temperature=1.0,
...     planning_horizon=10,
...     optim_steps=11,
...     num_candidates=7,
...     top_k=3)
>>> env.rollout(5, planner)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                hidden_observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5]),
            device=cpu,
            is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)
planning(tensordict: TensorDictBase) Tensor[source]

Performs the MPC planning step.

Parameters:

td (TensorDict) – The TensorDict to perform the planning step on.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources