Shortcuts

ModelBasedEnvBase

torchrl.envs.ModelBasedEnvBase(*args, **kwargs)[source]

Basic environnement for Model Based RL sota-implementations.

Wrapper around the model of the MBRL algorithm. It is meant to give an env framework to a world model (including but not limited to observations, reward, done state and safety constraints models). and to behave as a classical environment.

This is a base class for other environments and it should not be used directly.

Example

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Composite, Unbounded
>>> class MyMBEnv(ModelBasedEnvBase):
...     def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
...         super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
...         self.observation_spec = Composite(
...             hidden_observation=Unbounded((4,))
...         )
...         self.state_spec = Composite(
...             hidden_observation=Unbounded((4,)),
...         )
...         self.action_spec = Unbounded((1,))
...         self.reward_spec = Unbounded((1,))
...
...     def _reset(self, tensordict: TensorDict) -> TensorDict:
...         tensordict = TensorDict({},
...             batch_size=self.batch_size,
...             device=self.device,
...         )
...         tensordict = tensordict.update(self.state_spec.rand())
...         tensordict = tensordict.update(self.observation_spec.rand())
...         return tensordict
>>> # This environment is used as follows:
>>> import torch.nn as nn
>>> from torchrl.modules import MLP, WorldModelWrapper
>>> world_model = WorldModelWrapper(
...     TensorDictModule(
...         MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
...         in_keys=["hidden_observation", "action"],
...         out_keys=["hidden_observation"],
...     ),
...     TensorDictModule(
...         nn.Linear(4, 1),
...         in_keys=["hidden_observation"],
...         out_keys=["reward"],
...     ),
... )
>>> env = MyMBEnv(world_model)
>>> tensordict = env.rollout(max_steps=10)
>>> print(tensordict)
TensorDict(
    fields={
        action: Tensor(torch.Size([10, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([10, 1]), dtype=torch.bool),
        hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32),
        next: LazyStackedTensorDict(
            fields={
                hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False),
        reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
Properties:
  • observation_spec (Composite): sampling spec of the observations;

  • action_spec (TensorSpec): sampling spec of the actions;

  • reward_spec (TensorSpec): sampling spec of the rewards;

  • input_spec (Composite): sampling spec of the inputs;

  • batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes.

  • device (torch.device): device where the env input and output are expected to live

Parameters:
  • world_model (nn.Module) – model that generates world states and its corresponding rewards;

  • params (List[torch.Tensor], optional) – list of parameters of the world model;

  • buffers (List[torch.Tensor], optional) – list of buffers of the world model;

  • device (torch.device, optional) – device where the env input and output are expected to live

  • dtype (torch.dtype, optional) – dtype of the env input and output

  • batch_size (torch.Size, optional) – number of environments contained in the instance

  • run_type_check (bool, optional) – whether to run type checks on the step of the env

torchrl.envs.step(TensorDict -> TensorDict)

step in the environment

torchrl.envs.reset(TensorDict, optional -> TensorDict)

reset the environment

torchrl.envs.set_seed(int -> int)

sets the seed of the environment

torchrl.envs.rand_step(TensorDict, optional -> TensorDict)

random step given the action spec

torchrl.envs.rollout(Callable, ... -> TensorDict)

executes a rollout in the environment with the given policy (or random steps if no policy is provided)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources