VecNorm¶
- class torchrl.envs.transforms.VecNorm(in_keys: Optional[Sequence[NestedKey]] = None, out_keys: Optional[Sequence[NestedKey]] = None, shared_td: Optional[TensorDictBase] = None, lock: Optional[Lock] = None, decay: float = 0.9999, eps: float = 0.0001, shapes: Optional[List[Size]] = None)[source]¶
Moving average normalization layer for torchrl environments.
VecNorm keeps track of the summary statistics of a dataset to standardize it on-the-fly. If the transform is in ‘eval’ mode, the running statistics are not updated.
If multiple processes are running a similar environment, one can pass a TensorDictBase instance that is placed in shared memory: if so, every time the normalization layer is queried it will update the values for all processes that share the same reference.
To use VecNorm at inference time and avoid updating the values with the new observations, one should substitute this layer by
to_observation_norm()
. This will provide a static version of VecNorm which will not be updated when the source transform is updated. To get a frozen copy of the VecNorm layer, seefrozen_copy()
.- Parameters:
in_keys (sequence of NestedKey, optional) – keys to be updated. default: [“observation”, “reward”]
out_keys (sequence of NestedKey, optional) – destination keys. Defaults to
in_keys
.shared_td (TensorDictBase, optional) – A shared tensordict containing the keys of the transform.
lock (mp.Lock) – a lock to prevent race conditions between processes. Defaults to None (lock created during init).
decay (number, optional) – decay rate of the moving average. default: 0.99
eps (number, optional) – lower bound of the running standard deviation (for numerical underflow). Default is 1e-4.
shapes (List[torch.Size], optional) – if provided, represents the shape of each in_keys. Its length must match the one of
in_keys
. Each shape must match the trailing dimension of the corresponding entry. If not, the feature dimensions of the entry (ie all dims that do not belong to the tensordict batch-size) will be considered as feature dimension.
Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> t = VecNorm(decay=0.9) >>> env = GymEnv("Pendulum-v0") >>> env = TransformedEnv(env, t) >>> tds = [] >>> for _ in range(1000): ... td = env.rand_step() ... if td.get("done"): ... _ = env.reset() ... tds += [td] >>> tds = torch.stack(tds, 0) >>> print((abs(tds.get(("next", "observation")).mean(0))<0.2).all()) tensor(True) >>> print((abs(tds.get(("next", "observation")).std(0)-1)<0.2).all()) tensor(True)
Creates a shared tensordict for normalization across processes.
- Parameters:
env (EnvBase) – example environment to be used to create the tensordict
keys (sequence of NestedKey, optional) – keys that have to be normalized. Default is [“next”, “reward”]
memmap (bool) – if
True
, the resulting tensordict will be cast into memmory map (using memmap_()). Otherwise, the tensordict will be placed in shared memory.
- Returns:
A memory in shared memory to be sent to each process.
Examples
>>> from torch import multiprocessing as mp >>> queue = mp.Queue() >>> env = make_env() >>> td_shared = VecNorm.build_td_for_shared_vecnorm(env, ... ["next", "reward"]) >>> assert td_shared.is_shared() >>> queue.put(td_shared) >>> # on workers >>> v = VecNorm(shared_td=queue.get()) >>> env = TransformedEnv(make_env(), v)
- forward(tensordict: TensorDictBase) TensorDictBase ¶
Reads the input tensordict, and for the selected keys, applies the transform.
- freeze() VecNorm [source]¶
Freezes the VecNorm, avoiding the stats to be updated when called.
See
unfreeze()
.
- frozen_copy()[source]¶
Returns a copy of the Transform that keeps track of the stats but does not update them.
- get_extra_state() OrderedDict [source]¶
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding
set_extra_state()
for your module if you need to store extra state. This function is called when building the module’s state_dict().Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.
- Returns:
Any extra state to store in the module’s state_dict
- Return type:
object
- property loc¶
Returns a TensorDict with the loc to be used for an affine transform.
- property scale¶
Returns a TensorDict with the scale to be used for an affine transform.
- set_extra_state(state: OrderedDict) None [source]¶
Set extra state contained in the loaded state_dict.
This function is called from
load_state_dict()
to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()
for your module if you need to store extra state within its state_dict.- Parameters:
state (dict) – Extra state from the state_dict
- property standard_normal¶
Whether the affine transform given by loc and scale follows the standard normal equation.
Similar to
ObservationNorm
standard_normal attribute.Always returns
True
.
- to_observation_norm() Union[Compose, ObservationNorm] [source]¶
Converts VecNorm into an ObservationNorm class that can be used at inference time.
The
ObservationNorm
layer can be updated using thestate_dict()
API.Examples
>>> from torchrl.envs import GymEnv, VecNorm >>> vecnorm = VecNorm(in_keys=["observation"]) >>> train_env = GymEnv("CartPole-v1", device=None).append_transform( ... vecnorm) >>> >>> r = train_env.rollout(4) >>> >>> eval_env = GymEnv("CartPole-v1").append_transform( ... vecnorm.to_observation_norm()) >>> print(eval_env.transform.loc, eval_env.transform.scale) >>> >>> r = train_env.rollout(4) >>> # Update entries with state_dict >>> eval_env.transform.load_state_dict( ... vecnorm.to_observation_norm().state_dict()) >>> print(eval_env.transform.loc, eval_env.transform.scale)
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec [source]¶
Transforms the observation spec such that the resulting spec matches transform mapping.
- Parameters:
observation_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform