MultiAgentMLP¶
- class torchrl.modules.MultiAgentMLP(n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, centralized: bool | None = None, share_params: bool | None = None, device: Optional[DEVICE_TYPING] = None, depth: Optional[int] = None, num_cells: Optional[Union[Sequence, int]] = None, activation_class: Optional[Type[nn.Module]] = <class 'torch.nn.modules.activation.Tanh'>, **kwargs)[source]¶
Mult-agent MLP.
This is an MLP that can be used in multi-agent contexts. For example, as a policy or as a value function. See examples/multiagent for examples.
It expects inputs with shape (*B, n_agents, n_agent_inputs) It returns outputs with shape (*B, n_agents, n_agent_outputs)
If share_params is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).
If centralized is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input.
- Parameters:
n_agent_inputs (int or None) – number of inputs for each agent. If
None
, the number of inputs is lazily instantiated during the first call.n_agent_outputs (int) – number of outputs for each agent.
n_agents (int) – number of agents.
centralized (bool) – If centralized is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input.
share_params (bool) – If share_params is True, the same MLP will be used to make the forward pass for all agents (homogeneous policies). Otherwise, each agent will use a different MLP to process its input (heterogeneous policies).
device (str or toech.device, optional) – device to create the module on.
depth (int, optional) – depth of the network. A depth of 0 will produce a single linear layer network with the desired input and output size. A length of 1 will create 2 linear layers etc. If no depth is indicated, the depth information should be contained in the num_cells argument (see below). If num_cells is an iterable and depth is indicated, both should match: len(num_cells) must be equal to depth. default: 3.
num_cells (int or Sequence[int], optional) – number of cells of every layer in between the input and output. If an integer is provided, every layer will have the same number of cells. If an iterable is provided, the linear layers out_features will match the content of num_cells. default: 32.
activation_class (Type[nn.Module]) – activation class to be used. default: nn.Tanh.
**kwargs – for
torchrl.modules.models.MLP
can be passed to customize the MLPs.
Note
to initialize the MARL module parameters with the torch.nn.init module, please refer to
get_stateful_net()
andfrom_stateful_net()
methods.Examples
>>> from torchrl.modules import MultiAgentMLP >>> import torch >>> n_agents = 6 >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) >>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=False, ... share_params=True, ... depth=2, ... ) >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( (0): MLP( (0): Linear(in_features=3, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=2, bias=True) ) ) ) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) Now let's instantiate a centralized network shared by all agents (e.g. a centalised value function) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=True, ... share_params=True, ... depth=2, ... ) >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( (0): MLP( (0): Linear(in_features=18, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=2, bias=True) ) ) ) We can see that the input to the first layer is n_agents * n_agent_inputs, this is because in the case the net acts as a centralized mlp (like a single huge agent) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) Outputs will be identical for all agents. Now we can do both examples just shown but with an independent set of parameters for each agent Let's show the centralized=False case. >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=False, ... share_params=False, ... depth=2, ... ) >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( (0-5): 6 x MLP( (0): Linear(in_features=3, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=2, bias=True) ) ) ) We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent! >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)