VIPRewardTransform
- class torchrl.envs.transforms.VIPRewardTransform(*args, **kwargs)[source]
A VIP transform to compute rewards based on embedded similarity.
This class will update the reward computation
- forward(tensordict)[source]
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform()
.does not call
_step()
or_call()
.
This method is not called within env.step at any point. However, is is called within
sample()
.Note
forward
also works with regular keyword arguments usingdispatch
to cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_input_spec(input_spec: TensorSpec) TensorSpec [source]
Transforms the input spec such that the resulting spec matches transform mapping.
- Parameters:
input_spec (TensorSpec) – spec before the transform
- Returns:
expected spec after the transform