Shortcuts

EGreedyWrapper

class torchrl.modules.EGreedyWrapper(*args, **kwargs)[source]

[Deprecated] Epsilon-Greedy PO wrapper.

Parameters:

policy (TensorDictModule) – a deterministic policy.

Keyword Arguments:
  • eps_init (scalar, optional) – initial epsilon value. default: 1.0

  • eps_end (scalar, optional) – final epsilon value. default: 0.1

  • annealing_num_steps (int, optional) – number of steps it will take for epsilon to reach the eps_end value

  • action_key (NestedKey, optional) – the key where the action can be found in the input tensordict. Default is "action".

  • action_mask_key (NestedKey, optional) – the key where the action mask can be found in the input tensordict. Default is None (corresponding to no mask).

  • spec (TensorSpec, optional) – if provided, the sampled action will be taken from this action space. If not provided, the exploration wrapper will attempt to recover it from the policy.

Note

Once a module has been wrapped in EGreedyWrapper, it is crucial to incorporate a call to step() in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted!

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.modules import EGreedyWrapper, Actor
>>> from torchrl.data import Bounded
>>> torch.manual_seed(0)
>>> spec = Bounded(-1, 1, torch.Size([4]))
>>> module = torch.nn.Linear(4, 4, bias=False)
>>> policy = Actor(spec=spec, module=module)
>>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2)
>>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10])
>>> print(explorative_policy(td).get("action"))
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.9055, -0.9277, -0.6295, -0.2532],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<AddBackward0>)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources