- class torchrl.envs.transforms.VecNorm(in_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None, shared_td: Optional[TensorDictBase] = None, lock: Optional[Lock] = None, decay: float = 0.9999, eps: float = 0.0001, shapes: Optional[List[Size]] = None)¶
Moving average normalization layer for torchrl environments.
VecNorm keeps track of the summary statistics of a dataset to standardize it on-the-fly. If the transform is in ‘eval’ mode, the running statistics are not updated.
If multiple processes are running a similar environment, one can pass a TensorDictBase instance that is placed in shared memory: if so, every time the normalization layer is queried it will update the values for all processes that share the same reference.
To use VecNorm at inference time and avoid updating the values with the new observations, one should substitute this layer by vecnorm.to_observation_norm().
in_keys (sequence of NestedKey, optional) – keys to be updated. default: [“observation”, “reward”]
shared_td (TensorDictBase, optional) – A shared tensordict containing the keys of the transform.
decay (number, optional) – decay rate of the moving average. default: 0.99
eps (number, optional) – lower bound of the running standard deviation (for numerical underflow). Default is 1e-4.
shapes (List[torch.Size], optional) – if provided, represents the shape of each in_keys. Its length must match the one of
in_keys. Each shape must match the trailing dimension of the corresponding entry. If not, the feature dimensions of the entry (ie all dims that do not belong to the tensordict batch-size) will be considered as feature dimension.
>>> from torchrl.envs.libs.gym import GymEnv >>> t = VecNorm(decay=0.9) >>> env = GymEnv("Pendulum-v0") >>> env = TransformedEnv(env, t) >>> tds =  >>> for _ in range(1000): ... td = env.rand_step() ... if td.get("done"): ... _ = env.reset() ... tds += [td] >>> tds = torch.stack(tds, 0) >>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all()) tensor(True) >>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all()) tensor(True)
Creates a shared tensordict for normalization across processes.
env (EnvBase) – example environment to be used to create the tensordict
keys (sequence of NestedKey, optional) – keys that have to be normalized. Default is [“next”, “reward”]
memmap (bool) – if
True, the resulting tensordict will be cast into memmory map (using memmap_()). Otherwise, the tensordict will be placed in shared memory.
A memory in shared memory to be sent to each process.
>>> from torch import multiprocessing as mp >>> queue = mp.Queue() >>> env = make_env() >>> td_shared = VecNorm.build_td_for_shared_vecnorm(env, ... ["next", "reward"]) >>> assert td_shared.is_shared() >>> queue.put(td_shared) >>> # on workers >>> v = VecNorm(shared_td=queue.get()) >>> env = TransformedEnv(make_env(), v)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- get_extra_state() OrderedDict ¶
Returns any extra state to include in the module’s state_dict. Implement this and a corresponding
set_extra_state()for your module if you need to store extra state. This function is called when building the module’s state_dict().
Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.
Any extra state to store in the module’s state_dict
- Return type:
- set_extra_state(state: OrderedDict) None ¶
This function is called from
load_state_dict()to handle any extra state found within the state_dict. Implement this function and a corresponding
get_extra_state()for your module if you need to store extra state within its state_dict.
state (dict) – Extra state from the state_dict