DQNLoss¶
- class torchrl.objectives.DQNLoss(*args, **kwargs)[source]¶
The DQN Loss class.
- Parameters:
value_network (QValueActor or nn.Module) – a Q value operator.
- Keyword Arguments:
loss_function (str, optional) – loss function for the value discrepancy. Can be one of “l1”, “l2” or “smooth_l1”. Defaults to “l2”.
delay_value (bool, optional) – whether to duplicate the value network into a new target value network to create a DQN with a target network. Default is
True
.double_dqn (bool, optional) – whether to use Double DQN, as described in https://arxiv.org/abs/1509.06461. Defaults to
False
.action_space (str or TensorSpec, optional) – Action space. Must be one of
"one-hot"
,"mult_one_hot"
,"binary"
or"categorical"
, or an instance of the corresponding specs (torchrl.data.OneHot
,torchrl.data.MultiOneHot
,torchrl.data.Binary
ortorchrl.data.Categorical
). If not provided, an attempt to retrieve it from the value network will be made.priority_key (NestedKey, optional) – [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type
PrioritizedSampler
. Defaults to"td_error"
.reduction (str, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
."none"
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.
Examples
>>> from torchrl.modules import MLP >>> from torchrl.data import OneHot >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DQNLoss(actor, action_space=spec) >>> batch = [10,] >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "observation"): torch.randn(*batch, n_obs), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1) ... }, batch) >>> loss(data) TensorDict( fields={ loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]
, and a single loss value is returned.Examples
>>> from torchrl.objectives import DQNLoss >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dqn_loss = DQNLoss(value_network, action_space=action_spec) >>> # define data >>> observation = torch.randn(n_obs) >>> next_observation = torch.randn(n_obs) >>> action = action_spec.rand() >>> next_reward = torch.randn(1) >>> next_done = torch.zeros(1, dtype=torch.bool) >>> next_terminated = torch.zeros(1, dtype=torch.bool) >>> loss_val = dqn_loss( ... observation=observation, ... next_observation=next_observation, ... next_reward=next_reward, ... next_done=next_done, ... next_terminated=next_terminated, ... action=action)
- forward(tensordict: TensorDictBase = None) TensorDict [source]¶
Computes the DQN loss given a tensordict sampled from the replay buffer.
- This function will also write a “td_error” key that can be used by prioritized replay buffers to assign
a priority to items in the tensordict.
- Parameters:
tensordict (TensorDictBase) – a tensordict with keys [“action”] and the in_keys of the value network (observations, “done”, “terminated”, “reward” in a “next” tensordict).
- Returns:
a tensor containing the DQN loss.
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)