Shortcuts

MultiStepActorWrapper

class torchrl.modules.tensordict_module.MultiStepActorWrapper(*args, **kwargs)[source]

A wrapper around a multi-action actor.

This class enables macros to be executed in an environment. The actor action(s) entry must have an additional time dimension to be consumed. It must be placed adjacent to the last dimension of the input tensordict (i.e. at tensordict.ndim).

The action entry keys are retrieved automatically from the actor if not provided using a simple heuristic (any nested key ending with the "action" string).

An "is_init" entry must also be present in the input tensordict to track which and when the current collection should be interrupted because a “done” state has been encountered. Unlike action_keys, this key must be unique.

Parameters:
  • actor (TensorDictModuleBase) – An actor.

  • n_steps (int) – the number of actions the actor outputs at once (lookahead window).

Keyword Arguments:
  • action_keys (list of NestedKeys, optional) – the action keys from the environment. Can be retrieved from env.action_keys. Defaults to all out_keys of the actor which end with the "action" string.

  • init_key (NestedKey, optional) – the key of the entry indicating when the environment has gone through a reset. Defaults to "is_init" which is the out_key from the InitTracker transform.

Examples

>>> import torch.nn
>>> from torchrl.modules.tensordict_module.actors import MultiStepActorWrapper, Actor
>>> from torchrl.envs import CatFrames, GymEnv, TransformedEnv, SerialEnv, InitTracker, Compose
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> time_steps = 6
>>> n_obs = 4
>>> n_action = 2
>>> batch = 5
>>>
>>> # Transforms a CatFrames in a stack of frames
>>> def reshape_cat(data: torch.Tensor):
...     return data.unflatten(-1, (time_steps, n_obs))
>>> # an actor that reads `time_steps` frames and outputs one action per frame
>>> # (actions are conditioned on the observation of `time_steps` in the past)
>>> actor_base = Seq(
...     Mod(reshape_cat, in_keys=["obs_cat"], out_keys=["obs_cat_reshape"]),
...     Mod(torch.nn.Linear(n_obs, n_action), in_keys=["obs_cat_reshape"], out_keys=["action"])
... )
>>> # Wrap the actor to dispatch the actions
>>> actor = MultiStepActorWrapper(actor_base, n_steps=time_steps)
>>>
>>> env = TransformedEnv(
...     SerialEnv(batch, lambda: GymEnv("CartPole-v1")),
...     Compose(
...         InitTracker(),
...         CatFrames(N=time_steps, in_keys=["observation"], out_keys=["obs_cat"], dim=-1)
...     )
... )
>>>
>>> print(env.rollout(100, policy=actor, break_when_any_done=False))
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 100, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        action_orig: Tensor(shape=torch.Size([5, 100, 6, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        counter: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.int32, is_shared=False),
        done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5, 100]),
            device=cpu,
            is_shared=False),
        obs_cat: Tensor(shape=torch.Size([5, 100, 24]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([5, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([5, 100]),
    device=cpu,
    is_shared=False)
forward(tensordict: TensorDictBase) TensorDictBase[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

property init_key: Union[str, Tuple[str, ...]]

The indicator of the initial step for a given element of the batch.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources