class torchrl.modules.DistributionalQValueHook(action_space: str, support: Tensor, var_nums: Optional[int] = None, action_value_key: Optional[Union[str, Tuple[str, ...]]] = None, action_mask_key: Optional[Union[str, Tuple[str, ...]]] = None, out_keys: Optional[Sequence[Union[str, Tuple[str, ...]]]] = None)[source]

Distributional Q-Value hook for Q-value policies.

Given the output of a mapping operator, representing the log-probability of the different action value bin available, a DistributionalQValueHook will transform these values into their argmax component using the provided support.

For more details regarding Distributional DQN, refer to “A Distributional Perspective on Reinforcement Learning”,

  • action_space (str) – Action space. Must be one of "one-hot", "mult-one-hot", "binary" or "categorical".

  • action_value_key (str or tuple of str, optional) – to be used when hooked on a TensorDictModule. The input key representing the action value. Defaults to "action_value".

  • action_mask_key (str or tuple of str, optional) – The input key representing the action mask. Defaults to "None" (equivalent to no masking).

  • support (torch.Tensor) – support of the action values.

  • var_nums (int, optional) – if action_space = "mult-one-hot", this value represents the cardinality of each action component.


>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from import OneHotDiscreteTensorSpec
>>> from torchrl.modules.tensordict_module.actors import DistributionalQValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> class CustomDistributionalQval(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = nn.Linear(4, nbins*4)
...     def forward(self, x):
...         return self.linear(x).view(-1, nbins, 4).log_softmax(-2)
>>> module = CustomDistributionalQval()
>>> params = make_functional(module)
>>> action_spec = OneHotDiscreteTensorSpec(4)
>>> hook = DistributionalQValueHook("one_hot", support = torch.arange(nbins))
>>> module.register_forward_hook(hook)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> qvalue_actor(td, params=params)
>>> print(td)
        action: Tensor(torch.Size([5, 4]), dtype=torch.int64),
        action_value: Tensor(torch.Size([5, 3, 4]), dtype=torch.float32),
        observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)},


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources