Shortcuts

QValueHook

class torchrl.modules.QValueHook(action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None)[source]

Q-Value hook for Q-value policies.

Given the output of a regular nn.Module, representing the values of the different discrete actions available, a QValueHook will transform these values into their argmax component (i.e. the resulting greedy action).

Parameters:
  • action_space (str) – Action space. Must be one of "one-hot", "mult-one-hot", "binary" or "categorical".

  • var_nums (int, optional) – if action_space = "mult-one-hot", this value represents the cardinality of each action component.

  • action_value_key (str or tuple of str, optional) – to be used when hooked on a TensorDictModule. The input key representing the action value. Defaults to "action_value".

  • action_mask_key (str or tuple of str, optional) – The input key representing the action mask. Defaults to "None" (equivalent to no masking).

  • out_keys (list of str or tuple of str, optional) – to be used when hooked on a TensorDictModule. The output keys representing the actions, action values and chosen action value. Defaults to ["action", "action_value", "chosen_action_value"].

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> module = nn.Linear(4, 4)
>>> hook = QValueHook("one_hot")
>>> module.register_forward_hook(hook)
>>> action_spec = OneHot(4)
>>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"])
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources