PPOLoss¶
- class torchrl.objectives.PPOLoss(*args, **kwargs)[source]¶
A parent PPO loss class.
PPO (Proximal Policy Optimization) is a model-free, online RL algorithm that makes use of a recorded (batch of) trajectories to perform several optimization steps, while actively preventing the updated policy to deviate too much from its original parameter configuration.
PPO loss can be found in different flavors, depending on the way the constrained optimization is implemented: ClipPPOLoss and KLPENPPOLoss. Unlike its subclasses, this class does not implement any regularization and should therefore be used cautiously.
For more details regarding PPO, refer to: “Proximal Policy Optimization Algorithms”, https://arxiv.org/abs/1707.06347
- Parameters:
actor_network (ProbabilisticTensorDictSequential) – policy operator.
critic_network (ValueOperator) – value operator.
- Keyword Arguments:
entropy_bonus (bool, optional) – if
True
, an entropy bonus will be added to the loss to favour exploratory policies.samples_mc_entropy (int, optional) – if the distribution retrieved from the policy operator does not have a closed form formula for the entropy, a Monte-Carlo estimate will be used.
samples_mc_entropy
will control how many samples will be used to compute this estimate. Defaults to1
.entropy_coef (scalar, optional) – entropy multiplier when computing the total loss. Defaults to
0.01
.critic_coef (scalar, optional) – critic loss multiplier when computing the total loss. Defaults to
1.0
. Setcritic_coef
toNone
to exclude the value loss from the forward outputs.loss_critic_type (str, optional) – loss function for the value discrepancy. Can be one of “l1”, “l2” or “smooth_l1”. Defaults to
"smooth_l1"
.normalize_advantage (bool, optional) – if
True
, the advantage will be normalized before being used. Defaults toFalse
.separate_losses (bool, optional) – if
True
, shared parameters between policy and critic will only be trained on the policy loss. Defaults toFalse
, i.e., gradients are propagated to shared parameters for both policy and critic losses.advantage_key (str, optional) – [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is expected to be written. Defaults to
"advantage"
.value_target_key (str, optional) – [Deprecated, use set_keys(value_target_key=value_target_key) instead] The input tensordict key where the target state value is expected to be written. Defaults to
"value_target"
.value_key (str, optional) – [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to
"state_value"
.functional (bool, optional) – whether modules should be functionalized. Functionalizing permits features like meta-RL, but makes it impossible to use distributed models (DDP, FSDP, …) and comes with a little cost. Defaults to
True
.reduction (str, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
."none"
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.clip_value (float, optional) – If provided, it will be used to compute a clipped version of the value prediction with respect to the input tensordict value estimate and use it to calculate the value loss. The purpose of clipping is to limit the impact of extreme value predictions, helping stabilize training and preventing large updates. However, it will have no impact if the value estimate was done by the current version of the value estimator. Defaults to
None
.
Note
The advantage (typically GAE) can be computed by the loss function or in the training loop. The latter option is usually preferred, but this is up to the user to choose which option is to be preferred. If the advantage key (
"advantage
by default) is not present in the input tensordict, the advantage will be computed by theforward()
method.>>> ppo_loss = PPOLoss(actor, critic) >>> advantage = GAE(critic) >>> data = next(datacollector) >>> losses = ppo_loss(data) >>> # equivalent >>> advantage(data) >>> losses = ppo_loss(data)
A custom advantage module can be built using
make_value_estimator()
. The default isGAE
with hyperparameters dictated bydefault_value_kwargs()
.>>> ppo_loss = PPOLoss(actor, critic) >>> ppo_loss.make_value_estimator(ValueEstimators.TDLambda) >>> data = next(datacollector) >>> losses = ppo_loss(data)
Note
If the actor and the value function share parameters, one can avoid calling the common module multiple times by passing only the head of the value network to the PPO loss module:
>>> common = SomeModule(in_keys=["observation"], out_keys=["hidden"]) >>> actor_head = SomeActor(in_keys=["hidden"]) >>> value_head = SomeValue(in_keys=["hidden"]) >>> # first option, with 2 calls on the common module >>> model = ActorValueOperator(common, actor_head, value_head) >>> loss_module = PPOLoss(model.get_policy_operator(), model.get_value_operator()) >>> # second option, with a single call to the common module >>> loss_module = PPOLoss(ProbabilisticTensorDictSequential(model, actor_head), value_head)
This will work regardless of whether separate_losses is activated or not.
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... distribution_class=TanhNormal, ... in_keys=["loc", "scale"], ... spec=spec) >>> module = nn.Sequential(base_layer, nn.Linear(5, 1)) >>> value = ValueOperator( ... module=module, ... in_keys=["observation"]) >>> loss = PPOLoss(actor, value) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({"observation": torch.randn(*batch, n_obs), ... "action": action, ... "sample_log_prob": torch.randn_like(action[..., 1]), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]
+ in_keys of the actor and value network. The return value is a tuple of tensors in the following order:["loss_objective"]
+["entropy", "loss_entropy"]
if entropy_bonus is set +"loss_critic"
if critic_coef is notNone
. The output keys can also be filtered usingPPOLoss.select_out_keys()
method.Examples
>>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... distribution_class=TanhNormal, ... in_keys=["loc", "scale"], ... spec=spec) >>> module = nn.Sequential(base_layer, nn.Linear(5, 1)) >>> value = ValueOperator( ... module=module, ... in_keys=["observation"]) >>> loss = PPOLoss(actor, value) >>> loss.set_keys(sample_log_prob="sampleLogProb") >>> _ = loss.select_out_keys("loss_objective") >>> batch = [2, ] >>> action = spec.rand(batch) >>> loss_objective = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... sampleLogProb=torch.randn_like(action[..., 1]) / 10, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward()
Note
There is an exception regarding compatibility with non-tensordict-based modules. If the actor network is probabilistic and uses a
CompositeDistribution
, this class must be used with tensordicts and cannot function as a tensordict-independent module. This is because composite action spaces inherently rely on the structured representation of data provided by tensordicts to handle their actions.- forward(tensordict: TensorDictBase = None) TensorDictBase [source]¶
It is designed to read an input TensorDict and return another tensordict with loss keys named “loss*”.
Splitting the loss in its component can then be used by the trainer to log the various loss values throughout training. Other scalars present in the output tensordict will be logged too.
- Parameters:
tensordict – an input tensordict with the values required to compute the loss.
- Returns:
A new tensordict with no batch dimension containing various loss scalars which will be named “loss*”. It is essential that the losses are returned with this name as they will be read by the trainer before backpropagation.
- property functional¶
Whether the module is functional.
Unless it has been specifically designed not to be functional, all losses are functional.
- loss_critic(tensordict: TensorDictBase) Tensor [source]¶
Returns the critic loss multiplied by
critic_coef
, if it is notNone
.
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)