step_mdp¶
- torchrl.envs.utils.step_mdp(tensordict: TensorDictBase, next_tensordict: Optional[TensorDictBase] = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: Union[NestedKey, List[NestedKey]] = 'reward', done_keys: Union[NestedKey, List[NestedKey]] = 'done', action_keys: Union[NestedKey, List[NestedKey]] = 'action') TensorDictBase [source]¶
Creates a new tensordict that reflects a step in time of the input tensordict.
Given a tensordict retrieved after a step, returns the
"next"
indexed-tensordict. The arguments allow for a precise control over what should be kept and what should be copied from the"next"
entry. The default behaviour is: move the observation entries, reward and done states to the root, exclude the current action and keep all extra keys (non-action, non-done, non-reward).- Parameters:
tensordict (TensorDictBase) – tensordict with keys to be renamed
next_tensordict (TensorDictBase, optional) – destination tensordict
keep_other (bool, optional) – if
True
, all keys that do not start with'next_'
will be kept. Default isTrue
.exclude_reward (bool, optional) – if
True
, the"reward"
key will be discarded from the resulting tensordict. IfFalse
, it will be copied (and replaced) from the"next"
entry (if present). Default isTrue
.exclude_done (bool, optional) – if
True
, the"done"
key will be discarded from the resulting tensordict. IfFalse
, it will be copied (and replaced) from the"next"
entry (if present). Default isFalse
.exclude_action (bool, optional) – if
True
, the"action"
key will be discarded from the resulting tensordict. IfFalse
, it will be kept in the root tensordict (since it should not be present in the"next"
entry). Default isTrue
.reward_keys (NestedKey or list of NestedKey, optional) – the keys where the reward is written. Defaults to “reward”.
done_keys (NestedKey or list of NestedKey, optional) – the keys where the done is written. Defaults to “done”.
action_keys (NestedKey or list of NestedKey, optional) – the keys where the action is written. Defaults to “action”.
- Returns:
A new tensordict (or next_tensordict) containing the tensors of the t+1 step.
Examples: This funtion allows for this kind of loop to be used:
>>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "extra": torch.zeros(()), ... "next": TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "obs": torch.zeros(()), ... }, []), ... "obs": torch.zeros(()), ... "action": torch.zeros(()), ... }, []) >>> print(step_mdp(td)) TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_done=True)) # "done" is dropped TensorDict( fields={ extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_reward=False)) # "reward" is kept TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_action=False)) # "action" persists at the root TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, keep_other=False)) # "extra" is missing TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
Warning
This function will not work properly if the reward key is also part of the input key when the reward keys are excluded. This is why the
RewardSum
transform registers the episode reward in the observation and not the reward spec by default. When using the fast, cached version of this function (_StepMDP
), this issue should not be observed.