Shortcuts

QMixer

class torchrl.modules.QMixer(state_shape: Union[Tuple[int, ...], Size], mixing_embed_dim: int, n_agents: int, device: Union[device, str, int])[source]

QMix mixer.

Mixes the local Q values of the agents into a global Q value through a monotonic hyper-network whose parameters are obtained from a global state. From the paper https://arxiv.org/abs/1803.11485 .

It transforms the local value of each agent’s chosen action of shape (*B, self.n_agents, 1), into a global value with shape (*B, 1). Used with the torchrl.objectives.QMixerLoss. See examples/multiagent/qmix_vdn.py for examples.

Parameters:
  • state_shape (tuple or torch.Size) – the shape of the state (excluding eventual leading batch dimensions).

  • mixing_embed_dim (int) – the size of the mixing embedded dimension.

  • n_agents (int) – number of agents.

  • device (str or torch.Device) – torch device for the network.

Examples

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules.models.multiagent import QMixer
>>> n_agents = 4
>>> qmix = TensorDictModule(
...     module=QMixer(
...         state_shape=(64, 64, 3),
...         mixing_embed_dim=32,
...         n_agents=n_agents,
...         device="cpu",
...     ),
...     in_keys=[("agents", "chosen_action_value"), "state"],
...     out_keys=["chosen_action_value"],
... )
>>> td = TensorDict({"agents": TensorDict({"chosen_action_value": torch.zeros(32, n_agents, 1)}, [32, n_agents]), "state": torch.zeros(32, 64, 64, 3)}, [32])
>>> td
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([32, 4]),
            device=None,
            is_shared=False),
        state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)
>>> vdn(td)
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                chosen_action_value: Tensor(shape=torch.Size([32, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([32, 4]),
            device=None,
            is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        state: Tensor(shape=torch.Size([32, 64, 64, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([32]),
    device=None,
    is_shared=False)
mix(chosen_action_value: Tensor, state: Tensor)[source]

Forward pass for the mixer.

Parameters:

chosen_action_value – Tensor of shape [*B, n_agents]

Returns:

Tensor of shape [*B]

Return type:

chosen_action_value

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources