Shortcuts

VecNorm

class torchrl.envs.transforms.VecNorm(in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, shared_td: Optional[TensorDictBase] = None, lock: mp.Lock = None, decay: float = 0.9999, eps: float = 0.0001, shapes: List[torch.Size] = None)[source]

Moving average normalization layer for torchrl environments.

VecNorm keeps track of the summary statistics of a dataset to standardize it on-the-fly. If the transform is in ‘eval’ mode, the running statistics are not updated.

If multiple processes are running a similar environment, one can pass a TensorDictBase instance that is placed in shared memory: if so, every time the normalization layer is queried it will update the values for all processes that share the same reference.

To use VecNorm at inference time and avoid updating the values with the new observations, one should substitute this layer by vecnorm.to_observation_norm().

Parameters:
  • in_keys (sequence of NestedKey, optional) – keys to be updated. default: [“observation”, “reward”]

  • out_keys (sequence of NestedKey, optional) – destination keys. Defaults to in_keys.

  • shared_td (TensorDictBase, optional) – A shared tensordict containing the keys of the transform.

  • decay (number, optional) – decay rate of the moving average. default: 0.99

  • eps (number, optional) – lower bound of the running standard deviation (for numerical underflow). Default is 1e-4.

  • shapes (List[torch.Size], optional) – if provided, represents the shape of each in_keys. Its length must match the one of in_keys. Each shape must match the trailing dimension of the corresponding entry. If not, the feature dimensions of the entry (ie all dims that do not belong to the tensordict batch-size) will be considered as feature dimension.

Examples

>>> from torchrl.envs.libs.gym import GymEnv
>>> t = VecNorm(decay=0.9)
>>> env = GymEnv("Pendulum-v0")
>>> env = TransformedEnv(env, t)
>>> tds = []
>>> for _ in range(1000):
...     td = env.rand_step()
...     if td.get("done"):
...         _ = env.reset()
...     tds += [td]
>>> tds = torch.stack(tds, 0)
>>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all())
tensor(True)
>>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all())
tensor(True)
static build_td_for_shared_vecnorm(env: EnvBase, keys: Optional[Sequence[str]] = None, memmap: bool = False) TensorDictBase[source]

Creates a shared tensordict for normalization across processes.

Parameters:
  • env (EnvBase) – example environment to be used to create the tensordict

  • keys (sequence of NestedKey, optional) – keys that have to be normalized. Default is [“next”, “reward”]

  • memmap (bool) – if True, the resulting tensordict will be cast into memmory map (using memmap_()). Otherwise, the tensordict will be placed in shared memory.

Returns:

A memory in shared memory to be sent to each process.

Examples

>>> from torch import multiprocessing as mp
>>> queue = mp.Queue()
>>> env = make_env()
>>> td_shared = VecNorm.build_td_for_shared_vecnorm(env,
...     ["next", "reward"])
>>> assert td_shared.is_shared()
>>> queue.put(td_shared)
>>> # on workers
>>> v = VecNorm(shared_td=queue.get())
>>> env = TransformedEnv(make_env(), v)
forward(tensordict: TensorDictBase) TensorDictBase

Reads the input tensordict, and for the selected keys, applies the transform.

get_extra_state() OrderedDict[source]

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

Any extra state to store in the module’s state_dict

Return type:

object

set_extra_state(state: OrderedDict) None[source]

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Parameters:

state (dict) – Extra state from the state_dict

to_observation_norm() Union[Compose, ObservationNorm][source]

Converts VecNorm into an ObservationNorm class that can be used at inference time.

transform_observation_spec(observation_spec: TensorSpec) TensorSpec[source]

Transforms the observation spec such that the resulting spec matches transform mapping.

Parameters:

observation_spec (TensorSpec) – spec before the transform

Returns:

expected spec after the transform

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources