Shortcuts

MultiAgentMLP

class torchrl.modules.MultiAgentMLP(n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, *, centralized: bool | None = None, share_params: bool | None = None, device: Optional[DEVICE_TYPING] = None, depth: Optional[int] = None, num_cells: Optional[Union[Sequence, int]] = None, activation_class: Optional[Type[nn.Module]] = <class 'torch.nn.modules.activation.Tanh'>, use_td_params: bool = True, **kwargs)[source]

Mult-agent MLP.

This is an MLP that can be used in multi-agent contexts. For example, as a policy or as a value function. See examples/multiagent for examples.

It expects inputs with shape (*B, n_agents, n_agent_inputs) It returns outputs with shape (*B, n_agents, n_agent_outputs)

If share_params is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).

If centralized is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input.

Parameters:
  • n_agent_inputs (int or None) – number of inputs for each agent. If None, the number of inputs is lazily instantiated during the first call.

  • n_agent_outputs (int) – number of outputs for each agent.

  • n_agents (int) – number of agents.

Keyword Arguments:
  • centralized (bool) – If centralized is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input.

  • share_params (bool) – If share_params is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).

  • device (str or toech.device, optional) – device to create the module on.

  • depth (int, optional) – depth of the network. A depth of 0 will produce a single linear layer network with the desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, the depth information should be contained in the num_cells argument (see below). If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. default: 3.

  • num_cells (int or Sequence[int], optional) – number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells. default: 32.

  • activation_class (Type[nn.Module]) – activation class to be used. default: nn.Tanh.

  • use_td_params (bool, optional) – if True, the parameters can be found in self.params which is a TensorDictParams object (which inherits both from TensorDict and nn.Module). If False, parameters are contained in self._empty_net. All things considered, these two approaches should be roughly identical but not interchangeable: for instance, a state_dict created with use_td_params=True cannot be used when use_td_params=False.

  • **kwargs – for torchrl.modules.models.MLP can be passed to customize the MLPs.

Note

to initialize the MARL module parameters with the torch.nn.init module, please refer to get_stateful_net() and from_stateful_net() methods.

Examples

>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Now let's instantiate a centralized network shared by all agents (e.g. a centalised value function)
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=True,
...     share_params=True,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0): MLP(
      (0): Linear(in_features=18, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
We can see that the input to the first layer is n_agents * n_agent_inputs,
this is because in the case the net acts as a centralized mlp (like a single huge agent)
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)
Outputs will be identical for all agents.
Now we can do both examples just shown but with an independent set of parameters for each agent
Let's show the centralized=False case.
>>> mlp = MultiAgentMLP(
...     n_agent_inputs=n_agent_inputs,
...     n_agent_outputs=n_agent_outputs,
...     n_agents=n_agents,
...     centralized=False,
...     share_params=False,
...     depth=2,
... )
>>> print(mlp)
MultiAgentMLP(
  (agent_networks): ModuleList(
    (0-5): 6 x MLP(
      (0): Linear(in_features=3, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=2, bias=True)
    )
  )
)
We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent!
>>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources