Source code for torchrl.modules.planners.common
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import abc
from typing import Optional
import torch
from tensordict import TensorDictBase
from torchrl.envs.common import EnvBase
from torchrl.modules import SafeModule
[docs]class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta):
"""MPCPlannerBase abstract Module.
This class inherits from :obj:`SafeModule`. Provided a :obj:`TensorDict`, this module will perform a Model Predictive Control (MPC) planning step.
At the end of the planning step, the :obj:`MPCPlanner` will return a proposed action.
Args:
env (EnvBase): The environment to perform the planning step on (Can be :obj:`ModelBasedEnvBase` or :obj:`EnvBase`).
action_key (str, optional): The key that will point to the computed action.
"""
def __init__(
self,
env: EnvBase,
action_key: str = "action",
):
# Check if env is stateless
if env.batch_locked:
raise ValueError(
"Environment is batch_locked. MPCPlanners need an environnement that accepts batched inputs with any batch size"
)
out_keys = [action_key]
in_keys = list(env.observation_spec.keys(True, True))
super().__init__(env, in_keys=in_keys, out_keys=out_keys)
self.env = env
self.action_spec = env.action_spec
self.to(env.device)
[docs] @abc.abstractmethod
def planning(self, td: TensorDictBase) -> torch.Tensor:
"""Performs the MPC planning step.
Args:
td (TensorDict): The TensorDict to perform the planning step on.
"""
raise NotImplementedError()
[docs] def forward(
self,
tensordict: TensorDictBase,
tensordict_out: Optional[TensorDictBase] = None,
**kwargs,
) -> TensorDictBase:
if "params" in kwargs or "vmap" in kwargs:
raise ValueError(
"MPCPlannerBase does not currently support functional programming."
)
action = self.planning(tensordict)
action = self.action_spec.project(action)
tensordict_out = self._write_to_tensordict(
tensordict,
(action,),
tensordict_out,
)
return tensordict_out