step_mdp¶
- torchrl.envs.utils.step_mdp(tensordict: TensorDictBase, next_tensordict: Optional[TensorDictBase] = None, keep_other: bool = True, exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, reward_keys: Union[str, Tuple[str, ...], List[Union[str, Tuple[str, ...]]]] = 'reward', done_keys: Union[str, Tuple[str, ...], List[Union[str, Tuple[str, ...]]]] = 'done', action_keys: Union[str, Tuple[str, ...], List[Union[str, Tuple[str, ...]]]] = 'action') TensorDictBase [source]¶
Creates a new tensordict that reflects a step in time of the input tensordict.
Given a tensordict retrieved after a step, returns the
"next"
indexed-tensordict. The arguments allow for a precise control over what should be kept and what should be copied from the"next"
entry. The default behaviour is: move the observation entries, reward and done states to the root, exclude the current action and keep all extra keys (non-action, non-done, non-reward).- Parameters:
tensordict (TensorDictBase) – tensordict with keys to be renamed
next_tensordict (TensorDictBase, optional) – destination tensordict
keep_other (bool, optional) – if
True
, all keys that do not start with'next_'
will be kept. Default isTrue
.exclude_reward (bool, optional) – if
True
, the"reward"
key will be discarded from the resulting tensordict. IfFalse
, it will be copied (and replaced) from the"next"
entry (if present). Default isTrue
.exclude_done (bool, optional) – if
True
, the"done"
key will be discarded from the resulting tensordict. IfFalse
, it will be copied (and replaced) from the"next"
entry (if present). Default isFalse
.exclude_action (bool, optional) – if
True
, the"action"
key will be discarded from the resulting tensordict. IfFalse
, it will be kept in the root tensordict (since it should not be present in the"next"
entry). Default isTrue
.reward_keys (NestedKey or list of NestedKey, optional) – the keys where the reward is written. Defaults to “reward”.
done_keys (NestedKey or list of NestedKey, optional) – the keys where the done is written. Defaults to “done”.
action_keys (NestedKey or list of NestedKey, optional) – the keys where the action is written. Defaults to “action”.
- Returns:
A new tensordict (or next_tensordict) containing the tensors of the t+1 step.
Examples: This funtion allows for this kind of loop to be used:
>>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "extra": torch.zeros(()), ... "next": TensorDict({ ... "done": torch.zeros((), dtype=torch.bool), ... "reward": torch.zeros(()), ... "obs": torch.zeros(()), ... }, []), ... "obs": torch.zeros(()), ... "action": torch.zeros(()), ... }, []) >>> print(step_mdp(td)) TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_done=True)) # "done" is dropped TensorDict( fields={ extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_reward=False)) # "reward" is kept TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, exclude_action=False)) # "action" persists at the root TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), extra: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> print(step_mdp(td, keep_other=False)) # "extra" is missing TensorDict( fields={ done: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False), obs: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)