- class torchrl.objectives.SACLoss(*args, **kwargs)[source]
TorchRL implementation of the SAC loss.
Presented in “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor” and “Soft Actor-Critic Algorithms and Applications”
- Parameters:
actor_network (ProbabilisticActor) – stochastic actor
qvalue_network (TensorDictModule) –
Q(s, a) parametric model. This module typically outputs a
entry. If a single instance of qvalue_network is provided, it will be duplicatednum_qvalue_nets
times. If a list of modules is passed, their parameters will be stacked unless they share the same identity (in which case the original parameter will be expanded).Warning
When a list of parameters if passed, it will __not__ be compared against the policy parameters and all the parameters will be considered as untied.
value_network (TensorDictModule, optional) –
V(s) parametric model. This module typically outputs a
If not provided, the second version of SAC is assumed, where only the Q-Value network is needed.
- Keyword Arguments:
num_qvalue_nets (integer, optional) – number of Q-Value networks used. Defaults to
.loss_function (str, optional) – loss function to be used with the value function loss. Default is “smooth_l1”.
alpha_init (
, optional) – initial entropy multiplier. Default is 1.0.min_alpha (
, optional) – min value of alpha. Default is None (no minimum value).max_alpha (
, optional) – max value of alpha. Default is None (no maximum value).action_spec (TensorSpec, optional) – the action tensor spec. If not provided and the target entropy is
, it will be retrieved from the actor.fixed_alpha (bool, optional) – if
, alpha will be fixed to its initial value. Otherwise, alpha will be optimized to match the ‘target_entropy’ value. Default isFalse
.target_entropy (
or str, optional) – Target entropy for the stochastic policy. Default is “auto”, where target entropy is computed as-prod(n_actions)
.delay_actor (bool, optional) – Whether to separate the target actor networks from the actor networks used for data collection. Default is
.delay_qvalue (bool, optional) – Whether to separate the target Q value networks from the Q value networks used for data collection. Default is
.delay_value (bool, optional) – Whether to separate the target value networks from the value networks used for data collection. Default is
.priority_key (str, optional) – [Deprecated, use .set_keys(priority_key=priority_key) instead] Tensordict key where to write the priority (for prioritized replay buffer usage). Defaults to
.separate_losses (bool, optional) – if
, shared parameters between policy and critic will only be trained on the policy loss. Defaults toFalse
, i.e., gradients are propagated to shared parameters for both policy and critic losses.reduction (str, optional) – Specifies the reduction to apply to the output:
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.skip_done_states (bool, optional) – whether the actor network used for value computation should only be run on valid, non-terminating next states. If
, it is assumed that the done state can be broadcast to the shape of the data and that masking the data results in a valid data structure. Among other things, this may not be true in MARL settings or when using RNNs. Defaults toFalse
>>> import torch >>> from torch import nn >>> from import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> module = nn.Linear(n_obs, 1) >>> value = ValueOperator( ... module=module, ... in_keys=["observation"]) >>> loss = SACLoss(actor, qvalue, value) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": action, ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1), ... ("next", "observation"): torch.randn(*batch, n_obs), ... }, batch) >>> loss(data) TensorDict( fields={ alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["action", "next_reward", "next_done", "next_terminated"]
+ in_keys of the actor, value, and qvalue network. The return value is a tuple of tensors in the following order:["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]
if version one is used.Examples
>>> import torch >>> from torch import nn >>> from import Bounded >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = Bounded(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, ... in_keys=["loc", "scale"], ... spec=spec, ... distribution_class=TanhNormal) >>> class ValueClass(nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = nn.Linear(n_obs + n_act, 1) ... def forward(self, obs, act): ... return self.linear([obs, act], -1)) >>> module = ValueClass() >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) >>> module = nn.Linear(n_obs, 1) >>> value = ValueOperator( ... module=module, ... in_keys=["observation"]) >>> loss = SACLoss(actor, qvalue, value) >>> batch = [2, ] >>> action = spec.rand(batch) >>> loss_actor, loss_qvalue, _, _, _, _ = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward()
The output keys can also be filtered using the
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') >>> loss_actor, loss_qvalue = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), ... next_observation=torch.zeros(*batch, n_obs), ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward()
- forward(tensordict: TensorDictBase = None) TensorDictBase [source]
It is designed to read an input TensorDict and return another tensordict with loss keys named “loss*”.
Splitting the loss in its component can then be used by the trainer to log the various loss values throughout training. Other scalars present in the output tensordict will be logged too.
- Parameters:
tensordict – an input tensordict with the values required to compute the loss.
- Returns:
A new tensordict with no batch dimension containing various loss scalars which will be named “loss*”. It is essential that the losses are returned with this name as they will be read by the trainer before backpropagation.
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)
Copy parameters and buffers from
into this module and its descendants.If
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
the optimizer must be created after the call toload_state_dict
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – When set to
, the properties of the tensors in the current module are preserved whereas setting it toTrue
preserves properties of the Tensors in the state dict. The only exception is therequires_grad
field ofDefault: ``False`
- Returns:
- missing_keys is a list of str containing any keys that are expected
by this module but missing from the provided
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
- Return type:
If a parameter or buffer is registered as
and its corresponding key exists instate_dict
will raise aRuntimeError
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
will be used.
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)
- state_dict(*args, destination=None, prefix='', keep_vars=False)
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
are not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
also accepts positional arguments fordestination
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
as it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
will be created and returned. Default:None
.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
.keep_vars (bool, optional) – by default the
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
- Returns:
a dictionary containing a whole state of the module
- Return type:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']