EGreedyWrapper¶
- class torchrl.modules.tensordict_module.EGreedyWrapper(*args, **kwargs)[source]¶
Epsilon-Greedy PO wrapper.
- Parameters:
policy (TensorDictModule) – a deterministic policy.
- Keyword Arguments:
eps_init (scalar, optional) – initial epsilon value. default: 1.0
eps_end (scalar, optional) – final epsilon value. default: 0.1
annealing_num_steps (int, optional) – number of steps it will take for epsilon to reach the eps_end value
action_key (NestedKey, optional) – if the policy module has more than one output key, its output spec will be of type CompositeSpec. One needs to know where to find the action spec. Default is “action”.
spec (TensorSpec, optional) – if provided, the sampled action will be projected onto the valid action space once explored. If not provided, the exploration wrapper will attempt to recover it from the policy.
Note
Once an environment has been wrapped in
EGreedyWrapper
, it is crucial to incorporate a call tostep()
in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted!Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import EGreedyWrapper, Actor >>> from torchrl.data import BoundedTensorSpec >>> torch.manual_seed(0) >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) >>> print(explorative_policy(td).get("action")) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.9055, -0.9277, -0.6295, -0.2532], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<AddBackward0>)
- forward(tensordict: TensorDictBase) TensorDictBase [source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.