class torchrl.objectives.DQNLoss(*args, **kwargs)[source]

The DQN Loss class.


value_network (QValueActor or nn.Module) – a Q value operator.

Keyword Arguments:
  • loss_function (str, optional) – loss function for the value discrepancy. Can be one of “l1”, “l2” or “smooth_l1”. Defaults to “l2”.

  • delay_value (bool, optional) – whether to duplicate the value network into a new target value network to create a double DQN. Default is False.

  • action_space (str or TensorSpec, optional) – Action space. Must be one of "one-hot", "mult_one_hot", "binary" or "categorical", or an instance of the corresponding specs (,, or If not provided, an attempt to retrieve it from the value network will be made.

  • priority_key (NestedKey, optional) – [Deprecated, use .set_keys(priority_key=priority_key) instead] The key at which priority is assumed to be stored within TensorDicts added to this ReplayBuffer. This is to be used when the sampler is of type PrioritizedSampler. Defaults to "td_error".


>>> from torchrl.modules import MLP
>>> from import OneHotDiscreteTensorSpec
>>> n_obs, n_act = 4, 3
>>> value_net = MLP(in_features=n_obs, out_features=n_act)
>>> spec = OneHotDiscreteTensorSpec(n_act)
>>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec)
>>> loss = DQNLoss(actor, action_space=spec)
>>> batch = [10,]
>>> data = TensorDict({
...     "observation": torch.randn(*batch, n_obs),
...     "action": spec.rand(batch),
...     ("next", "observation"): torch.randn(*batch, n_obs),
...     ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
...     ("next", "reward"): torch.randn(*batch, 1)
... }, batch)
>>> loss(data)
        loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},

This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: ["observation", "next_observation", "action", "next_reward", "next_done"], and a single loss value is returned.


>>> from torchrl.objectives import DQNLoss
>>> from import OneHotDiscreteTensorSpec
>>> from torch import nn
>>> import torch
>>> n_obs = 3
>>> n_action = 4
>>> action_spec = OneHotDiscreteTensorSpec(n_action)
>>> value_network = nn.Linear(n_obs, n_action) # a simple value model
>>> dqn_loss = DQNLoss(value_network, action_space=action_spec)
>>> # define data
>>> observation = torch.randn(n_obs)
>>> next_observation = torch.randn(n_obs)
>>> action = action_spec.rand()
>>> next_reward = torch.randn(1)
>>> next_done = torch.zeros(1, dtype=torch.bool)
>>> loss_val = dqn_loss(
...     observation=observation,
...     next_observation=next_observation,
...     next_reward=next_reward,
...     next_done=next_done,
...     action=action)
forward(tensordict: TensorDictBase) TensorDict[source]

Computes the DQN loss given a tensordict sampled from the replay buffer.

This function will also write a “td_error” key that can be used by prioritized replay buffers to assign

a priority to items in the tensordict.


tensordict (TensorDictBase) – a tensordict with keys [“action”] and the in_keys of the value network (observations, “done”, “reward” in a “next” tensordict).


a tensor containing the DQN loss.

make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]

Value-function constructor.

If the non-default value function is wanted, it must be built using this method.

  • value_type (ValueEstimators) – A ValueEstimators enum type indicating the value function to use. If none is provided, the default stored in the default_value_estimator attribute will be used. The resulting value estimator class will be registered in self.value_type, allowing future refinements.

  • **hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by default_value_kwargs() will be used.


>>> from torchrl.objectives import DQNLoss
>>> # initialize the DQN loss
>>> actor = torch.nn.Linear(3, 4)
>>> dqn_loss = DQNLoss(actor, action_space="one-hot")
>>> # updating the parameters of the default value estimator
>>> dqn_loss.make_value_estimator(gamma=0.9)
>>> dqn_loss.make_value_estimator(
...     ValueEstimators.TD1,
...     gamma=0.9)
>>> # if we want to change the gamma value
>>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources