TD3BCLoss¶
- class torchrl.objectives.TD3BCLoss(*args, **kwargs)[source]¶
TD3+BC Loss Module.
Implementation of the TD3+BC loss presented in the paper “A Minimalist Approach to Offline Reinforcement Learning” <https://arxiv.org/pdf/2106.06860>.
This class incorporates two loss functions, executed sequentially within the forward method:
Users also have the option to call these functions directly in the same order if preferred.
- Parameters:
actor_network (TensorDictModule) – the actor to be trained
qvalue_network (TensorDictModule) –
a single Q-value network or a list of Q-value networks. If a single instance of qvalue_network is provided, it will be duplicated
num_qvalue_nets
times. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded).Warning
When a list of parameters if passed, it will __not__ be compared against the policy parameters and all the parameters will be considered as untied.
- Keyword Arguments:
bounds (tuple of float, optional) –
- the bounds of the action space.
Exclusive with
action_spec
. Either this oraction_spec
must
be provided.
action_spec (TensorSpec, optional) – the action spec. Exclusive with
bounds
. Either this orbounds
must be provided.num_qvalue_nets (int, optional) – Number of Q-value networks to be trained. Default is
2
.policy_noise (
float
, optional) – Standard deviation for the target policy action noise. Default is0.2
.noise_clip (
float
, optional) – Clipping range value for the sampled target policy action noise. Default is0.5
.alpha (
float
, optional) – Weight for the behavioral cloning loss. Defaults to2.5
.priority_key (str, optional) – Key where to write the priority value for prioritized replay buffers. Default is “td_error”.
loss_function (str, optional) – loss function to be used for the Q-value. Can be one of
"smooth_l1"
,"l2"
,"l1"
, Default is"smooth_l1"
.delay_actor (bool, optional) – whether to separate the target actor networks from the actor networks used for data collection. Default is
True
.delay_qvalue (bool, optional) – Whether to separate the target Q value networks from the Q value networks used for data collection. Default is
True
.spec (TensorSpec, optional) – the action tensor spec. If not provided and the target entropy is
"auto"
, it will be retrieved from the actor.separate_losses (bool, optional) – if
True
, shared parameters between policy and critic will only be trained on the policy loss. Defaults toFalse
, i.e., gradients are propagated to shared parameters for both policy and critic losses.reduction (str, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
."none"
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, ... spec=spec) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["action", "next_reward", "next_done", "next_terminated"]
+ in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order:["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]
.Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) >>> actor = Actor( ... module=module, ... spec=spec) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") >>> batch = [2, ] >>> action = spec.rand(batch) >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_reward=torch.randn(*batch, 1), ... next_observation=torch.randn(*batch, n_obs)) >>> loss_actor.backward()
- actor_loss(tensordict) Tuple[Tensor, dict] [source]¶
Compute the actor loss.
The actor loss should be computed after the
qvalue_loss()
and is usually delayed 1-3 critic updates.- Parameters:
tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.
- Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached “bc_loss”
used in the combined actor loss as well as the detached “state_action_value_actor” used to calculate the lambda value, and the lambda value “lmbd” itself.
- forward(tensordict: TensorDictBase = None) TensorDictBase [source]¶
The forward method.
Computes successively the
actor_loss()
,qvalue_loss()
, and returns a tensordict with these values. To see what keys are expected in the input tensordict and what keys are expected as output, check the class’s “in_keys” and “out_keys” attributes.
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
- qvalue_loss(tensordict) Tuple[Tensor, dict] [source]¶
Compute the q-value loss.
The q-value loss should be computed before the
actor_loss()
.- Parameters:
tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.
- Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
the detached “td_error” to be used for prioritized sampling, the detached “next_state_value”, the detached “pred_value”, and the detached “target_value”.