Shortcuts

MultiAgentConvNet

class torchrl.modules.MultiAgentConvNet(n_agents: int, centralized: bool | None = None, share_params: bool | None = None, *, in_features: int | None = None, device: DEVICE_TYPING | None = None, num_cells: Sequence[int] | None = None, kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5, strides: Union[Sequence, int] = 2, paddings: Union[Sequence, int] = 0, activation_class: Type[nn.Module] = <class 'torch.nn.modules.activation.ELU'>, use_td_params: bool = True, **kwargs)[source]

Multi-agent CNN.

In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as “centralized” and “non-centralized”, respectively.

It expects inputs with shape (*B, n_agents, channels, x, y).

Note

to initialize the MARL module parameters with the torch.nn.init module, please refer to get_stateful_net() and from_stateful_net() methods.

Parameters:
  • n_agents (int) – number of agents.

  • centralized (bool) – If True, each agent will use the inputs of all agents to compute its output, resulting in input of shape (*B, n_agents * channels, x, y). Otherwise, each agent will only use its data as input.

  • share_params (bool) – If True, the same ConvNet will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different ConvNet to process its input (heterogeneous policies).

Keyword Arguments:
  • in_features (int, optional) – the input feature dimension. If left to None, a lazy module is used.

  • device (str or torch.device, optional) – device to create the module on.

  • num_cells (int or Sequence[int], optional) – number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells.

  • kernel_sizes (int, Sequence[Union[int, Sequence[int]]]) – Kernel size(s) of the convolutional network. Defaults to 5.

  • strides (int or Sequence[int]) – Stride(s) of the convolutional network. If iterable, the length must match the depth, defined by the num_cells or depth arguments. Defaults to 2.

  • activation_class (Type[nn.Module]) – activation class to be used. Default to torch.nn.ELU.

  • use_td_params (bool, optional) – if True, the parameters can be found in self.params which is a TensorDictParams object (which inherits both from TensorDict and nn.Module). If False, parameters are contained in self._empty_net. All things considered, these two approaches should be roughly identical but not interchangeable: for instance, a state_dict created with use_td_params=True cannot be used when use_td_params=False.

  • **kwargs – for ConvNet can be passed to customize the ConvNet.

Examples

>>> import torch
>>> from torchrl.modules import MultiAgentConvNet
>>> batch = (3,2)
>>> n_agents = 7
>>> channels, x, y = 3, 100, 100
>>> obs = torch.randn(*batch, n_agents, channels, x, y)
>>> # Let's consider a centralized network with shared parameters.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = True,
...     share_params = True
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0): ConvNet(
        (0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> result = cnn(obs)
>>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'.
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function)
>>> print(all(result[0,0,0] == result[0,0,1]))
True
>>> # Alternatively, a local network with parameter sharing (eg. decentralized weight sharing policy)
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = False,
...     share_params = True
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0): ConvNet(
        (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> # Parameters are shared but not observations, hence each agent has a different output.
>>> print(all(result[0,0,0] == result[0,0,1]))
False
>>> # Or multiple local networks identical in structure but with differing weights.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = False,
...     share_params = False
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0-6): 7 x ConvNet(
        (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> print(all(result[0,0,0] == result[0,0,1]))
False
>>> # Or where inputs are shared but not parameters.
>>> cnn = MultiAgentConvNet(
...     n_agents,
...     centralized = True,
...     share_params = False
... )
>>> print(cnn)
MultiAgentConvNet(
    (agent_networks): ModuleList(
        (0-6): 7 x ConvNet(
        (0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2))
        (1): ELU(alpha=1.0)
        (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (3): ELU(alpha=1.0)
        (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
        (5): ELU(alpha=1.0)
        (6): SquashDims()
        )
    )
)
>>> print(result.shape)
torch.Size([3, 2, 7, 2592])
>>> print(all(result[0,0,0] == result[0,0,1]))
False

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources