

class torchrl.modules.MultiAgentMLP(n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, centralised: bool, share_params: bool, device: Optional[DEVICE_TYPING] = None, depth: Optional[int] = None, num_cells: Optional[Union[Sequence, int]] = None, activation_class: Optional[Type[nn.Module]] = <class 'torch.nn.modules.activation.Tanh'>, **kwargs)[source]

Mult-agent MLP.

This is an MLP that can be used in multi-agent contexts. For example, as a policy or as a value function. See examples/multiagent for examples.

It expects inputs with shape (*B, n_agents, n_agent_inputs) It returns outputs with shape (*B, n_agents, n_agent_outputs)

If share_params is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).

If centralised is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input.

  • n_agent_inputs (int or None) – number of inputs for each agent. If None, the number of inputs is lazily instantiated during the first call.

  • n_agent_outputs (int) – number of outputs for each agent.

  • n_agents (int) – number of agents.

  • centralised (bool) – If centralised is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input.

  • share_params (bool) – If share_params is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).

  • device (str or toech.device, optional) – device to create the module on.

  • depth (int, optional) – depth of the network. A depth of 0 will produce a single linear layer network with the desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, the depth information should be contained in the num_cells argument (see below). If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. default: 3.

  • num_cells (int or Sequence[int], optional) – number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells. default: 32.

  • activation_class (Type[nn.Module]) – activation class to be used. default: nn.Tanh.

  • **kwargs – for torchrl.modules.models.MLP can be passed to customize the MLPs.


>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs
First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralised=False,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Now let's instantiate a centralised network shared by all agents (e.g. a centalised value function)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralised=True,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=18, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
We can see that the input to the first layer is n_agents * n_agent_inputs,
this is because in the case the net acts as a centralised mlp (like a single huge agent)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Outputs will be identical for all agents.
Now we can do both examples just shown but with an independent set of parameters for each agent
Let's show the centralised=False case.
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralised=False,
...     share_params=False,
...     depth=2,
... )
>>> print(mlp)
  (agent_networks): ModuleList(
    (0-5): 6 x MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent!
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources