terminated_or_truncated¶
- torchrl.envs.utils.terminated_or_truncated(data: TensorDictBase, full_done_spec: Optional[TensorSpec] = None, key: str = '_reset', write_full_false: bool = False) bool [source]¶
Reads the done / terminated / truncated keys within a tensordict, and writes a new tensor where the values of both signals are aggregated.
The modification occurs in-place within the TensorDict instance provided. This function can be used to compute the “_reset” signals in batched or multiagent settings, hence the default name of the output key.
- Parameters:
data (TensorDictBase) – the input data, generally resulting from a call to
step()
.full_done_spec (TensorSpec, optional) – the done_spec from the env, indicating where the done leaves have to be found. If not provided, the default
"done"
,"terminated"
and"truncated"
entries will be searched for in the data.key (NestedKey, optional) –
where the aggregated result should be written. If
None
, then the function will not write any key but just output whether any of the done values was true. .. note:: if a value is already present for thekey
entry,the previous value will prevail and no update will be achieved.
write_full_false (bool, optional) – if
True
, the reset keys will be written even if the output isFalse
(ie, no done isTrue
in the provided data structure). Defaults toFalse
.
- Returns: a boolean value indicating whether any of the done states found in the data
contained a
True
.
Examples
>>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict >>> spec = Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... nested=Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ ... "done": True, "truncated": False, ... "nested": {"done": False, "truncated": True}}, ... batch_size=[] ... ) >>> data = _terminated_or_truncated(data, spec) >>> print(data["_reset"]) tensor(True) >>> print(data["nested", "_reset"]) tensor(True)