DDPGLoss¶
- class torchrl.objectives.DDPGLoss(*args, **kwargs)[source]¶
The DDPG Loss class.
- Parameters:
actor_network (TensorDictModule) – a policy operator.
value_network (TensorDictModule) – a Q value operator.
loss_function (str) – loss function for the value discrepancy. Can be one of “l1”, “l2” or “smooth_l1”.
delay_actor (bool, optional) – whether to separate the target actor networks from the actor networks used for data collection. Default is
False
.delay_value (bool, optional) – whether to separate the target value networks from the value networks used for data collection. Default is
True
.separate_losses (bool, optional) – if
True
, shared parameters between policy and critic will only be trained on the policy loss. Defaults toFalse
, i.e., gradients are propagated to shared parameters for both policy and critic losses.reduction (str, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
."none"
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> value = ValueOperator( ... module=module, ... in_keys=["observation", "action"]) >>> loss = DDPGLoss(actor, value) >>> batch = [2, ] >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["next_reward", "next_done", "next_terminated"]
+ in_keys of the actor_network and value_network. The return value is a tuple of tensors in the following order:["loss_actor", "loss_value", "pred_value", "target_value", "pred_value_max", "target_value_max"]
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.ddpg import DDPGLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> actor = Actor(spec=spec, module=nn.Linear(n_obs, n_act)) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> value = ValueOperator( ... module=module, ... in_keys=["observation", "action"]) >>> loss = DDPGLoss(actor, value) >>> loss_actor, loss_value, pred_value, target_value, pred_value_max, target_value_max = loss( ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward()
The output keys can also be filtered using the
DDPGLoss.select_out_keys()
method.Examples
>>> loss.select_out_keys('loss_actor', 'loss_value') >>> loss_actor, loss_value = loss( ... observation=torch.randn(n_obs), ... action=spec.rand(), ... next_done=torch.zeros(1, dtype=torch.bool), ... next_terminated=torch.zeros(1, dtype=torch.bool), ... next_observation=torch.randn(n_obs), ... next_reward=torch.randn(1)) >>> loss_actor.backward()
- forward(tensordict: TensorDictBase = None) TensorDict [source]¶
Computes the DDPG losses given a tensordict sampled from the replay buffer.
- This function will also write a “td_error” key that can be used by prioritized replay buffers to assign
a priority to items in the tensordict.
- Parameters:
tensordict (TensorDictBase) – a tensordict with keys [“done”, “terminated”, “reward”] and the in_keys of the actor and value networks.
- Returns:
a tuple of 2 tensors containing the DDPG loss.
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)