EGreedyWrapper¶
- class torchrl.modules.tensordict_module.EGreedyWrapper(*args, **kwargs)[source]¶
[Deprecated] Epsilon-Greedy PO wrapper.
- Parameters:
policy (TensorDictModule) – a deterministic policy.
- Keyword Arguments:
eps_init (scalar, optional) – initial epsilon value. default: 1.0
eps_end (scalar, optional) – final epsilon value. default: 0.1
annealing_num_steps (int, optional) – number of steps it will take for epsilon to reach the eps_end value
action_key (NestedKey, optional) – the key where the action can be found in the input tensordict. Default is
"action"
.action_mask_key (NestedKey, optional) – the key where the action mask can be found in the input tensordict. Default is
None
(corresponding to no mask).spec (TensorSpec, optional) – if provided, the sampled action will be taken from this action space. If not provided, the exploration wrapper will attempt to recover it from the policy.
Note
Once a module has been wrapped in
EGreedyWrapper
, it is crucial to incorporate a call tostep()
in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted!Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import EGreedyWrapper, Actor >>> from torchrl.data import BoundedTensorSpec >>> torch.manual_seed(0) >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) >>> print(explorative_policy(td).get("action")) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.9055, -0.9277, -0.6295, -0.2532], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<AddBackward0>)