Shortcuts

Actor

class torchrl.modules.tensordict_module.Actor(*args, **kwargs)[source]

General class for deterministic actors in RL.

The Actor class comes with default values for the out_keys (["action"]) and if the spec is provided but not as a Composite object, it will be automatically translated into spec = Composite(action=spec).

Parameters:
  • module (nn.Module) – a Module used to map the input to the output parameter space.

  • in_keys (iterable of str, optional) – keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. Defaults to ["observation"].

  • out_keys (iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. Defaults to ["action"].

Keyword Arguments:
  • spec (TensorSpec, optional) – Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, spec characterize the space of the first output tensor.

  • safe (bool) – Keyword-only argument. If True, the value of the output is checked against the input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the desired space using the project() method. Default is False.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data import Unbounded
>>> from torchrl.modules import Actor
>>> torch.manual_seed(0)
>>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,])
>>> action_spec = Unbounded(4)
>>> module = torch.nn.Linear(4, 4)
>>> td_module = Actor(
...    module=module,
...    spec=action_spec,
...    )
>>> td_module(td)
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
>>> print(td.get("action"))
tensor([[-1.3635, -0.0340,  0.1476, -1.3911],
        [-0.1664,  0.5455,  0.2247, -0.4583],
        [-0.2916,  0.2160,  0.5337, -0.5193]], grad_fn=<AddmmBackward0>)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources