DiscreteCQLLoss¶
- class torchrl.objectives.DiscreteCQLLoss(*args, **kwargs)[source]¶
TorchRL implementation of the discrete CQL loss.
This class implements the discrete conservative Q-learning (CQL) loss function, as presented in the paper “Conservative Q-Learning for Offline Reinforcement Learning” (https://arxiv.org/abs/2006.04779).
- Parameters:
value_network (Union[QValueActor, nn.Module]) – The Q-value network used to estimate state-action values.
- Keyword Arguments:
loss_function (Optional[str]) – The distance function used to calculate the distance between the predicted Q-values and the target Q-values. Defaults to
l2
.delay_value (bool) – Whether to separate the target Q value networks from the Q value networks used for data collection. Default is
True
.gamma (
float
, optional) – Discount factor. Default isNone
.action_space – The action space of the environment. If None, it is inferred from the value network. Defaults to None.
reduction (str, optional) – Specifies the reduction to apply to the output:
"none"
|"mean"
|"sum"
."none"
: no reduction will be applied,"mean"
: the sum of the output will be divided by the number of elements in the output,"sum"
: the output will be summed. Default:"mean"
.
Examples
>>> from torchrl.modules import MLP, QValueActor >>> from torchrl.data import OneHot >>> from torchrl.objectives import DiscreteCQLLoss >>> n_obs, n_act = 4, 3 >>> value_net = MLP(in_features=n_obs, out_features=n_act) >>> spec = OneHot(n_act) >>> actor = QValueActor(value_net, in_keys=["observation"], action_space=spec) >>> loss = DiscreteCQLLoss(actor, action_space=spec) >>> batch = [10,] >>> data = TensorDict({ ... "observation": torch.randn(*batch, n_obs), ... "action": spec.rand(batch), ... ("next", "observation"): torch.randn(*batch, n_obs), ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), ... ("next", "reward"): torch.randn(*batch, 1) ... }, batch) >>> loss(data) TensorDict( fields={ loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), td_error: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are:
["observation", "next_observation", "action", "next_reward", "next_done", "next_terminated"]
, and a single loss value is returned.Examples
>>> from torchrl.objectives import DiscreteCQLLoss >>> from torchrl.data import OneHot >>> from torch import nn >>> import torch >>> n_obs = 3 >>> n_action = 4 >>> action_spec = OneHot(n_action) >>> value_network = nn.Linear(n_obs, n_action) # a simple value model >>> dcql_loss = DiscreteCQLLoss(value_network, action_space=action_spec) >>> # define data >>> observation = torch.randn(n_obs) >>> next_observation = torch.randn(n_obs) >>> action = action_spec.rand() >>> next_reward = torch.randn(1) >>> next_done = torch.zeros(1, dtype=torch.bool) >>> next_terminated = torch.zeros(1, dtype=torch.bool) >>> loss_val = dcql_loss( ... observation=observation, ... next_observation=next_observation, ... next_reward=next_reward, ... next_done=next_done, ... next_terminated=next_terminated, ... action=action)
- forward(tensordict: TensorDictBase = None) TensorDict [source]¶
Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer.
- This function will also write a “td_error” key that can be used by prioritized replay buffers to assign
a priority to items in the tensordict.
- Parameters:
tensordict (TensorDictBase) – a tensordict with keys [“action”] and the in_keys of the value network (observations, “done”, “terminated”, “reward” in a “next” tensordict).
- Returns:
a tensor containing the CQL loss.
- make_value_estimator(value_type: Optional[ValueEstimators] = None, **hyperparams)[source]¶
Value-function constructor.
If the non-default value function is wanted, it must be built using this method.
- Parameters:
value_type (ValueEstimators) – A
ValueEstimators
enum type indicating the value function to use. If none is provided, the default stored in thedefault_value_estimator
attribute will be used. The resulting value estimator class will be registered inself.value_type
, allowing future refinements.**hyperparams – hyperparameters to use for the value function. If not provided, the value indicated by
default_value_kwargs()
will be used.
Examples
>>> from torchrl.objectives import DQNLoss >>> # initialize the DQN loss >>> actor = torch.nn.Linear(3, 4) >>> dqn_loss = DQNLoss(actor, action_space="one-hot") >>> # updating the parameters of the default value estimator >>> dqn_loss.make_value_estimator(gamma=0.9) >>> dqn_loss.make_value_estimator( ... ValueEstimators.TD1, ... gamma=0.9) >>> # if we want to change the gamma value >>> dqn_loss.make_value_estimator(dqn_loss.value_type, gamma=0.9)