ActionMask
- class torchrl.envs.transforms.ActionMask(action_key: NestedKey = 'action', mask_key: NestedKey = 'action_mask')[source]
An adaptive action masker.
This transform reads the mask from the input tensordict after the step is executed, and adapts the mask of the one-hot / categorical action spec.
Note
This transform will fail when used without an environment.
- Parameters:
action_key (NestedKey, optional) – the key where the action tensor can be found. Defaults to
"action"
.mask_key (NestedKey, optional) – the key where the action mask can be found. Defaults to
"action_mask"
.
Examples
>>> import torch >>> from torchrl.data.tensor_specs import Categorical, Binary, Unbounded, Composite >>> from torchrl.envs.transforms import ActionMask, TransformedEnv >>> from torchrl.envs.common import EnvBase >>> class MaskedEnv(EnvBase): ... def __init__(self, *args, **kwargs): ... super().__init__(*args, **kwargs) ... self.action_spec = Categorical(4) ... self.state_spec = Composite(action_mask=Binary(4, dtype=torch.bool)) ... self.observation_spec = Composite(obs=Unbounded(3)) ... self.reward_spec = Unbounded(1) ... ... def _reset(self, tensordict=None): ... td = self.observation_spec.rand() ... td.update(torch.ones_like(self.state_spec.rand())) ... return td ... ... def _step(self, data): ... td = self.observation_spec.rand() ... mask = data.get("action_mask") ... action = data.get("action") ... mask = mask.scatter(-1, action.unsqueeze(-1), 0) ... ... td.set("action_mask", mask) ... td.set("reward", self.reward_spec.rand()) ... td.set("done", ~mask.any().view(1)) ... return td ... ... def _set_seed(self, seed): ... return seed ... >>> torch.manual_seed(0) >>> base_env = MaskedEnv() >>> env = TransformedEnv(base_env, ActionMask()) >>> r = env.rollout(10) >>> env = TransformedEnv(base_env, ActionMask()) >>> r = env.rollout(10) >>> r["action_mask"] tensor([[ True, True, True, True], [ True, True, False, True], [ True, True, False, False], [ True, False, False, False]])
- forward(tensordict: TensorDictBase) TensorDictBase [source]
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform()
.does not call
_step()
or_call()
.
This method is not called within env.step at any point. However, is is called within
sample()
.Note
forward
also works with regular keyword arguments usingdispatch
to cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.