CrossQLoss¶
- class torchrl.objectives.CrossQLoss(*args, **kwargs)[source]¶
TorchRL implementation of the CrossQ loss.
Presented in “CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY” https://openreview.net/pdf?id=PczQtTsTIX
This class has three loss functions that will be called sequentially by the forward method:
qvalue_loss()
,actor_loss()
andalpha_loss()
. Alternatively, they can be called by the user that order.- Parameters:
actor_network (ProbabilisticActor) – stochastic actor
qvalue_network (TensorDictModule) –
Q(s, a) parametric model. This module typically outputs a
"state_action_value"
entry. If a single instance of qvalue_network is provided, it will be duplicatednum_qvalue_nets
times. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded).Warning
When a list of parameters if passed, it will __not__ be compared against the policy parameters and all the parameters will be considered as untied.
- Keyword Arguments:
num_qvalue_nets (integer, optional) – number of Q-Value networks used. Defaults to
2
.loss_function (str, optional) – loss function to be used with the value function loss. Default is “smooth_l1”.
alpha_init (float, optional) – initial entropy multiplier. Default is 1.0.
min_alpha (float, optional) – min value of alpha. Default is None (no minimum value).
max_alpha (float, optional) – max value of alpha. Default is None (no maximum value).
action_spec (TensorSpec, optional) – the action tensor spec. If not provided and the target entropy is
"auto"
, it will be retrieved from the actor.fixed_alpha (bool, optional) – if
True
, alpha will be fixed to its initial value. Otherwise, alpha will be optimized to match the ‘target_entropy’ value. Default isFalse
.target_entropy (float or str, optional) – Target entropy for the stochastic policy. Default is “auto”, where target entropy is computed as
-prod(n_actions)
.priority_key (str, optional) – [Deprecated, use .set_keys(priority_key=priority_key) instead] Tensordict key where to write the priority (for prioritized replay buffer usage). Defaults to
"td_error"
.separate_losses (bool, optional) – if
True
, shared parameters between policy and critic will only be trained on the policy loss. Defaults toFalse
, i.e., gradients are propagated to shared parameters for both policy and critic losses.reduction (str, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
."none"
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.crossq import CrossQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = CrossQLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["action", "next_reward", "next_done", "next_terminated"]
+ in_keys of the actor and qvalue network. The return value is a tuple of tensors in the following order:["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
Examples
>>> import torch >>> from torch import nn >>> from torchrl.data import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives import CrossQLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear(torch.cat([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> loss = CrossQLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) >>> loss_actor, loss_qvalue, _, _, _ = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward()
The output keys can also be filtered using the
CrossQLoss.select_out_keys()
method.Examples
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward()
- actor_loss(tensordict: TensorDictBase) Tuple[Tensor, Dict[str, Tensor]] [source]¶
Compute the actor loss.
The actor loss should be computed after the
qvalue_loss()
and before the ~.alpha_loss which requires the log_prob field of the metadata returned by this method.- Parameters:
tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.
Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached “log_prob” of the sampled action.
- alpha_loss(log_prob: Tensor) Tensor [source]¶
Compute the entropy loss.
The entropy loss should be computed last.
- Parameters:
log_prob (torch.Tensor) – a log-probability as computed by the
actor_loss()
and returned in the metadata.
Returns: a differentiable tensor with the entropy loss.
- forward(tensordict: TensorDictBase = None) TensorDictBase [source]¶
The forward method.
Computes successively the
qvalue_loss()
,actor_loss()
andalpha_loss()
, and returns a tensordict with these values along with the “alpha” value and the “entropy” value (detached). To see what keys are expected in the input tensordict and what keys are expected as output, check the class’s “in_keys” and “out_keys” attributes.
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)¶
Copy parameters and buffers from
state_dict
into this module and its descendants.If
strict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
unlessget_swap_module_params_on_conversion()
isTrue
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
state_dict
.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict
.
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
- qvalue_loss(tensordict: TensorDictBase) Tuple[Tensor, Dict[str, Tensor]] [source]¶
Compute the q-value loss.
The q-value loss should be computed before the
actor_loss()
.- Parameters:
tensordict (TensorDictBase) – the input data for the loss. Check the class’s in_keys to see what fields are required for this to be computed.
- Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing
the detached “td_error” to be used for prioritized sampling.
- state_dict(*args, destination=None, prefix='', keep_vars=False)¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- property target_entropy_buffer¶
The target entropy.
This value can be controlled via the target_entropy kwarg in the constructor.