KLRewardTransform
- class torchrl.envs.transforms.KLRewardTransform(actor: ProbabilisticTensorDictModule, coef=1.0, in_keys=None, out_keys=None, requires_grad=False)[source]
A transform to add a KL[pi_current||pi_0] correction term to the reward.
This transform is used to constrain the policy to remain close to its original configuration which limits overfitting when fine-tuning using RLHF.
- Parameters:
actor (ProbabilisticTensorDictModule) – a probabilistic actor. It must have the following features: it must have a set of input (
in_keys
) and output keys (out_keys
). It must have aget_dist
method that outputs the distribution of the action.coef (
float
) – the coefficient of the KL term. Defaults to1.0
.in_keys (str or list of str/tuples of str) – the input key where the reward should be fetched. Defaults to
"reward"
.out_keys (str or list of str/tuples of str) – the output key where the reward should be written. Defaults to
"reward"
.requires_grad (bool, optional) – if
True
, the frozen parameters will consist of differentiable clones of the original params. Defaults toFalse
.
Note
If the parameters are not differentiable (default), they will not follow the module when dtype or device casting operations will be called (such as
cuda()
,to()
etc.). Whenrequires_grad=True
, casting operations will work as expected.Examples
>>> from torchrl.envs.libs.gym import GymEnv >>> from torchrl.envs import TransformedEnv >>> from tensordict.nn import TensorDictModule as Mod, NormalParamExtractor >>> from torchrl.modules import ProbabilisticActor >>> from tensordict import TensorDict >>> from torchrl.modules.distributions import TanhNormal >>> from torch import nn >>> base_env = GymEnv("Pendulum-v1") >>> n_obs = base_env.observation_spec["observation"].shape[-1] >>> n_act = base_env.action_spec.shape[-1] >>> module = Mod( ... nn.Sequential(nn.Linear(n_obs, n_act * 2), NormalParamExtractor()), ... in_keys=["observation"], ... out_keys=["loc", "scale"], ... ) >>> actor = ProbabilisticActor( ... module, ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... return_log_prob=True, ... ) >>> transform = KLRewardTransform(actor, out_keys="reward_kl") >>> env = TransformedEnv(base_env, transform) >>> with torch.no_grad(): ... # modify the actor parameters ... _ = TensorDict(dict(actor.named_parameters()), []).apply_(lambda x: x.data.copy_(x.data + 1)) ... td = env.rollout(3, actor) >>> # check that rewards have been modified >>> assert (td.get(("next", "reward")) != td.get(("next", "reward_kl"))).all()
Note
Because the KL formula is not always available and the parameters of the original distribution may not have been recorded, we use a stochastic estimate of the KL divergence.
- forward(next_tensordict: TensorDictBase) TensorDictBase
Reads the input tensordict, and for the selected keys, applies the transform.
By default, this method:
calls directly
_apply_transform()
.does not call
_step()
or_call()
.
This method is not called within env.step at any point. However, is is called within
sample()
.Note
forward
also works with regular keyword arguments usingdispatch
to cast the args names to the keys.Examples
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- transform_output_spec(output_spec: Composite) Composite [source]
Transforms the output spec such that the resulting spec matches transform mapping.
This method should generally be left untouched. Changes should be implemented using
transform_observation_spec()
,transform_reward_spec()
andtransform_full_done_spec()
. :param output_spec: spec before the transform :type output_spec: TensorSpec- Returns:
expected spec after the transform