StepCounter¶
- class torchrl.envs.transforms.StepCounter(max_steps: int | None = None, truncated_key: str = 'truncated')[source]¶
Counts the steps from a reset and sets the done state to True after a certain number of steps.
- Parameters:
max_steps (int, optional) – a positive integer that indicates the maximum number of steps to take before setting the
truncated_key
entry toTrue
. However, the step count will still be incremented on each call to step() into the step_count attribute.truncated_key (str, optional) – the key where the truncated key should be written. Defaults to
"truncated"
, which is recognised by data collectors as a reset signal.
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_input_spec(input_spec: CompositeSpec) CompositeSpec [source]¶
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform
- transform_observation_spec(observation_spec: CompositeSpec) CompositeSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform