QValueHook¶
- class torchrl.modules.QValueHook(action_space: str, var_nums: Optional[int] = None, action_value_key: Optional[NestedKey] = None, action_mask_key: Optional[NestedKey] = None, out_keys: Optional[Sequence[NestedKey]] = None)[source]¶
Q-Value hook for Q-value policies.
Given the output of a regular nn.Module, representing the values of the different discrete actions available, a QValueHook will transform these values into their argmax component (i.e. the resulting greedy action).
- Parameters:
action_space (str) – Action space. Must be one of
"one-hot"
,"mult-one-hot"
,"binary"
or"categorical"
.var_nums (int, optional) – if
action_space = "mult-one-hot"
, this value represents the cardinality of each action component.action_value_key (str or tuple of str, optional) – to be used when hooked on a TensorDictModule. The input key representing the action value. Defaults to
"action_value"
.action_mask_key (str or tuple of str, optional) – The input key representing the action mask. Defaults to
"None"
(equivalent to no masking).out_keys (list of str or tuple of str, optional) – to be used when hooked on a TensorDictModule. The output keys representing the actions, action values and chosen action value. Defaults to
["action", "action_value", "chosen_action_value"]
.
Examples
>>> import torch >>> from tensordict import TensorDict >>> from torch import nn >>> from torchrl.data import OneHot >>> from torchrl.modules.tensordict_module.actors import QValueHook, Actor >>> td = TensorDict({'observation': torch.randn(5, 4)}, [5]) >>> module = nn.Linear(4, 4) >>> hook = QValueHook("one_hot") >>> module.register_forward_hook(hook) >>> action_spec = OneHot(4) >>> qvalue_actor = Actor(module=module, spec=action_spec, out_keys=["action", "action_value"]) >>> td = qvalue_actor(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False), action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False)