TimeMaxPool¶
- class torchrl.envs.transforms.TimeMaxPool(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, T: int = 1, reset_key: Optional[NestedKey] = None)[source]¶
Take the maximum value in each position over the last T observations.
This transform take the maximum value in each position for all in_keys tensors over the last T time steps.
- Parameters:
in_keys (sequence of NestedKey, optional) – input keys on which the max pool will be applied. Defaults to “observation” if left empty.
out_keys (sequence of NestedKey, optional) – output keys where the output will be written. Defaults to in_keys if left empty.
T (int, optional) – Number of time steps over which to apply max pooling.
reset_key (NestedKey, optional) – the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise.
Examples
>>> from torchrl.envs import GymEnv >>> base_env = GymEnv("Pendulum-v1") >>> env = TransformedEnv(base_env, TimeMaxPool(in_keys=["observation"], T=10)) >>> torch.manual_seed(0) >>> env.set_seed(0) >>> rollout = env.rollout(10) >>> print(rollout["observation"]) # values should be increasing up until the 10th step tensor([[ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0216, 0.0000], [ 0.0000, 0.1149, 0.0000], [ 0.0000, 0.1990, 0.0000], [ 0.0000, 0.2749, 0.0000], [ 0.0000, 0.3281, 0.0000], [-0.9290, 0.3702, -0.8978]])
Note
TimeMaxPool
currently only supportsdone
signal at the root. Nesteddone
, such as those found in MARL settings, are currently not supported. If this feature is needed, please raise an issue on TorchRL repo.- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Reads the input tensordict, and for the selected keys, applies the transform.
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform